├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── VERSION ├── VGGFace2-HQ.png ├── VGGFace2-HQ.pptx ├── docs └── img │ ├── 2.png │ ├── VGGFace2-HQ.png │ ├── girl2-RGB.png │ ├── girl2.gif │ ├── logo.png │ ├── simswap.png │ ├── title.png │ └── vggface2_hq_compare.png ├── gfpgan ├── __init__.py ├── archs │ ├── __init__.py │ ├── arcface_arch.py │ ├── gfpganv1_arch.py │ ├── gfpganv1_clean_arch.py │ └── stylegan2_clean_arch.py ├── data │ ├── __init__.py │ └── ffhq_degradation_dataset.py ├── models │ ├── __init__.py │ └── gfpgan_model.py ├── train.py ├── utils.py └── weights │ └── README.md ├── inference_gfpgan.py ├── insightface_func ├── __init__.py ├── face_detect_crop.py ├── face_detect_crop_ffhq_newarcAlign.py └── utils │ ├── face_align.py │ └── face_align_ffhqandnewarc.py ├── options ├── train_gfpgan_v1.yml └── train_gfpgan_v1_simple.yml ├── requirements.txt ├── scripts ├── crop_align_vggface2_FFHQalign.py ├── crop_align_vggface2_FFHQalignandNewarcalign.py ├── inference_gfpgan_forvggface2.py └── vggface_dataset.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignored folders 2 | datasets/* 3 | experiments/* 4 | results/* 5 | tb_logger/* 6 | wandb/* 7 | tmp/* 8 | 9 | version.py 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | .ppt 141 | .pptx -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # flake8 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 3.8.3 5 | hooks: 6 | - id: flake8 7 | args: ["--config=setup.cfg", "--ignore=W504, W503"] 8 | 9 | # modify known_third_party 10 | - repo: https://github.com/asottile/seed-isort-config 11 | rev: v2.2.0 12 | hooks: 13 | - id: seed-isort-config 14 | 15 | # isort 16 | - repo: https://github.com/timothycrosley/isort 17 | rev: 5.2.2 18 | hooks: 19 | - id: isort 20 | 21 | # yapf 22 | - repo: https://github.com/pre-commit/mirrors-yapf 23 | rev: v0.30.0 24 | hooks: 25 | - id: yapf 26 | 27 | # codespell 28 | - repo: https://github.com/codespell-project/codespell 29 | rev: v2.1.0 30 | hooks: 31 | - id: codespell 32 | 33 | # pre-commit-hooks 34 | - repo: https://github.com/pre-commit/pre-commit-hooks 35 | rev: v3.2.0 36 | hooks: 37 | - id: trailing-whitespace # Trim trailing whitespace 38 | - id: check-yaml # Attempt to load all yaml files to verify syntax 39 | - id: check-merge-conflict # Check for files that contain merge conflict strings 40 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings 41 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline 42 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 43 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- 44 | args: ["--remove"] 45 | - id: mixed-line-ending # Replace or check mixed line ending 46 | args: ["--fix=lf"] 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include assets/* 2 | include inputs/* 3 | include scripts/*.py 4 | include inference_gfpgan.py 5 | include VERSION 6 | include LICENSE 7 | include requirements.txt 8 | include gfpgan/weights/README.md 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VGGFace2-HQ 2 | 3 | Related paper: [TPAMI](https://github.com/neuralchen/SimSwapPlus) 4 | 5 | ## The first open source high resolution dataset for face swapping!!! 6 | 7 | A high resolution version of [VGGFace2](https://github.com/ox-vgg/vgg_face2) for academic face editing purpose.This project uses [GFPGAN](https://github.com/TencentARC/GFPGAN) for image restoration and [insightface](https://github.com/deepinsight/insightface) for data preprocessing (crop and align). 8 | 9 | [![logo](./VGGFace2-HQ.png)](https://github.com/NNNNAI/VGGFace2-HQ) 10 | 11 | We provide a download link for users to download the data, and also provide guidance on how to generate the VGGFace2 dataset from scratch. 12 | 13 | If you find this project useful, please star it. It is the greatest appreciation of our work. 14 | 15 | 16 | 17 | # Get the VGGFace2-HQ dataset from cloud! 18 | 19 | We have uploaded the dataset of VGGFace2 HQ to the cloud, and you can download it from the cloud. 20 | 21 | ### Google Drive 22 | 23 | [[Google Drive]](https://drive.google.com/drive/folders/1ZHy7jrd6cGb2lUa4qYugXe41G_Ef9Ibw?usp=sharing) 24 | 25 | ***We are especially grateful to [Kairui Feng](https://scholar.google.com.hk/citations?user=4N5hE8YAAAAJ&hl=zh-CN) PhD student from Princeton University.*** 26 | 27 | ### Baidu Drive 28 | 29 | [[Baidu Drive]](https://pan.baidu.com/s/1LwPFhgbdBj5AeoPTXgoqDw) Password: ```sjtu``` 30 | 31 | 32 | # Generate the HQ dataset by yourself. (If you want to do so) 33 | ## Preparation 34 | ### Installation 35 | **We highly recommand that you use Anaconda for Installation** 36 | ``` 37 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch 38 | pip install insightface==0.2.1 onnxruntime 39 | (optional) pip install onnxruntime-gpu==1.2.0 40 | 41 | pip install basicsr 42 | pip install facexlib 43 | pip install -r requirements.txt 44 | python setup.py develop 45 | ``` 46 | - The pytorch and cuda versions above are most recommanded. They may vary. 47 | - Using insightface with different versions is not recommanded. Please use this specific version. 48 | - These settings are tested valid on both Windows and Ununtu. 49 | ### Pretrained model 50 | - We use the face detection and alignment methods from **[insightface](https://github.com/deepinsight/insightface)** for image preprocessing. Please download the relative files and unzip them to **./insightface_func/models** from [this link](https://onedrive.live.com/?authkey=%21ADJ0aAOSsc90neY&cid=4A83B6B633B029CC&id=4A83B6B633B029CC%215837&parId=4A83B6B633B029CC%215834&action=locate). 51 | - Download [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) from GFPGAN offical repo. Place "GFPGANCleanv1-NoCE-C2.pth" in **./experiments/pretrained_models**. 52 | 53 | ### Data preparation 54 | - Download VGGFace2 Dataset from [VGGFace2 Dataset for Face Recognition](https://github.com/ox-vgg/vgg_face2) 55 | 56 | ## Inference 57 | 58 | - Frist, perform data preprocessing on all photos in VGGFACE2, that is, detect faces and align them to the same alignment format as FFHQdataset. 59 | ``` 60 | python scripts/crop_align_vggface2_FFHQalign.py --input_dir $DATAPATH$/VGGface2/train --output_dir_ffhqalign $ALIGN_OUTDIR$ --mode ffhq --crop_size 256 61 | ``` 62 | - And then, do the magic of image restoration with GFPGAN for processed photos. 63 | ``` 64 | python scripts/inference_gfpgan_forvggface2.py --input_path $ALIGN_OUTDIR$ --batchSize 8 --save_dir $HQ_OUTDIR$ 65 | ``` 66 | 67 | ## Citation 68 | 69 | If you find our work useful in your research, please consider citing: 70 | 71 | ``` 72 | @Article{simswapplusplus, 73 | author = {Xuanhong Chen and 74 | Bingbing Ni and 75 | Yutian Liu and 76 | Naiyuan Liu and 77 | Zhilin Zeng and 78 | Hang Wang}, 79 | title = {SimSwap++: Towards Faster and High-Quality Identity Swapping}, 80 | journal = {{IEEE} Trans. Pattern Anal. Mach. Intell.}, 81 | volume = {46}, 82 | number = {1}, 83 | pages = {576--592}, 84 | year = {2024} 85 | } 86 | ``` 87 | 88 | ## Related Projects 89 | 90 | ***Please visit our popular face swapping project*** 91 | 92 | [![logo](./docs/img/simswap.png)](https://github.com/neuralchen/SimSwap) 93 | 94 | ***Please visit our another ACMMM2020 high-quality style transfer project*** 95 | 96 | [![logo](./docs/img/logo.png)](https://github.com/neuralchen/ASMAGAN) 97 | 98 | [![title](/docs/img/title.png)](https://github.com/neuralchen/ASMAGAN) 99 | 100 | ***Please visit our AAAI2021 sketch based rendering project*** 101 | 102 | [![logo](./docs/img/girl2.gif)](https://github.com/TZYSJTU/Sketch-Generation-with-Drawing-Process-Guided-by-Vector-Flow-and-Grayscale) 103 | [![title](/docs/img/girl2-RGB.png)](https://github.com/TZYSJTU/Sketch-Generation-with-Drawing-Process-Guided-by-Vector-Flow-and-Grayscale) 104 | 105 | Learn about our other projects 106 | 107 | [[VGGFace2-HQ]](https://github.com/NNNNAI/VGGFace2-HQ); 108 | 109 | [[RainNet]](https://neuralchen.github.io/RainNet); 110 | 111 | [[Sketch Generation]](https://github.com/TZYSJTU/Sketch-Generation-with-Drawing-Process-Guided-by-Vector-Flow-and-Grayscale); 112 | 113 | [[CooGAN]](https://github.com/neuralchen/CooGAN); 114 | 115 | [[Knowledge Style Transfer]](https://github.com/AceSix/Knowledge_Transfer); 116 | 117 | [[SimSwap]](https://github.com/neuralchen/SimSwap); 118 | 119 | [[ASMA-GAN]](https://github.com/neuralchen/ASMAGAN); 120 | 121 | [[SNGAN-Projection-pytorch]](https://github.com/neuralchen/SNGAN_Projection) 122 | 123 | [[Pretrained_VGG19]](https://github.com/neuralchen/Pretrained_VGG19). 124 | 125 | 126 | 127 | # Acknowledgements 128 | 129 | 130 | * [GFPGAN](https://github.com/TencentARC/GFPGAN) 131 | * [Insightface](https://github.com/deepinsight/insightface) 132 | * [VGGFace2 Dataset for Face Recognition](https://github.com/ox-vgg/vgg_face2) 133 | 134 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.2.3 2 | -------------------------------------------------------------------------------- /VGGFace2-HQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/VGGFace2-HQ.png -------------------------------------------------------------------------------- /VGGFace2-HQ.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/VGGFace2-HQ.pptx -------------------------------------------------------------------------------- /docs/img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/2.png -------------------------------------------------------------------------------- /docs/img/VGGFace2-HQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/VGGFace2-HQ.png -------------------------------------------------------------------------------- /docs/img/girl2-RGB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/girl2-RGB.png -------------------------------------------------------------------------------- /docs/img/girl2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/girl2.gif -------------------------------------------------------------------------------- /docs/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/logo.png -------------------------------------------------------------------------------- /docs/img/simswap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/simswap.png -------------------------------------------------------------------------------- /docs/img/title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/title.png -------------------------------------------------------------------------------- /docs/img/vggface2_hq_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/vggface2_hq_compare.png -------------------------------------------------------------------------------- /gfpgan/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .archs import * 3 | from .data import * 4 | from .models import * 5 | from .utils import * 6 | from .version import __gitsha__, __version__ 7 | -------------------------------------------------------------------------------- /gfpgan/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import arch modules for registry 6 | # scan all the files that end with '_arch.py' under the archs folder 7 | arch_folder = osp.dirname(osp.abspath(__file__)) 8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 9 | # import all the arch modules 10 | _arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames] 11 | -------------------------------------------------------------------------------- /gfpgan/archs/arcface_arch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from basicsr.utils.registry import ARCH_REGISTRY 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, inplanes, planes, stride=1, downsample=None): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = conv3x3(inplanes, planes, stride) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv2 = conv3x3(planes, planes) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | if self.downsample is not None: 34 | residual = self.downsample(x) 35 | 36 | out += residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | 42 | class IRBlock(nn.Module): 43 | expansion = 1 44 | 45 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 46 | super(IRBlock, self).__init__() 47 | self.bn0 = nn.BatchNorm2d(inplanes) 48 | self.conv1 = conv3x3(inplanes, inplanes) 49 | self.bn1 = nn.BatchNorm2d(inplanes) 50 | self.prelu = nn.PReLU() 51 | self.conv2 = conv3x3(inplanes, planes, stride) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | self.use_se = use_se 56 | if self.use_se: 57 | self.se = SEBlock(planes) 58 | 59 | def forward(self, x): 60 | residual = x 61 | out = self.bn0(x) 62 | out = self.conv1(out) 63 | out = self.bn1(out) 64 | out = self.prelu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | if self.use_se: 69 | out = self.se(out) 70 | 71 | if self.downsample is not None: 72 | residual = self.downsample(x) 73 | 74 | out += residual 75 | out = self.prelu(out) 76 | 77 | return out 78 | 79 | 80 | class Bottleneck(nn.Module): 81 | expansion = 4 82 | 83 | def __init__(self, inplanes, planes, stride=1, downsample=None): 84 | super(Bottleneck, self).__init__() 85 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 86 | self.bn1 = nn.BatchNorm2d(planes) 87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 90 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | residual = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class SEBlock(nn.Module): 119 | 120 | def __init__(self, channel, reduction=16): 121 | super(SEBlock, self).__init__() 122 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 123 | self.fc = nn.Sequential( 124 | nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), 125 | nn.Sigmoid()) 126 | 127 | def forward(self, x): 128 | b, c, _, _ = x.size() 129 | y = self.avg_pool(x).view(b, c) 130 | y = self.fc(y).view(b, c, 1, 1) 131 | return x * y 132 | 133 | 134 | @ARCH_REGISTRY.register() 135 | class ResNetArcFace(nn.Module): 136 | 137 | def __init__(self, block, layers, use_se=True): 138 | if block == 'IRBlock': 139 | block = IRBlock 140 | self.inplanes = 64 141 | self.use_se = use_se 142 | super(ResNetArcFace, self).__init__() 143 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) 144 | self.bn1 = nn.BatchNorm2d(64) 145 | self.prelu = nn.PReLU() 146 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 147 | self.layer1 = self._make_layer(block, 64, layers[0]) 148 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 149 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 151 | self.bn4 = nn.BatchNorm2d(512) 152 | self.dropout = nn.Dropout() 153 | self.fc5 = nn.Linear(512 * 8 * 8, 512) 154 | self.bn5 = nn.BatchNorm1d(512) 155 | 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.xavier_normal_(m.weight) 159 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | elif isinstance(m, nn.Linear): 163 | nn.init.xavier_normal_(m.weight) 164 | nn.init.constant_(m.bias, 0) 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1): 167 | downsample = None 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 171 | nn.BatchNorm2d(planes * block.expansion), 172 | ) 173 | layers = [] 174 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 175 | self.inplanes = planes 176 | for _ in range(1, blocks): 177 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 178 | 179 | return nn.Sequential(*layers) 180 | 181 | def forward(self, x): 182 | x = self.conv1(x) 183 | x = self.bn1(x) 184 | x = self.prelu(x) 185 | x = self.maxpool(x) 186 | 187 | x = self.layer1(x) 188 | x = self.layer2(x) 189 | x = self.layer3(x) 190 | x = self.layer4(x) 191 | x = self.bn4(x) 192 | x = self.dropout(x) 193 | x = x.view(x.size(0), -1) 194 | x = self.fc5(x) 195 | x = self.bn5(x) 196 | 197 | return x 198 | -------------------------------------------------------------------------------- /gfpgan/archs/gfpganv1_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, 5 | StyleGAN2Generator) 6 | from basicsr.ops.fused_act import FusedLeakyReLU 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | 12 | class StyleGAN2GeneratorSFT(StyleGAN2Generator): 13 | """StyleGAN2 Generator. 14 | 15 | Args: 16 | out_size (int): The spatial size of outputs. 17 | num_style_feat (int): Channel number of style features. Default: 512. 18 | num_mlp (int): Layer number of MLP style layers. Default: 8. 19 | channel_multiplier (int): Channel multiplier for large networks of 20 | StyleGAN2. Default: 2. 21 | resample_kernel (list[int]): A list indicating the 1D resample kernel 22 | magnitude. A cross production will be applied to extent 1D resample 23 | kernel to 2D resample kernel. Default: [1, 3, 3, 1]. 24 | lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. 25 | """ 26 | 27 | def __init__(self, 28 | out_size, 29 | num_style_feat=512, 30 | num_mlp=8, 31 | channel_multiplier=2, 32 | resample_kernel=(1, 3, 3, 1), 33 | lr_mlp=0.01, 34 | narrow=1, 35 | sft_half=False): 36 | super(StyleGAN2GeneratorSFT, self).__init__( 37 | out_size, 38 | num_style_feat=num_style_feat, 39 | num_mlp=num_mlp, 40 | channel_multiplier=channel_multiplier, 41 | resample_kernel=resample_kernel, 42 | lr_mlp=lr_mlp, 43 | narrow=narrow) 44 | self.sft_half = sft_half 45 | 46 | def forward(self, 47 | styles, 48 | conditions, 49 | input_is_latent=False, 50 | noise=None, 51 | randomize_noise=True, 52 | truncation=1, 53 | truncation_latent=None, 54 | inject_index=None, 55 | return_latents=False): 56 | """Forward function for StyleGAN2Generator. 57 | 58 | Args: 59 | styles (list[Tensor]): Sample codes of styles. 60 | input_is_latent (bool): Whether input is latent style. 61 | Default: False. 62 | noise (Tensor | None): Input noise or None. Default: None. 63 | randomize_noise (bool): Randomize noise, used when 'noise' is 64 | False. Default: True. 65 | truncation (float): TODO. Default: 1. 66 | truncation_latent (Tensor | None): TODO. Default: None. 67 | inject_index (int | None): The injection index for mixing noise. 68 | Default: None. 69 | return_latents (bool): Whether to return style latents. 70 | Default: False. 71 | """ 72 | # style codes -> latents with Style MLP layer 73 | if not input_is_latent: 74 | styles = [self.style_mlp(s) for s in styles] 75 | # noises 76 | if noise is None: 77 | if randomize_noise: 78 | noise = [None] * self.num_layers # for each style conv layer 79 | else: # use the stored noise 80 | noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] 81 | # style truncation 82 | if truncation < 1: 83 | style_truncation = [] 84 | for style in styles: 85 | style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) 86 | styles = style_truncation 87 | # get style latent with injection 88 | if len(styles) == 1: 89 | inject_index = self.num_latent 90 | 91 | if styles[0].ndim < 3: 92 | # repeat latent code for all the layers 93 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 94 | else: # used for encoder with different latent code for each layer 95 | latent = styles[0] 96 | elif len(styles) == 2: # mixing noises 97 | if inject_index is None: 98 | inject_index = random.randint(1, self.num_latent - 1) 99 | latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 100 | latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) 101 | latent = torch.cat([latent1, latent2], 1) 102 | 103 | # main generation 104 | out = self.constant_input(latent.shape[0]) 105 | out = self.style_conv1(out, latent[:, 0], noise=noise[0]) 106 | skip = self.to_rgb1(out, latent[:, 1]) 107 | 108 | i = 1 109 | for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], 110 | noise[2::2], self.to_rgbs): 111 | out = conv1(out, latent[:, i], noise=noise1) 112 | 113 | # the conditions may have fewer levels 114 | if i < len(conditions): 115 | # SFT part to combine the conditions 116 | if self.sft_half: 117 | out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) 118 | out_sft = out_sft * conditions[i - 1] + conditions[i] 119 | out = torch.cat([out_same, out_sft], dim=1) 120 | else: 121 | out = out * conditions[i - 1] + conditions[i] 122 | 123 | out = conv2(out, latent[:, i + 1], noise=noise2) 124 | skip = to_rgb(out, latent[:, i + 2], skip) 125 | i += 2 126 | 127 | image = skip 128 | 129 | if return_latents: 130 | return image, latent 131 | else: 132 | return image, None 133 | 134 | 135 | class ConvUpLayer(nn.Module): 136 | """Conv Up Layer. Bilinear upsample + Conv. 137 | 138 | Args: 139 | in_channels (int): Channel number of the input. 140 | out_channels (int): Channel number of the output. 141 | kernel_size (int): Size of the convolving kernel. 142 | stride (int): Stride of the convolution. Default: 1 143 | padding (int): Zero-padding added to both sides of the input. 144 | Default: 0. 145 | bias (bool): If ``True``, adds a learnable bias to the output. 146 | Default: ``True``. 147 | bias_init_val (float): Bias initialized value. Default: 0. 148 | activate (bool): Whether use activateion. Default: True. 149 | """ 150 | 151 | def __init__(self, 152 | in_channels, 153 | out_channels, 154 | kernel_size, 155 | stride=1, 156 | padding=0, 157 | bias=True, 158 | bias_init_val=0, 159 | activate=True): 160 | super(ConvUpLayer, self).__init__() 161 | self.in_channels = in_channels 162 | self.out_channels = out_channels 163 | self.kernel_size = kernel_size 164 | self.stride = stride 165 | self.padding = padding 166 | self.scale = 1 / math.sqrt(in_channels * kernel_size**2) 167 | 168 | self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) 169 | 170 | if bias and not activate: 171 | self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val)) 172 | else: 173 | self.register_parameter('bias', None) 174 | 175 | # activation 176 | if activate: 177 | if bias: 178 | self.activation = FusedLeakyReLU(out_channels) 179 | else: 180 | self.activation = ScaledLeakyReLU(0.2) 181 | else: 182 | self.activation = None 183 | 184 | def forward(self, x): 185 | # bilinear upsample 186 | out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 187 | # conv 188 | out = F.conv2d( 189 | out, 190 | self.weight * self.scale, 191 | bias=self.bias, 192 | stride=self.stride, 193 | padding=self.padding, 194 | ) 195 | # activation 196 | if self.activation is not None: 197 | out = self.activation(out) 198 | return out 199 | 200 | 201 | class ResUpBlock(nn.Module): 202 | """Residual block with upsampling. 203 | 204 | Args: 205 | in_channels (int): Channel number of the input. 206 | out_channels (int): Channel number of the output. 207 | """ 208 | 209 | def __init__(self, in_channels, out_channels): 210 | super(ResUpBlock, self).__init__() 211 | 212 | self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True) 213 | self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True) 214 | self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False) 215 | 216 | def forward(self, x): 217 | out = self.conv1(x) 218 | out = self.conv2(out) 219 | skip = self.skip(x) 220 | out = (out + skip) / math.sqrt(2) 221 | return out 222 | 223 | 224 | @ARCH_REGISTRY.register() 225 | class GFPGANv1(nn.Module): 226 | """Unet + StyleGAN2 decoder with SFT.""" 227 | 228 | def __init__( 229 | self, 230 | out_size, 231 | num_style_feat=512, 232 | channel_multiplier=1, 233 | resample_kernel=(1, 3, 3, 1), 234 | decoder_load_path=None, 235 | fix_decoder=True, 236 | # for stylegan decoder 237 | num_mlp=8, 238 | lr_mlp=0.01, 239 | input_is_latent=False, 240 | different_w=False, 241 | narrow=1, 242 | sft_half=False): 243 | 244 | super(GFPGANv1, self).__init__() 245 | self.input_is_latent = input_is_latent 246 | self.different_w = different_w 247 | self.num_style_feat = num_style_feat 248 | 249 | unet_narrow = narrow * 0.5 250 | channels = { 251 | '4': int(512 * unet_narrow), 252 | '8': int(512 * unet_narrow), 253 | '16': int(512 * unet_narrow), 254 | '32': int(512 * unet_narrow), 255 | '64': int(256 * channel_multiplier * unet_narrow), 256 | '128': int(128 * channel_multiplier * unet_narrow), 257 | '256': int(64 * channel_multiplier * unet_narrow), 258 | '512': int(32 * channel_multiplier * unet_narrow), 259 | '1024': int(16 * channel_multiplier * unet_narrow) 260 | } 261 | 262 | self.log_size = int(math.log(out_size, 2)) 263 | first_out_size = 2**(int(math.log(out_size, 2))) 264 | 265 | self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) 266 | 267 | # downsample 268 | in_channels = channels[f'{first_out_size}'] 269 | self.conv_body_down = nn.ModuleList() 270 | for i in range(self.log_size, 2, -1): 271 | out_channels = channels[f'{2**(i - 1)}'] 272 | self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel)) 273 | in_channels = out_channels 274 | 275 | self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) 276 | 277 | # upsample 278 | in_channels = channels['4'] 279 | self.conv_body_up = nn.ModuleList() 280 | for i in range(3, self.log_size + 1): 281 | out_channels = channels[f'{2**i}'] 282 | self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) 283 | in_channels = out_channels 284 | 285 | # to RGB 286 | self.toRGB = nn.ModuleList() 287 | for i in range(3, self.log_size + 1): 288 | self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) 289 | 290 | if different_w: 291 | linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat 292 | else: 293 | linear_out_channel = num_style_feat 294 | 295 | self.final_linear = EqualLinear( 296 | channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) 297 | 298 | self.stylegan_decoder = StyleGAN2GeneratorSFT( 299 | out_size=out_size, 300 | num_style_feat=num_style_feat, 301 | num_mlp=num_mlp, 302 | channel_multiplier=channel_multiplier, 303 | resample_kernel=resample_kernel, 304 | lr_mlp=lr_mlp, 305 | narrow=narrow, 306 | sft_half=sft_half) 307 | 308 | if decoder_load_path: 309 | self.stylegan_decoder.load_state_dict( 310 | torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) 311 | if fix_decoder: 312 | for _, param in self.stylegan_decoder.named_parameters(): 313 | param.requires_grad = False 314 | 315 | # for SFT 316 | self.condition_scale = nn.ModuleList() 317 | self.condition_shift = nn.ModuleList() 318 | for i in range(3, self.log_size + 1): 319 | out_channels = channels[f'{2**i}'] 320 | if sft_half: 321 | sft_out_channels = out_channels 322 | else: 323 | sft_out_channels = out_channels * 2 324 | self.condition_scale.append( 325 | nn.Sequential( 326 | EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), 327 | ScaledLeakyReLU(0.2), 328 | EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) 329 | self.condition_shift.append( 330 | nn.Sequential( 331 | EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), 332 | ScaledLeakyReLU(0.2), 333 | EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) 334 | 335 | def forward(self, 336 | x, 337 | return_latents=False, 338 | save_feat_path=None, 339 | load_feat_path=None, 340 | return_rgb=True, 341 | randomize_noise=True): 342 | conditions = [] 343 | unet_skips = [] 344 | out_rgbs = [] 345 | 346 | # encoder 347 | feat = self.conv_body_first(x) 348 | for i in range(self.log_size - 2): 349 | feat = self.conv_body_down[i](feat) 350 | unet_skips.insert(0, feat) 351 | 352 | feat = self.final_conv(feat) 353 | 354 | # style code 355 | style_code = self.final_linear(feat.view(feat.size(0), -1)) 356 | if self.different_w: 357 | style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) 358 | 359 | # decode 360 | for i in range(self.log_size - 2): 361 | # add unet skip 362 | feat = feat + unet_skips[i] 363 | # ResUpLayer 364 | feat = self.conv_body_up[i](feat) 365 | # generate scale and shift for SFT layer 366 | scale = self.condition_scale[i](feat) 367 | conditions.append(scale.clone()) 368 | shift = self.condition_shift[i](feat) 369 | conditions.append(shift.clone()) 370 | # generate rgb images 371 | if return_rgb: 372 | out_rgbs.append(self.toRGB[i](feat)) 373 | 374 | if save_feat_path is not None: 375 | torch.save(conditions, save_feat_path) 376 | if load_feat_path is not None: 377 | conditions = torch.load(load_feat_path) 378 | conditions = [v.cuda() for v in conditions] 379 | 380 | # decoder 381 | image, _ = self.stylegan_decoder([style_code], 382 | conditions, 383 | return_latents=return_latents, 384 | input_is_latent=self.input_is_latent, 385 | randomize_noise=randomize_noise) 386 | 387 | return image, out_rgbs 388 | 389 | 390 | @ARCH_REGISTRY.register() 391 | class FacialComponentDiscriminator(nn.Module): 392 | 393 | def __init__(self): 394 | super(FacialComponentDiscriminator, self).__init__() 395 | 396 | self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) 397 | self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) 398 | self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) 399 | self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) 400 | self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True) 401 | self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False) 402 | 403 | def forward(self, x, return_feats=False): 404 | feat = self.conv1(x) 405 | feat = self.conv3(self.conv2(feat)) 406 | rlt_feats = [] 407 | if return_feats: 408 | rlt_feats.append(feat.clone()) 409 | feat = self.conv5(self.conv4(feat)) 410 | if return_feats: 411 | rlt_feats.append(feat.clone()) 412 | out = self.final_conv(feat) 413 | 414 | if return_feats: 415 | return out, rlt_feats 416 | else: 417 | return out, None 418 | -------------------------------------------------------------------------------- /gfpgan/archs/gfpganv1_clean_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from .stylegan2_clean_arch import StyleGAN2GeneratorClean 8 | 9 | 10 | class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean): 11 | """StyleGAN2 Generator. 12 | 13 | Args: 14 | out_size (int): The spatial size of outputs. 15 | num_style_feat (int): Channel number of style features. Default: 512. 16 | num_mlp (int): Layer number of MLP style layers. Default: 8. 17 | channel_multiplier (int): Channel multiplier for large networks of 18 | StyleGAN2. Default: 2. 19 | """ 20 | 21 | def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False): 22 | super(StyleGAN2GeneratorCSFT, self).__init__( 23 | out_size, 24 | num_style_feat=num_style_feat, 25 | num_mlp=num_mlp, 26 | channel_multiplier=channel_multiplier, 27 | narrow=narrow) 28 | 29 | self.sft_half = sft_half 30 | 31 | def forward(self, 32 | styles, 33 | conditions, 34 | input_is_latent=False, 35 | noise=None, 36 | randomize_noise=True, 37 | truncation=1, 38 | truncation_latent=None, 39 | inject_index=None, 40 | return_latents=False): 41 | """Forward function for StyleGAN2Generator. 42 | 43 | Args: 44 | styles (list[Tensor]): Sample codes of styles. 45 | input_is_latent (bool): Whether input is latent style. 46 | Default: False. 47 | noise (Tensor | None): Input noise or None. Default: None. 48 | randomize_noise (bool): Randomize noise, used when 'noise' is 49 | False. Default: True. 50 | truncation (float): TODO. Default: 1. 51 | truncation_latent (Tensor | None): TODO. Default: None. 52 | inject_index (int | None): The injection index for mixing noise. 53 | Default: None. 54 | return_latents (bool): Whether to return style latents. 55 | Default: False. 56 | """ 57 | # style codes -> latents with Style MLP layer 58 | if not input_is_latent: 59 | styles = [self.style_mlp(s) for s in styles] 60 | # noises 61 | if noise is None: 62 | if randomize_noise: 63 | noise = [None] * self.num_layers # for each style conv layer 64 | else: # use the stored noise 65 | noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] 66 | # style truncation 67 | if truncation < 1: 68 | style_truncation = [] 69 | for style in styles: 70 | style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) 71 | styles = style_truncation 72 | # get style latent with injection 73 | if len(styles) == 1: 74 | inject_index = self.num_latent 75 | 76 | if styles[0].ndim < 3: 77 | # repeat latent code for all the layers 78 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 79 | else: # used for encoder with different latent code for each layer 80 | latent = styles[0] 81 | elif len(styles) == 2: # mixing noises 82 | if inject_index is None: 83 | inject_index = random.randint(1, self.num_latent - 1) 84 | latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 85 | latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) 86 | latent = torch.cat([latent1, latent2], 1) 87 | 88 | # main generation 89 | out = self.constant_input(latent.shape[0]) 90 | out = self.style_conv1(out, latent[:, 0], noise=noise[0]) 91 | skip = self.to_rgb1(out, latent[:, 1]) 92 | 93 | i = 1 94 | for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], 95 | noise[2::2], self.to_rgbs): 96 | out = conv1(out, latent[:, i], noise=noise1) 97 | 98 | # the conditions may have fewer levels 99 | if i < len(conditions): 100 | # SFT part to combine the conditions 101 | if self.sft_half: 102 | out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) 103 | out_sft = out_sft * conditions[i - 1] + conditions[i] 104 | out = torch.cat([out_same, out_sft], dim=1) 105 | else: 106 | out = out * conditions[i - 1] + conditions[i] 107 | 108 | out = conv2(out, latent[:, i + 1], noise=noise2) 109 | skip = to_rgb(out, latent[:, i + 2], skip) 110 | i += 2 111 | 112 | image = skip 113 | 114 | if return_latents: 115 | return image, latent 116 | else: 117 | return image, None 118 | 119 | 120 | class ResBlock(nn.Module): 121 | """Residual block with upsampling/downsampling. 122 | 123 | Args: 124 | in_channels (int): Channel number of the input. 125 | out_channels (int): Channel number of the output. 126 | """ 127 | 128 | def __init__(self, in_channels, out_channels, mode='down'): 129 | super(ResBlock, self).__init__() 130 | 131 | self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) 132 | self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) 133 | self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) 134 | if mode == 'down': 135 | self.scale_factor = 0.5 136 | elif mode == 'up': 137 | self.scale_factor = 2 138 | 139 | def forward(self, x): 140 | out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) 141 | # upsample/downsample 142 | out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) 143 | out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) 144 | # skip 145 | x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) 146 | skip = self.skip(x) 147 | out = out + skip 148 | return out 149 | 150 | 151 | class GFPGANv1Clean(nn.Module): 152 | """GFPGANv1 Clean version.""" 153 | 154 | def __init__( 155 | self, 156 | out_size, 157 | num_style_feat=512, 158 | channel_multiplier=1, 159 | decoder_load_path=None, 160 | fix_decoder=True, 161 | # for stylegan decoder 162 | num_mlp=8, 163 | input_is_latent=False, 164 | different_w=False, 165 | narrow=1, 166 | sft_half=False): 167 | 168 | super(GFPGANv1Clean, self).__init__() 169 | self.input_is_latent = input_is_latent 170 | self.different_w = different_w 171 | self.num_style_feat = num_style_feat 172 | 173 | unet_narrow = narrow * 0.5 174 | channels = { 175 | '4': int(512 * unet_narrow), 176 | '8': int(512 * unet_narrow), 177 | '16': int(512 * unet_narrow), 178 | '32': int(512 * unet_narrow), 179 | '64': int(256 * channel_multiplier * unet_narrow), 180 | '128': int(128 * channel_multiplier * unet_narrow), 181 | '256': int(64 * channel_multiplier * unet_narrow), 182 | '512': int(32 * channel_multiplier * unet_narrow), 183 | '1024': int(16 * channel_multiplier * unet_narrow) 184 | } 185 | 186 | self.log_size = int(math.log(out_size, 2)) 187 | first_out_size = 2**(int(math.log(out_size, 2))) 188 | 189 | self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) 190 | 191 | # downsample 192 | in_channels = channels[f'{first_out_size}'] 193 | self.conv_body_down = nn.ModuleList() 194 | for i in range(self.log_size, 2, -1): 195 | out_channels = channels[f'{2**(i - 1)}'] 196 | self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) 197 | in_channels = out_channels 198 | 199 | self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1) 200 | 201 | # upsample 202 | in_channels = channels['4'] 203 | self.conv_body_up = nn.ModuleList() 204 | for i in range(3, self.log_size + 1): 205 | out_channels = channels[f'{2**i}'] 206 | self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up')) 207 | in_channels = out_channels 208 | 209 | # to RGB 210 | self.toRGB = nn.ModuleList() 211 | for i in range(3, self.log_size + 1): 212 | self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1)) 213 | 214 | if different_w: 215 | linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat 216 | else: 217 | linear_out_channel = num_style_feat 218 | 219 | self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) 220 | 221 | self.stylegan_decoder = StyleGAN2GeneratorCSFT( 222 | out_size=out_size, 223 | num_style_feat=num_style_feat, 224 | num_mlp=num_mlp, 225 | channel_multiplier=channel_multiplier, 226 | narrow=narrow, 227 | sft_half=sft_half) 228 | 229 | if decoder_load_path: 230 | self.stylegan_decoder.load_state_dict( 231 | torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) 232 | if fix_decoder: 233 | for _, param in self.stylegan_decoder.named_parameters(): 234 | param.requires_grad = False 235 | 236 | # for SFT 237 | self.condition_scale = nn.ModuleList() 238 | self.condition_shift = nn.ModuleList() 239 | for i in range(3, self.log_size + 1): 240 | out_channels = channels[f'{2**i}'] 241 | if sft_half: 242 | sft_out_channels = out_channels 243 | else: 244 | sft_out_channels = out_channels * 2 245 | self.condition_scale.append( 246 | nn.Sequential( 247 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), 248 | nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) 249 | self.condition_shift.append( 250 | nn.Sequential( 251 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True), 252 | nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1))) 253 | 254 | def forward(self, 255 | x, 256 | return_latents=False, 257 | save_feat_path=None, 258 | load_feat_path=None, 259 | return_rgb=True, 260 | randomize_noise=True): 261 | conditions = [] 262 | unet_skips = [] 263 | out_rgbs = [] 264 | 265 | # encoder 266 | feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2) 267 | for i in range(self.log_size - 2): 268 | feat = self.conv_body_down[i](feat) 269 | unet_skips.insert(0, feat) 270 | feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) 271 | 272 | # style code 273 | style_code = self.final_linear(feat.view(feat.size(0), -1)) 274 | if self.different_w: 275 | style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) 276 | # decode 277 | for i in range(self.log_size - 2): 278 | # add unet skip 279 | feat = feat + unet_skips[i] 280 | # ResUpLayer 281 | feat = self.conv_body_up[i](feat) 282 | # generate scale and shift for SFT layer 283 | scale = self.condition_scale[i](feat) 284 | conditions.append(scale.clone()) 285 | shift = self.condition_shift[i](feat) 286 | conditions.append(shift.clone()) 287 | # generate rgb images 288 | if return_rgb: 289 | out_rgbs.append(self.toRGB[i](feat)) 290 | 291 | if save_feat_path is not None: 292 | torch.save(conditions, save_feat_path) 293 | if load_feat_path is not None: 294 | conditions = torch.load(load_feat_path) 295 | conditions = [v.cuda() for v in conditions] 296 | 297 | # decoder 298 | image, _ = self.stylegan_decoder([style_code], 299 | conditions, 300 | return_latents=return_latents, 301 | input_is_latent=self.input_is_latent, 302 | randomize_noise=randomize_noise) 303 | 304 | return image, out_rgbs 305 | -------------------------------------------------------------------------------- /gfpgan/archs/stylegan2_clean_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | from basicsr.archs.arch_util import default_init_weights 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | class NormStyleCode(nn.Module): 11 | 12 | def forward(self, x): 13 | """Normalize the style codes. 14 | 15 | Args: 16 | x (Tensor): Style codes with shape (b, c). 17 | 18 | Returns: 19 | Tensor: Normalized tensor. 20 | """ 21 | return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8) 22 | 23 | 24 | class ModulatedConv2d(nn.Module): 25 | """Modulated Conv2d used in StyleGAN2. 26 | 27 | There is no bias in ModulatedConv2d. 28 | 29 | Args: 30 | in_channels (int): Channel number of the input. 31 | out_channels (int): Channel number of the output. 32 | kernel_size (int): Size of the convolving kernel. 33 | num_style_feat (int): Channel number of style features. 34 | demodulate (bool): Whether to demodulate in the conv layer. 35 | Default: True. 36 | sample_mode (str | None): Indicating 'upsample', 'downsample' or None. 37 | Default: None. 38 | eps (float): A value added to the denominator for numerical stability. 39 | Default: 1e-8. 40 | """ 41 | 42 | def __init__(self, 43 | in_channels, 44 | out_channels, 45 | kernel_size, 46 | num_style_feat, 47 | demodulate=True, 48 | sample_mode=None, 49 | eps=1e-8): 50 | super(ModulatedConv2d, self).__init__() 51 | self.in_channels = in_channels 52 | self.out_channels = out_channels 53 | self.kernel_size = kernel_size 54 | self.demodulate = demodulate 55 | self.sample_mode = sample_mode 56 | self.eps = eps 57 | 58 | # modulation inside each modulated conv 59 | self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) 60 | # initialization 61 | default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear') 62 | 63 | self.weight = nn.Parameter( 64 | torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) / 65 | math.sqrt(in_channels * kernel_size**2)) 66 | self.padding = kernel_size // 2 67 | 68 | def forward(self, x, style): 69 | """Forward function. 70 | 71 | Args: 72 | x (Tensor): Tensor with shape (b, c, h, w). 73 | style (Tensor): Tensor with shape (b, num_style_feat). 74 | 75 | Returns: 76 | Tensor: Modulated tensor after convolution. 77 | """ 78 | b, c, h, w = x.shape # c = c_in 79 | # weight modulation 80 | style = self.modulation(style).view(b, 1, c, 1, 1) 81 | # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1) 82 | weight = self.weight * style # (b, c_out, c_in, k, k) 83 | 84 | if self.demodulate: 85 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) 86 | weight = weight * demod.view(b, self.out_channels, 1, 1, 1) 87 | 88 | weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) 89 | 90 | if self.sample_mode == 'upsample': 91 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 92 | elif self.sample_mode == 'downsample': 93 | x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) 94 | 95 | b, c, h, w = x.shape 96 | x = x.view(1, b * c, h, w) 97 | # weight: (b*c_out, c_in, k, k), groups=b 98 | out = F.conv2d(x, weight, padding=self.padding, groups=b) 99 | out = out.view(b, self.out_channels, *out.shape[2:4]) 100 | 101 | return out 102 | 103 | def __repr__(self): 104 | return (f'{self.__class__.__name__}(in_channels={self.in_channels}, ' 105 | f'out_channels={self.out_channels}, ' 106 | f'kernel_size={self.kernel_size}, ' 107 | f'demodulate={self.demodulate}, sample_mode={self.sample_mode})') 108 | 109 | 110 | class StyleConv(nn.Module): 111 | """Style conv. 112 | 113 | Args: 114 | in_channels (int): Channel number of the input. 115 | out_channels (int): Channel number of the output. 116 | kernel_size (int): Size of the convolving kernel. 117 | num_style_feat (int): Channel number of style features. 118 | demodulate (bool): Whether demodulate in the conv layer. Default: True. 119 | sample_mode (str | None): Indicating 'upsample', 'downsample' or None. 120 | Default: None. 121 | """ 122 | 123 | def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None): 124 | super(StyleConv, self).__init__() 125 | self.modulated_conv = ModulatedConv2d( 126 | in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode) 127 | self.weight = nn.Parameter(torch.zeros(1)) # for noise injection 128 | self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) 129 | self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) 130 | 131 | def forward(self, x, style, noise=None): 132 | # modulate 133 | out = self.modulated_conv(x, style) * 2**0.5 # for conversion 134 | # noise injection 135 | if noise is None: 136 | b, _, h, w = out.shape 137 | noise = out.new_empty(b, 1, h, w).normal_() 138 | out = out + self.weight * noise 139 | # add bias 140 | out = out + self.bias 141 | # activation 142 | out = self.activate(out) 143 | return out 144 | 145 | 146 | class ToRGB(nn.Module): 147 | """To RGB from features. 148 | 149 | Args: 150 | in_channels (int): Channel number of input. 151 | num_style_feat (int): Channel number of style features. 152 | upsample (bool): Whether to upsample. Default: True. 153 | """ 154 | 155 | def __init__(self, in_channels, num_style_feat, upsample=True): 156 | super(ToRGB, self).__init__() 157 | self.upsample = upsample 158 | self.modulated_conv = ModulatedConv2d( 159 | in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None) 160 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 161 | 162 | def forward(self, x, style, skip=None): 163 | """Forward function. 164 | 165 | Args: 166 | x (Tensor): Feature tensor with shape (b, c, h, w). 167 | style (Tensor): Tensor with shape (b, num_style_feat). 168 | skip (Tensor): Base/skip tensor. Default: None. 169 | 170 | Returns: 171 | Tensor: RGB images. 172 | """ 173 | out = self.modulated_conv(x, style) 174 | out = out + self.bias 175 | if skip is not None: 176 | if self.upsample: 177 | skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) 178 | out = out + skip 179 | return out 180 | 181 | 182 | class ConstantInput(nn.Module): 183 | """Constant input. 184 | 185 | Args: 186 | num_channel (int): Channel number of constant input. 187 | size (int): Spatial size of constant input. 188 | """ 189 | 190 | def __init__(self, num_channel, size): 191 | super(ConstantInput, self).__init__() 192 | self.weight = nn.Parameter(torch.randn(1, num_channel, size, size)) 193 | 194 | def forward(self, batch): 195 | out = self.weight.repeat(batch, 1, 1, 1) 196 | return out 197 | 198 | 199 | @ARCH_REGISTRY.register() 200 | class StyleGAN2GeneratorClean(nn.Module): 201 | """Clean version of StyleGAN2 Generator. 202 | 203 | Args: 204 | out_size (int): The spatial size of outputs. 205 | num_style_feat (int): Channel number of style features. Default: 512. 206 | num_mlp (int): Layer number of MLP style layers. Default: 8. 207 | channel_multiplier (int): Channel multiplier for large networks of 208 | StyleGAN2. Default: 2. 209 | narrow (float): Narrow ratio for channels. Default: 1.0. 210 | """ 211 | 212 | def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1): 213 | super(StyleGAN2GeneratorClean, self).__init__() 214 | # Style MLP layers 215 | self.num_style_feat = num_style_feat 216 | style_mlp_layers = [NormStyleCode()] 217 | for i in range(num_mlp): 218 | style_mlp_layers.extend( 219 | [nn.Linear(num_style_feat, num_style_feat, bias=True), 220 | nn.LeakyReLU(negative_slope=0.2, inplace=True)]) 221 | self.style_mlp = nn.Sequential(*style_mlp_layers) 222 | # initialization 223 | default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu') 224 | 225 | channels = { 226 | '4': int(512 * narrow), 227 | '8': int(512 * narrow), 228 | '16': int(512 * narrow), 229 | '32': int(512 * narrow), 230 | '64': int(256 * channel_multiplier * narrow), 231 | '128': int(128 * channel_multiplier * narrow), 232 | '256': int(64 * channel_multiplier * narrow), 233 | '512': int(32 * channel_multiplier * narrow), 234 | '1024': int(16 * channel_multiplier * narrow) 235 | } 236 | self.channels = channels 237 | 238 | self.constant_input = ConstantInput(channels['4'], size=4) 239 | self.style_conv1 = StyleConv( 240 | channels['4'], 241 | channels['4'], 242 | kernel_size=3, 243 | num_style_feat=num_style_feat, 244 | demodulate=True, 245 | sample_mode=None) 246 | self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False) 247 | 248 | self.log_size = int(math.log(out_size, 2)) 249 | self.num_layers = (self.log_size - 2) * 2 + 1 250 | self.num_latent = self.log_size * 2 - 2 251 | 252 | self.style_convs = nn.ModuleList() 253 | self.to_rgbs = nn.ModuleList() 254 | self.noises = nn.Module() 255 | 256 | in_channels = channels['4'] 257 | # noise 258 | for layer_idx in range(self.num_layers): 259 | resolution = 2**((layer_idx + 5) // 2) 260 | shape = [1, 1, resolution, resolution] 261 | self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape)) 262 | # style convs and to_rgbs 263 | for i in range(3, self.log_size + 1): 264 | out_channels = channels[f'{2**i}'] 265 | self.style_convs.append( 266 | StyleConv( 267 | in_channels, 268 | out_channels, 269 | kernel_size=3, 270 | num_style_feat=num_style_feat, 271 | demodulate=True, 272 | sample_mode='upsample')) 273 | self.style_convs.append( 274 | StyleConv( 275 | out_channels, 276 | out_channels, 277 | kernel_size=3, 278 | num_style_feat=num_style_feat, 279 | demodulate=True, 280 | sample_mode=None)) 281 | self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) 282 | in_channels = out_channels 283 | 284 | def make_noise(self): 285 | """Make noise for noise injection.""" 286 | device = self.constant_input.weight.device 287 | noises = [torch.randn(1, 1, 4, 4, device=device)] 288 | 289 | for i in range(3, self.log_size + 1): 290 | for _ in range(2): 291 | noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) 292 | 293 | return noises 294 | 295 | def get_latent(self, x): 296 | return self.style_mlp(x) 297 | 298 | def mean_latent(self, num_latent): 299 | latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device) 300 | latent = self.style_mlp(latent_in).mean(0, keepdim=True) 301 | return latent 302 | 303 | def forward(self, 304 | styles, 305 | input_is_latent=False, 306 | noise=None, 307 | randomize_noise=True, 308 | truncation=1, 309 | truncation_latent=None, 310 | inject_index=None, 311 | return_latents=False): 312 | """Forward function for StyleGAN2Generator. 313 | 314 | Args: 315 | styles (list[Tensor]): Sample codes of styles. 316 | input_is_latent (bool): Whether input is latent style. 317 | Default: False. 318 | noise (Tensor | None): Input noise or None. Default: None. 319 | randomize_noise (bool): Randomize noise, used when 'noise' is 320 | False. Default: True. 321 | truncation (float): TODO. Default: 1. 322 | truncation_latent (Tensor | None): TODO. Default: None. 323 | inject_index (int | None): The injection index for mixing noise. 324 | Default: None. 325 | return_latents (bool): Whether to return style latents. 326 | Default: False. 327 | """ 328 | # style codes -> latents with Style MLP layer 329 | if not input_is_latent: 330 | styles = [self.style_mlp(s) for s in styles] 331 | # noises 332 | if noise is None: 333 | if randomize_noise: 334 | noise = [None] * self.num_layers # for each style conv layer 335 | else: # use the stored noise 336 | noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] 337 | # style truncation 338 | if truncation < 1: 339 | style_truncation = [] 340 | for style in styles: 341 | style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) 342 | styles = style_truncation 343 | # get style latent with injection 344 | if len(styles) == 1: 345 | inject_index = self.num_latent 346 | 347 | if styles[0].ndim < 3: 348 | # repeat latent code for all the layers 349 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 350 | else: # used for encoder with different latent code for each layer 351 | latent = styles[0] 352 | elif len(styles) == 2: # mixing noises 353 | if inject_index is None: 354 | inject_index = random.randint(1, self.num_latent - 1) 355 | latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 356 | latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) 357 | latent = torch.cat([latent1, latent2], 1) 358 | 359 | # main generation 360 | out = self.constant_input(latent.shape[0]) 361 | out = self.style_conv1(out, latent[:, 0], noise=noise[0]) 362 | skip = self.to_rgb1(out, latent[:, 1]) 363 | 364 | i = 1 365 | for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], 366 | noise[2::2], self.to_rgbs): 367 | out = conv1(out, latent[:, i], noise=noise1) 368 | out = conv2(out, latent[:, i + 1], noise=noise2) 369 | skip = to_rgb(out, latent[:, i + 2], skip) 370 | i += 2 371 | 372 | image = skip 373 | 374 | if return_latents: 375 | return image, latent 376 | else: 377 | return image, None 378 | -------------------------------------------------------------------------------- /gfpgan/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import dataset modules for registry 6 | # scan all the files that end with '_dataset.py' under the data folder 7 | data_folder = osp.dirname(osp.abspath(__file__)) 8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 9 | # import all the dataset modules 10 | _dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames] 11 | -------------------------------------------------------------------------------- /gfpgan/data/ffhq_degradation_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os.path as osp 5 | import torch 6 | import torch.utils.data as data 7 | from basicsr.data import degradations as degradations 8 | from basicsr.data.data_util import paths_from_folder 9 | from basicsr.data.transforms import augment 10 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 11 | from basicsr.utils.registry import DATASET_REGISTRY 12 | from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, 13 | normalize) 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class FFHQDegradationDataset(data.Dataset): 18 | 19 | def __init__(self, opt): 20 | super(FFHQDegradationDataset, self).__init__() 21 | self.opt = opt 22 | # file client (io backend) 23 | self.file_client = None 24 | self.io_backend_opt = opt['io_backend'] 25 | 26 | self.gt_folder = opt['dataroot_gt'] 27 | self.mean = opt['mean'] 28 | self.std = opt['std'] 29 | self.out_size = opt['out_size'] 30 | 31 | self.crop_components = opt.get('crop_components', False) # facial components 32 | self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) 33 | 34 | if self.crop_components: 35 | self.components_list = torch.load(opt.get('component_path')) 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | self.paths = paths_from_folder(self.gt_folder) 45 | 46 | # degradations 47 | self.blur_kernel_size = opt['blur_kernel_size'] 48 | self.kernel_list = opt['kernel_list'] 49 | self.kernel_prob = opt['kernel_prob'] 50 | self.blur_sigma = opt['blur_sigma'] 51 | self.downsample_range = opt['downsample_range'] 52 | self.noise_range = opt['noise_range'] 53 | self.jpeg_range = opt['jpeg_range'] 54 | 55 | # color jitter 56 | self.color_jitter_prob = opt.get('color_jitter_prob') 57 | self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') 58 | self.color_jitter_shift = opt.get('color_jitter_shift', 20) 59 | # to gray 60 | self.gray_prob = opt.get('gray_prob') 61 | 62 | logger = get_root_logger() 63 | logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, ' 64 | f'sigma: [{", ".join(map(str, self.blur_sigma))}]') 65 | logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') 66 | logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') 67 | logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') 68 | 69 | if self.color_jitter_prob is not None: 70 | logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, ' 71 | f'shift: {self.color_jitter_shift}') 72 | if self.gray_prob is not None: 73 | logger.info(f'Use random gray. Prob: {self.gray_prob}') 74 | 75 | self.color_jitter_shift /= 255. 76 | 77 | @staticmethod 78 | def color_jitter(img, shift): 79 | jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) 80 | img = img + jitter_val 81 | img = np.clip(img, 0, 1) 82 | return img 83 | 84 | @staticmethod 85 | def color_jitter_pt(img, brightness, contrast, saturation, hue): 86 | fn_idx = torch.randperm(4) 87 | for fn_id in fn_idx: 88 | if fn_id == 0 and brightness is not None: 89 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 90 | img = adjust_brightness(img, brightness_factor) 91 | 92 | if fn_id == 1 and contrast is not None: 93 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() 94 | img = adjust_contrast(img, contrast_factor) 95 | 96 | if fn_id == 2 and saturation is not None: 97 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() 98 | img = adjust_saturation(img, saturation_factor) 99 | 100 | if fn_id == 3 and hue is not None: 101 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() 102 | img = adjust_hue(img, hue_factor) 103 | return img 104 | 105 | def get_component_coordinates(self, index, status): 106 | components_bbox = self.components_list[f'{index:08d}'] 107 | if status[0]: # hflip 108 | # exchange right and left eye 109 | tmp = components_bbox['left_eye'] 110 | components_bbox['left_eye'] = components_bbox['right_eye'] 111 | components_bbox['right_eye'] = tmp 112 | # modify the width coordinate 113 | components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] 114 | components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] 115 | components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] 116 | 117 | # get coordinates 118 | locations = [] 119 | for part in ['left_eye', 'right_eye', 'mouth']: 120 | mean = components_bbox[part][0:2] 121 | half_len = components_bbox[part][2] 122 | if 'eye' in part: 123 | half_len *= self.eye_enlarge_ratio 124 | loc = np.hstack((mean - half_len + 1, mean + half_len)) 125 | loc = torch.from_numpy(loc).float() 126 | locations.append(loc) 127 | return locations 128 | 129 | def __getitem__(self, index): 130 | if self.file_client is None: 131 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 132 | 133 | # load gt image 134 | gt_path = self.paths[index] 135 | img_bytes = self.file_client.get(gt_path) 136 | img_gt = imfrombytes(img_bytes, float32=True) 137 | 138 | # random horizontal flip 139 | img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) 140 | h, w, _ = img_gt.shape 141 | 142 | if self.crop_components: 143 | locations = self.get_component_coordinates(index, status) 144 | loc_left_eye, loc_right_eye, loc_mouth = locations 145 | 146 | # ------------------------ generate lq image ------------------------ # 147 | # blur 148 | kernel = degradations.random_mixed_kernels( 149 | self.kernel_list, 150 | self.kernel_prob, 151 | self.blur_kernel_size, 152 | self.blur_sigma, 153 | self.blur_sigma, [-math.pi, math.pi], 154 | noise_range=None) 155 | img_lq = cv2.filter2D(img_gt, -1, kernel) 156 | # downsample 157 | scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) 158 | img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) 159 | # noise 160 | if self.noise_range is not None: 161 | img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) 162 | # jpeg compression 163 | if self.jpeg_range is not None: 164 | img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) 165 | 166 | # resize to original size 167 | img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) 168 | 169 | # random color jitter (only for lq) 170 | if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): 171 | img_lq = self.color_jitter(img_lq, self.color_jitter_shift) 172 | # random to gray (only for lq) 173 | if self.gray_prob and np.random.uniform() < self.gray_prob: 174 | img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) 175 | img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) 176 | if self.opt.get('gt_gray'): 177 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) 178 | img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) 179 | 180 | # BGR to RGB, HWC to CHW, numpy to tensor 181 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 182 | 183 | # random color jitter (pytorch version) (only for lq) 184 | if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): 185 | brightness = self.opt.get('brightness', (0.5, 1.5)) 186 | contrast = self.opt.get('contrast', (0.5, 1.5)) 187 | saturation = self.opt.get('saturation', (0, 1.5)) 188 | hue = self.opt.get('hue', (-0.1, 0.1)) 189 | img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) 190 | 191 | # round and clip 192 | img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. 193 | 194 | # normalize 195 | normalize(img_gt, self.mean, self.std, inplace=True) 196 | normalize(img_lq, self.mean, self.std, inplace=True) 197 | 198 | if self.crop_components: 199 | return_dict = { 200 | 'lq': img_lq, 201 | 'gt': img_gt, 202 | 'gt_path': gt_path, 203 | 'loc_left_eye': loc_left_eye, 204 | 'loc_right_eye': loc_right_eye, 205 | 'loc_mouth': loc_mouth 206 | } 207 | return return_dict 208 | else: 209 | return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} 210 | 211 | def __len__(self): 212 | return len(self.paths) 213 | -------------------------------------------------------------------------------- /gfpgan/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import model modules for registry 6 | # scan all the files that end with '_model.py' under the model folder 7 | model_folder = osp.dirname(osp.abspath(__file__)) 8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 9 | # import all the model modules 10 | _model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames] 11 | -------------------------------------------------------------------------------- /gfpgan/models/gfpgan_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os.path as osp 3 | import torch 4 | from basicsr.archs import build_network 5 | from basicsr.losses import build_loss 6 | from basicsr.losses.losses import r1_penalty 7 | from basicsr.metrics import calculate_metric 8 | from basicsr.models.base_model import BaseModel 9 | from basicsr.utils import get_root_logger, imwrite, tensor2img 10 | from basicsr.utils.registry import MODEL_REGISTRY 11 | from collections import OrderedDict 12 | from torch.nn import functional as F 13 | from torchvision.ops import roi_align 14 | from tqdm import tqdm 15 | 16 | 17 | @MODEL_REGISTRY.register() 18 | class GFPGANModel(BaseModel): 19 | """GFPGAN model for """ 20 | 21 | def __init__(self, opt): 22 | super(GFPGANModel, self).__init__(opt) 23 | self.idx = 0 24 | 25 | # define network 26 | self.net_g = build_network(opt['network_g']) 27 | self.net_g = self.model_to_device(self.net_g) 28 | self.print_network(self.net_g) 29 | 30 | # load pretrained model 31 | load_path = self.opt['path'].get('pretrain_network_g', None) 32 | if load_path is not None: 33 | param_key = self.opt['path'].get('param_key_g', 'params') 34 | self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) 35 | 36 | self.log_size = int(math.log(self.opt['network_g']['out_size'], 2)) 37 | 38 | if self.is_train: 39 | self.init_training_settings() 40 | 41 | def init_training_settings(self): 42 | train_opt = self.opt['train'] 43 | 44 | # ----------- define net_d ----------- # 45 | self.net_d = build_network(self.opt['network_d']) 46 | self.net_d = self.model_to_device(self.net_d) 47 | self.print_network(self.net_d) 48 | # load pretrained model 49 | load_path = self.opt['path'].get('pretrain_network_d', None) 50 | if load_path is not None: 51 | self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True)) 52 | 53 | # ----------- define net_g with Exponential Moving Average (EMA) ----------- # 54 | # net_g_ema only used for testing on one GPU and saving 55 | # There is no need to wrap with DistributedDataParallel 56 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device) 57 | # load pretrained model 58 | load_path = self.opt['path'].get('pretrain_network_g', None) 59 | if load_path is not None: 60 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') 61 | else: 62 | self.model_ema(0) # copy net_g weight 63 | 64 | self.net_g.train() 65 | self.net_d.train() 66 | self.net_g_ema.eval() 67 | 68 | # ----------- facial components networks ----------- # 69 | if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt): 70 | self.use_facial_disc = True 71 | else: 72 | self.use_facial_disc = False 73 | 74 | if self.use_facial_disc: 75 | # left eye 76 | self.net_d_left_eye = build_network(self.opt['network_d_left_eye']) 77 | self.net_d_left_eye = self.model_to_device(self.net_d_left_eye) 78 | self.print_network(self.net_d_left_eye) 79 | load_path = self.opt['path'].get('pretrain_network_d_left_eye') 80 | if load_path is not None: 81 | self.load_network(self.net_d_left_eye, load_path, True, 'params') 82 | # right eye 83 | self.net_d_right_eye = build_network(self.opt['network_d_right_eye']) 84 | self.net_d_right_eye = self.model_to_device(self.net_d_right_eye) 85 | self.print_network(self.net_d_right_eye) 86 | load_path = self.opt['path'].get('pretrain_network_d_right_eye') 87 | if load_path is not None: 88 | self.load_network(self.net_d_right_eye, load_path, True, 'params') 89 | # mouth 90 | self.net_d_mouth = build_network(self.opt['network_d_mouth']) 91 | self.net_d_mouth = self.model_to_device(self.net_d_mouth) 92 | self.print_network(self.net_d_mouth) 93 | load_path = self.opt['path'].get('pretrain_network_d_mouth') 94 | if load_path is not None: 95 | self.load_network(self.net_d_mouth, load_path, True, 'params') 96 | 97 | self.net_d_left_eye.train() 98 | self.net_d_right_eye.train() 99 | self.net_d_mouth.train() 100 | 101 | # ----------- define facial component gan loss ----------- # 102 | self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device) 103 | 104 | # ----------- define losses ----------- # 105 | if train_opt.get('pixel_opt'): 106 | self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) 107 | else: 108 | self.cri_pix = None 109 | 110 | if train_opt.get('perceptual_opt'): 111 | self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device) 112 | else: 113 | self.cri_perceptual = None 114 | 115 | # L1 loss used in pyramid loss, component style loss and identity loss 116 | self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device) 117 | 118 | # gan loss (wgan) 119 | self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device) 120 | 121 | # ----------- define identity loss ----------- # 122 | if 'network_identity' in self.opt: 123 | self.use_identity = True 124 | else: 125 | self.use_identity = False 126 | 127 | if self.use_identity: 128 | # define identity network 129 | self.network_identity = build_network(self.opt['network_identity']) 130 | self.network_identity = self.model_to_device(self.network_identity) 131 | self.print_network(self.network_identity) 132 | load_path = self.opt['path'].get('pretrain_network_identity') 133 | if load_path is not None: 134 | self.load_network(self.network_identity, load_path, True, None) 135 | self.network_identity.eval() 136 | for param in self.network_identity.parameters(): 137 | param.requires_grad = False 138 | 139 | # regularization weights 140 | self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator 141 | self.net_d_iters = train_opt.get('net_d_iters', 1) 142 | self.net_d_init_iters = train_opt.get('net_d_init_iters', 0) 143 | self.net_d_reg_every = train_opt['net_d_reg_every'] 144 | 145 | # set up optimizers and schedulers 146 | self.setup_optimizers() 147 | self.setup_schedulers() 148 | 149 | def setup_optimizers(self): 150 | train_opt = self.opt['train'] 151 | 152 | # ----------- optimizer g ----------- # 153 | net_g_reg_ratio = 1 154 | normal_params = [] 155 | for _, param in self.net_g.named_parameters(): 156 | normal_params.append(param) 157 | optim_params_g = [{ # add normal params first 158 | 'params': normal_params, 159 | 'lr': train_opt['optim_g']['lr'] 160 | }] 161 | optim_type = train_opt['optim_g'].pop('type') 162 | lr = train_opt['optim_g']['lr'] * net_g_reg_ratio 163 | betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio) 164 | self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas) 165 | self.optimizers.append(self.optimizer_g) 166 | 167 | # ----------- optimizer d ----------- # 168 | net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1) 169 | normal_params = [] 170 | for _, param in self.net_d.named_parameters(): 171 | normal_params.append(param) 172 | optim_params_d = [{ # add normal params first 173 | 'params': normal_params, 174 | 'lr': train_opt['optim_d']['lr'] 175 | }] 176 | optim_type = train_opt['optim_d'].pop('type') 177 | lr = train_opt['optim_d']['lr'] * net_d_reg_ratio 178 | betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio) 179 | self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas) 180 | self.optimizers.append(self.optimizer_d) 181 | 182 | if self.use_facial_disc: 183 | # setup optimizers for facial component discriminators 184 | optim_type = train_opt['optim_component'].pop('type') 185 | lr = train_opt['optim_component']['lr'] 186 | # left eye 187 | self.optimizer_d_left_eye = self.get_optimizer( 188 | optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99)) 189 | self.optimizers.append(self.optimizer_d_left_eye) 190 | # right eye 191 | self.optimizer_d_right_eye = self.get_optimizer( 192 | optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99)) 193 | self.optimizers.append(self.optimizer_d_right_eye) 194 | # mouth 195 | self.optimizer_d_mouth = self.get_optimizer( 196 | optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99)) 197 | self.optimizers.append(self.optimizer_d_mouth) 198 | 199 | def feed_data(self, data): 200 | self.lq = data['lq'].to(self.device) 201 | if 'gt' in data: 202 | self.gt = data['gt'].to(self.device) 203 | 204 | if 'loc_left_eye' in data: 205 | # get facial component locations, shape (batch, 4) 206 | self.loc_left_eyes = data['loc_left_eye'] 207 | self.loc_right_eyes = data['loc_right_eye'] 208 | self.loc_mouths = data['loc_mouth'] 209 | 210 | # uncomment to check data 211 | # import torchvision 212 | # if self.opt['rank'] == 0: 213 | # import os 214 | # os.makedirs('tmp/gt', exist_ok=True) 215 | # os.makedirs('tmp/lq', exist_ok=True) 216 | # print(self.idx) 217 | # torchvision.utils.save_image( 218 | # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) 219 | # torchvision.utils.save_image( 220 | # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) 221 | # self.idx = self.idx + 1 222 | 223 | def construct_img_pyramid(self): 224 | pyramid_gt = [self.gt] 225 | down_img = self.gt 226 | for _ in range(0, self.log_size - 3): 227 | down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False) 228 | pyramid_gt.insert(0, down_img) 229 | return pyramid_gt 230 | 231 | def get_roi_regions(self, eye_out_size=80, mouth_out_size=120): 232 | # hard code 233 | face_ratio = int(self.opt['network_g']['out_size'] / 512) 234 | eye_out_size *= face_ratio 235 | mouth_out_size *= face_ratio 236 | 237 | rois_eyes = [] 238 | rois_mouths = [] 239 | for b in range(self.loc_left_eyes.size(0)): # loop for batch size 240 | # left eye and right eye 241 | img_inds = self.loc_left_eyes.new_full((2, 1), b) 242 | bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4) 243 | rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5) 244 | rois_eyes.append(rois) 245 | # mouse 246 | img_inds = self.loc_left_eyes.new_full((1, 1), b) 247 | rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5) 248 | rois_mouths.append(rois) 249 | 250 | rois_eyes = torch.cat(rois_eyes, 0).to(self.device) 251 | rois_mouths = torch.cat(rois_mouths, 0).to(self.device) 252 | 253 | # real images 254 | all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio 255 | self.left_eyes_gt = all_eyes[0::2, :, :, :] 256 | self.right_eyes_gt = all_eyes[1::2, :, :, :] 257 | self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio 258 | # output 259 | all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio 260 | self.left_eyes = all_eyes[0::2, :, :, :] 261 | self.right_eyes = all_eyes[1::2, :, :, :] 262 | self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio 263 | 264 | def _gram_mat(self, x): 265 | """Calculate Gram matrix. 266 | 267 | Args: 268 | x (torch.Tensor): Tensor with shape of (n, c, h, w). 269 | 270 | Returns: 271 | torch.Tensor: Gram matrix. 272 | """ 273 | n, c, h, w = x.size() 274 | features = x.view(n, c, w * h) 275 | features_t = features.transpose(1, 2) 276 | gram = features.bmm(features_t) / (c * h * w) 277 | return gram 278 | 279 | def gray_resize_for_identity(self, out, size=128): 280 | out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :]) 281 | out_gray = out_gray.unsqueeze(1) 282 | out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False) 283 | return out_gray 284 | 285 | def optimize_parameters(self, current_iter): 286 | # optimize net_g 287 | for p in self.net_d.parameters(): 288 | p.requires_grad = False 289 | self.optimizer_g.zero_grad() 290 | 291 | if self.use_facial_disc: 292 | for p in self.net_d_left_eye.parameters(): 293 | p.requires_grad = False 294 | for p in self.net_d_right_eye.parameters(): 295 | p.requires_grad = False 296 | for p in self.net_d_mouth.parameters(): 297 | p.requires_grad = False 298 | 299 | # image pyramid loss weight 300 | if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')): 301 | pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1) 302 | else: 303 | pyramid_loss_weight = 1e-12 # very small loss 304 | if pyramid_loss_weight > 0: 305 | self.output, out_rgbs = self.net_g(self.lq, return_rgb=True) 306 | pyramid_gt = self.construct_img_pyramid() 307 | else: 308 | self.output, out_rgbs = self.net_g(self.lq, return_rgb=False) 309 | 310 | # get roi-align regions 311 | if self.use_facial_disc: 312 | self.get_roi_regions(eye_out_size=80, mouth_out_size=120) 313 | 314 | l_g_total = 0 315 | loss_dict = OrderedDict() 316 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 317 | # pixel loss 318 | if self.cri_pix: 319 | l_g_pix = self.cri_pix(self.output, self.gt) 320 | l_g_total += l_g_pix 321 | loss_dict['l_g_pix'] = l_g_pix 322 | 323 | # image pyramid loss 324 | if pyramid_loss_weight > 0: 325 | for i in range(0, self.log_size - 2): 326 | l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight 327 | l_g_total += l_pyramid 328 | loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid 329 | 330 | # perceptual loss 331 | if self.cri_perceptual: 332 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 333 | if l_g_percep is not None: 334 | l_g_total += l_g_percep 335 | loss_dict['l_g_percep'] = l_g_percep 336 | if l_g_style is not None: 337 | l_g_total += l_g_style 338 | loss_dict['l_g_style'] = l_g_style 339 | 340 | # gan loss 341 | fake_g_pred = self.net_d(self.output) 342 | l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) 343 | l_g_total += l_g_gan 344 | loss_dict['l_g_gan'] = l_g_gan 345 | 346 | # facial component loss 347 | if self.use_facial_disc: 348 | # left eye 349 | fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True) 350 | l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False) 351 | l_g_total += l_g_gan 352 | loss_dict['l_g_gan_left_eye'] = l_g_gan 353 | # right eye 354 | fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True) 355 | l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False) 356 | l_g_total += l_g_gan 357 | loss_dict['l_g_gan_right_eye'] = l_g_gan 358 | # mouth 359 | fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True) 360 | l_g_gan = self.cri_component(fake_mouth, True, is_disc=False) 361 | l_g_total += l_g_gan 362 | loss_dict['l_g_gan_mouth'] = l_g_gan 363 | 364 | if self.opt['train'].get('comp_style_weight', 0) > 0: 365 | # get gt feat 366 | _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True) 367 | _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True) 368 | _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True) 369 | 370 | def _comp_style(feat, feat_gt, criterion): 371 | return criterion(self._gram_mat(feat[0]), self._gram_mat( 372 | feat_gt[0].detach())) * 0.5 + criterion( 373 | self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach())) 374 | 375 | # facial component style loss 376 | comp_style_loss = 0 377 | comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1) 378 | comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1) 379 | comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1) 380 | comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight'] 381 | l_g_total += comp_style_loss 382 | loss_dict['l_g_comp_style_loss'] = comp_style_loss 383 | 384 | # identity loss 385 | if self.use_identity: 386 | identity_weight = self.opt['train']['identity_weight'] 387 | # get gray images and resize 388 | out_gray = self.gray_resize_for_identity(self.output) 389 | gt_gray = self.gray_resize_for_identity(self.gt) 390 | 391 | identity_gt = self.network_identity(gt_gray).detach() 392 | identity_out = self.network_identity(out_gray) 393 | l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight 394 | l_g_total += l_identity 395 | loss_dict['l_identity'] = l_identity 396 | 397 | l_g_total.backward() 398 | self.optimizer_g.step() 399 | 400 | # EMA 401 | self.model_ema(decay=0.5**(32 / (10 * 1000))) 402 | 403 | # ----------- optimize net_d ----------- # 404 | for p in self.net_d.parameters(): 405 | p.requires_grad = True 406 | self.optimizer_d.zero_grad() 407 | if self.use_facial_disc: 408 | for p in self.net_d_left_eye.parameters(): 409 | p.requires_grad = True 410 | for p in self.net_d_right_eye.parameters(): 411 | p.requires_grad = True 412 | for p in self.net_d_mouth.parameters(): 413 | p.requires_grad = True 414 | self.optimizer_d_left_eye.zero_grad() 415 | self.optimizer_d_right_eye.zero_grad() 416 | self.optimizer_d_mouth.zero_grad() 417 | 418 | fake_d_pred = self.net_d(self.output.detach()) 419 | real_d_pred = self.net_d(self.gt) 420 | l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True) 421 | loss_dict['l_d'] = l_d 422 | # In wgan, real_score should be positive and fake_score should benegative 423 | loss_dict['real_score'] = real_d_pred.detach().mean() 424 | loss_dict['fake_score'] = fake_d_pred.detach().mean() 425 | l_d.backward() 426 | 427 | if current_iter % self.net_d_reg_every == 0: 428 | self.gt.requires_grad = True 429 | real_pred = self.net_d(self.gt) 430 | l_d_r1 = r1_penalty(real_pred, self.gt) 431 | l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]) 432 | loss_dict['l_d_r1'] = l_d_r1.detach().mean() 433 | l_d_r1.backward() 434 | 435 | self.optimizer_d.step() 436 | 437 | if self.use_facial_disc: 438 | # lefe eye 439 | fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach()) 440 | real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt) 441 | l_d_left_eye = self.cri_component( 442 | real_d_pred, True, is_disc=True) + self.cri_gan( 443 | fake_d_pred, False, is_disc=True) 444 | loss_dict['l_d_left_eye'] = l_d_left_eye 445 | l_d_left_eye.backward() 446 | # right eye 447 | fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach()) 448 | real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt) 449 | l_d_right_eye = self.cri_component( 450 | real_d_pred, True, is_disc=True) + self.cri_gan( 451 | fake_d_pred, False, is_disc=True) 452 | loss_dict['l_d_right_eye'] = l_d_right_eye 453 | l_d_right_eye.backward() 454 | # mouth 455 | fake_d_pred, _ = self.net_d_mouth(self.mouths.detach()) 456 | real_d_pred, _ = self.net_d_mouth(self.mouths_gt) 457 | l_d_mouth = self.cri_component( 458 | real_d_pred, True, is_disc=True) + self.cri_gan( 459 | fake_d_pred, False, is_disc=True) 460 | loss_dict['l_d_mouth'] = l_d_mouth 461 | l_d_mouth.backward() 462 | 463 | self.optimizer_d_left_eye.step() 464 | self.optimizer_d_right_eye.step() 465 | self.optimizer_d_mouth.step() 466 | 467 | self.log_dict = self.reduce_loss_dict(loss_dict) 468 | 469 | def test(self): 470 | with torch.no_grad(): 471 | if hasattr(self, 'net_g_ema'): 472 | self.net_g_ema.eval() 473 | self.output, _ = self.net_g_ema(self.lq) 474 | else: 475 | logger = get_root_logger() 476 | logger.warning('Do not have self.net_g_ema, use self.net_g.') 477 | self.net_g.eval() 478 | self.output, _ = self.net_g(self.lq) 479 | self.net_g.train() 480 | 481 | def dist_validation(self, dataloader, current_iter, tb_logger, save_img): 482 | if self.opt['rank'] == 0: 483 | self.nondist_validation(dataloader, current_iter, tb_logger, save_img) 484 | 485 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): 486 | dataset_name = dataloader.dataset.opt['name'] 487 | with_metrics = self.opt['val'].get('metrics') is not None 488 | if with_metrics: 489 | self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} 490 | pbar = tqdm(total=len(dataloader), unit='image') 491 | 492 | for idx, val_data in enumerate(dataloader): 493 | img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] 494 | self.feed_data(val_data) 495 | self.test() 496 | 497 | visuals = self.get_current_visuals() 498 | sr_img = tensor2img([visuals['sr']], min_max=(-1, 1)) 499 | gt_img = tensor2img([visuals['gt']], min_max=(-1, 1)) 500 | 501 | if 'gt' in visuals: 502 | gt_img = tensor2img([visuals['gt']], min_max=(-1, 1)) 503 | del self.gt 504 | # tentative for out of GPU memory 505 | del self.lq 506 | del self.output 507 | torch.cuda.empty_cache() 508 | 509 | if save_img: 510 | if self.opt['is_train']: 511 | save_img_path = osp.join(self.opt['path']['visualization'], img_name, 512 | f'{img_name}_{current_iter}.png') 513 | else: 514 | if self.opt['val']['suffix']: 515 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, 516 | f'{img_name}_{self.opt["val"]["suffix"]}.png') 517 | else: 518 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, 519 | f'{img_name}_{self.opt["name"]}.png') 520 | imwrite(sr_img, save_img_path) 521 | 522 | if with_metrics: 523 | # calculate metrics 524 | for name, opt_ in self.opt['val']['metrics'].items(): 525 | metric_data = dict(img1=sr_img, img2=gt_img) 526 | self.metric_results[name] += calculate_metric(metric_data, opt_) 527 | pbar.update(1) 528 | pbar.set_description(f'Test {img_name}') 529 | pbar.close() 530 | 531 | if with_metrics: 532 | for metric in self.metric_results.keys(): 533 | self.metric_results[metric] /= (idx + 1) 534 | 535 | self._log_validation_metric_values(current_iter, dataset_name, tb_logger) 536 | 537 | def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): 538 | log_str = f'Validation {dataset_name}\n' 539 | for metric, value in self.metric_results.items(): 540 | log_str += f'\t # {metric}: {value:.4f}\n' 541 | logger = get_root_logger() 542 | logger.info(log_str) 543 | if tb_logger: 544 | for metric, value in self.metric_results.items(): 545 | tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) 546 | 547 | def get_current_visuals(self): 548 | out_dict = OrderedDict() 549 | out_dict['gt'] = self.gt.detach().cpu() 550 | out_dict['sr'] = self.output.detach().cpu() 551 | return out_dict 552 | 553 | def save(self, epoch, current_iter): 554 | self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) 555 | self.save_network(self.net_d, 'net_d', current_iter) 556 | # save component discriminators 557 | if self.use_facial_disc: 558 | self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter) 559 | self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter) 560 | self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter) 561 | self.save_training_state(epoch, current_iter) 562 | -------------------------------------------------------------------------------- /gfpgan/train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | from basicsr.train import train_pipeline 4 | 5 | import gfpgan.archs 6 | import gfpgan.data 7 | import gfpgan.models 8 | 9 | if __name__ == '__main__': 10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 11 | train_pipeline(root_path) 12 | -------------------------------------------------------------------------------- /gfpgan/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | from basicsr.utils import img2tensor, tensor2img 5 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 6 | from torch.hub import download_url_to_file, get_dir 7 | from torchvision.transforms.functional import normalize 8 | from urllib.parse import urlparse 9 | 10 | from gfpgan.archs.gfpganv1_arch import GFPGANv1 11 | from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean 12 | 13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | 15 | 16 | class GFPGANer(): 17 | 18 | def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None): 19 | self.upscale = upscale 20 | self.bg_upsampler = bg_upsampler 21 | 22 | # initialize model 23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | # initialize the GFP-GAN 25 | if arch == 'clean': 26 | self.gfpgan = GFPGANv1Clean( 27 | out_size=512, 28 | num_style_feat=512, 29 | channel_multiplier=channel_multiplier, 30 | decoder_load_path=None, 31 | fix_decoder=False, 32 | num_mlp=8, 33 | input_is_latent=True, 34 | different_w=True, 35 | narrow=1, 36 | sft_half=True) 37 | else: 38 | self.gfpgan = GFPGANv1( 39 | out_size=512, 40 | num_style_feat=512, 41 | channel_multiplier=channel_multiplier, 42 | decoder_load_path=None, 43 | fix_decoder=True, 44 | num_mlp=8, 45 | input_is_latent=True, 46 | different_w=True, 47 | narrow=1, 48 | sft_half=True) 49 | # initialize face helper 50 | self.face_helper = FaceRestoreHelper( 51 | upscale, 52 | face_size=512, 53 | crop_ratio=(1, 1), 54 | det_model='retinaface_resnet50', 55 | save_ext='png', 56 | device=self.device) 57 | 58 | if model_path.startswith('https://'): 59 | model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None) 60 | loadnet = torch.load(model_path) 61 | if 'params_ema' in loadnet: 62 | keyname = 'params_ema' 63 | else: 64 | keyname = 'params' 65 | self.gfpgan.load_state_dict(loadnet[keyname], strict=True) 66 | self.gfpgan.eval() 67 | self.gfpgan = self.gfpgan.to(self.device) 68 | 69 | @torch.no_grad() 70 | def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True): 71 | self.face_helper.clean_all() 72 | 73 | if has_aligned: 74 | img = cv2.resize(img, (512, 512)) 75 | self.face_helper.cropped_faces = [img] 76 | else: 77 | self.face_helper.read_image(img) 78 | # get face landmarks for each face 79 | self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) 80 | # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels 81 | # align and warp each face 82 | self.face_helper.align_warp_face() 83 | 84 | # face restoration 85 | for cropped_face in self.face_helper.cropped_faces: 86 | # prepare data 87 | cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) 88 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 89 | cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) 90 | 91 | try: 92 | output = self.gfpgan(cropped_face_t, return_rgb=False)[0] 93 | # convert to image 94 | restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) 95 | except RuntimeError as error: 96 | print(f'\tFailed inference for GFPGAN: {error}.') 97 | restored_face = cropped_face 98 | 99 | restored_face = restored_face.astype('uint8') 100 | self.face_helper.add_restored_face(restored_face) 101 | 102 | if not has_aligned and paste_back: 103 | 104 | if self.bg_upsampler is not None: 105 | # Now only support RealESRGAN 106 | bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] 107 | else: 108 | bg_img = None 109 | 110 | self.face_helper.get_inverse_affine(None) 111 | # paste each restored face to the input image 112 | restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) 113 | return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img 114 | else: 115 | return self.face_helper.cropped_faces, self.face_helper.restored_faces, None 116 | 117 | 118 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 119 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 120 | """ 121 | if model_dir is None: 122 | hub_dir = get_dir() 123 | model_dir = os.path.join(hub_dir, 'checkpoints') 124 | 125 | os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) 126 | 127 | parts = urlparse(url) 128 | filename = os.path.basename(parts.path) 129 | if file_name is not None: 130 | filename = file_name 131 | cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) 132 | if not os.path.exists(cached_file): 133 | print(f'Downloading: "{url}" to {cached_file}\n') 134 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 135 | return cached_file 136 | -------------------------------------------------------------------------------- /gfpgan/weights/README.md: -------------------------------------------------------------------------------- 1 | # Weights 2 | 3 | Put the downloaded weights to this folder. 4 | -------------------------------------------------------------------------------- /inference_gfpgan.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import numpy as np 5 | import os 6 | import torch 7 | from basicsr.utils import imwrite 8 | 9 | from gfpgan import GFPGANer 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--upscale', type=int, default=2) 16 | parser.add_argument('--arch', type=str, default='clean') 17 | parser.add_argument('--channel', type=int, default=2) 18 | parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth') 19 | parser.add_argument('--bg_upsampler', type=str, default='realesrgan') 20 | parser.add_argument('--bg_tile', type=int, default=400) 21 | parser.add_argument('--test_path', type=str, default='inputs/whole_imgs') 22 | parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') 23 | parser.add_argument('--only_center_face', action='store_true') 24 | parser.add_argument('--aligned', action='store_true') 25 | parser.add_argument('--paste_back', action='store_false') 26 | parser.add_argument('--save_root', type=str, default='results') 27 | parser.add_argument( 28 | '--ext', 29 | type=str, 30 | default='auto', 31 | help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') 32 | args = parser.parse_args() 33 | 34 | args = parser.parse_args() 35 | if args.test_path.endswith('/'): 36 | args.test_path = args.test_path[:-1] 37 | os.makedirs(args.save_root, exist_ok=True) 38 | 39 | # background upsampler 40 | if args.bg_upsampler == 'realesrgan': 41 | if not torch.cuda.is_available(): # CPU 42 | import warnings 43 | warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. ' 44 | 'If you really want to use it, please modify the corresponding codes.') 45 | bg_upsampler = None 46 | else: 47 | from realesrgan import RealESRGANer 48 | bg_upsampler = RealESRGANer( 49 | scale=2, 50 | model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 51 | tile=args.bg_tile, 52 | tile_pad=10, 53 | pre_pad=0, 54 | half=True) # need to set False in CPU mode 55 | else: 56 | bg_upsampler = None 57 | # set up GFPGAN restorer 58 | restorer = GFPGANer( 59 | model_path=args.model_path, 60 | upscale=args.upscale, 61 | arch=args.arch, 62 | channel_multiplier=args.channel, 63 | bg_upsampler=bg_upsampler) 64 | 65 | img_list = sorted(glob.glob(os.path.join(args.test_path, '*'))) 66 | for img_path in img_list: 67 | # read image 68 | img_name = os.path.basename(img_path) 69 | print(f'Processing {img_name} ...') 70 | basename, ext = os.path.splitext(img_name) 71 | input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) 72 | 73 | cropped_faces, restored_faces, restored_img = restorer.enhance( 74 | input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back) 75 | 76 | # save faces 77 | for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)): 78 | # save cropped face 79 | save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png') 80 | imwrite(cropped_face, save_crop_path) 81 | # save restored face 82 | if args.suffix is not None: 83 | save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png' 84 | else: 85 | save_face_name = f'{basename}_{idx:02d}.png' 86 | save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name) 87 | imwrite(restored_face, save_restore_path) 88 | # save cmp image 89 | cmp_img = np.concatenate((cropped_face, restored_face), axis=1) 90 | imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png')) 91 | 92 | # save restored img 93 | if restored_img is not None: 94 | if args.ext == 'auto': 95 | extension = ext[1:] 96 | else: 97 | extension = args.ext 98 | 99 | if args.suffix is not None: 100 | save_restore_path = os.path.join(args.save_root, 'restored_imgs', 101 | f'{basename}_{args.suffix}.{extension}') 102 | else: 103 | save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}') 104 | imwrite(restored_img, save_restore_path) 105 | 106 | print(f'Results are in the [{args.save_root}] folder.') 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /insightface_func/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/insightface_func/__init__.py -------------------------------------------------------------------------------- /insightface_func/face_detect_crop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-05-04 5 | # @Function : 6 | 7 | 8 | from __future__ import division 9 | import collections 10 | import numpy as np 11 | import glob 12 | import os 13 | import os.path as osp 14 | from numpy.linalg import norm 15 | from insightface.model_zoo import model_zoo 16 | from insightface_func.utils import face_align 17 | 18 | __all__ = ['Face_detect_crop', 'Face'] 19 | 20 | Face = collections.namedtuple('Face', [ 21 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age', 22 | 'embedding_norm', 'normed_embedding', 23 | 'landmark' 24 | ]) 25 | 26 | Face.__new__.__defaults__ = (None, ) * len(Face._fields) 27 | 28 | 29 | class Face_detect_crop: 30 | def __init__(self, name, root='~/.insightface_func/models'): 31 | self.models = {} 32 | root = os.path.expanduser(root) 33 | onnx_files = glob.glob(osp.join(root, name, '*.onnx')) 34 | onnx_files = sorted(onnx_files) 35 | for onnx_file in onnx_files: 36 | if onnx_file.find('_selfgen_')>0: 37 | #print('ignore:', onnx_file) 38 | continue 39 | model = model_zoo.get_model(onnx_file) 40 | if model.taskname not in self.models: 41 | print('find model:', onnx_file, model.taskname) 42 | self.models[model.taskname] = model 43 | else: 44 | print('duplicated model task type, ignore:', onnx_file, model.taskname) 45 | del model 46 | assert 'detection' in self.models 47 | self.det_model = self.models['detection'] 48 | 49 | 50 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): 51 | self.det_thresh = det_thresh 52 | assert det_size is not None 53 | print('set det-size:', det_size) 54 | self.det_size = det_size 55 | for taskname, model in self.models.items(): 56 | if taskname=='detection': 57 | model.prepare(ctx_id, input_size=det_size) 58 | else: 59 | model.prepare(ctx_id) 60 | 61 | def get(self, img, crop_size, max_num=0, mode = 'None'): 62 | bboxes, kpss = self.det_model.detect(img, 63 | threshold=self.det_thresh, 64 | max_num=max_num, 65 | metric='default') 66 | if bboxes.shape[0] == 0: 67 | return [] 68 | ret = [] 69 | if mode == 'Both': 70 | for i in range(bboxes.shape[0]): 71 | kps = None 72 | if kpss is not None: 73 | kps = kpss[i] 74 | aimg_None,aimg_arface = face_align.norm_crop(img, kps,crop_size,mode =mode) 75 | return [aimg_None,aimg_arface] 76 | 77 | else: 78 | for i in range(bboxes.shape[0]): 79 | kps = None 80 | if kpss is not None: 81 | kps = kpss[i] 82 | aimg = face_align.norm_crop(img, kps,crop_size,mode =mode) 83 | return [aimg] 84 | 85 | def draw_on(self, img, faces): 86 | import cv2 87 | for i in range(len(faces)): 88 | face = faces[i] 89 | box = face.bbox.astype(np.int) 90 | color = (0, 0, 255) 91 | cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, 2) 92 | if face.kps is not None: 93 | kps = face.kps.astype(np.int) 94 | #print(landmark.shape) 95 | for l in range(kps.shape[0]): 96 | color = (0, 0, 255) 97 | if l == 0 or l == 3: 98 | color = (0, 255, 0) 99 | cv2.circle(img, (kps[l][0], kps[l][1]), 1, color, 100 | 2) 101 | return img 102 | 103 | -------------------------------------------------------------------------------- /insightface_func/face_detect_crop_ffhq_newarcAlign.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-15 19:42:42 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-16 15:35:02 7 | Description: 8 | ''' 9 | 10 | from __future__ import division 11 | import collections 12 | import numpy as np 13 | import glob 14 | import os 15 | import os.path as osp 16 | from insightface.model_zoo import model_zoo 17 | from insightface_func.utils import face_align_ffhqandnewarc as face_align 18 | 19 | __all__ = ['Face_detect_crop', 'Face'] 20 | 21 | Face = collections.namedtuple('Face', [ 22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age', 23 | 'embedding_norm', 'normed_embedding', 24 | 'landmark' 25 | ]) 26 | 27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields) 28 | 29 | 30 | class Face_detect_crop: 31 | def __init__(self, name, root='~/.insightface_func/models'): 32 | self.models = {} 33 | root = os.path.expanduser(root) 34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx')) 35 | onnx_files = sorted(onnx_files) 36 | for onnx_file in onnx_files: 37 | if onnx_file.find('_selfgen_')>0: 38 | #print('ignore:', onnx_file) 39 | continue 40 | model = model_zoo.get_model(onnx_file) 41 | if model.taskname not in self.models: 42 | print('find model:', onnx_file, model.taskname) 43 | self.models[model.taskname] = model 44 | else: 45 | print('duplicated model task type, ignore:', onnx_file, model.taskname) 46 | del model 47 | assert 'detection' in self.models 48 | self.det_model = self.models['detection'] 49 | 50 | 51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): 52 | self.det_thresh = det_thresh 53 | assert det_size is not None 54 | print('set det-size:', det_size) 55 | self.det_size = det_size 56 | for taskname, model in self.models.items(): 57 | if taskname=='detection': 58 | model.prepare(ctx_id, input_size=det_size) 59 | else: 60 | model.prepare(ctx_id) 61 | 62 | def get(self, img, crop_size, max_num=0, mode = 'ffhq'): 63 | bboxes, kpss = self.det_model.detect(img, 64 | threshold=self.det_thresh, 65 | max_num=max_num, 66 | metric='default') 67 | if bboxes.shape[0] == 0: 68 | return [] 69 | ret = [] 70 | if mode == 'Both': 71 | for i in range(bboxes.shape[0]): 72 | kps = None 73 | if kpss is not None: 74 | kps = kpss[i] 75 | aimg_ffhq,aimg_None = face_align.norm_crop(img, kps,crop_size,mode =mode) 76 | return [aimg_ffhq,aimg_None] 77 | 78 | else: 79 | for i in range(bboxes.shape[0]): 80 | kps = None 81 | if kpss is not None: 82 | kps = kpss[i] 83 | aimg = face_align.norm_crop(img, kps,crop_size,mode =mode) 84 | return [aimg] 85 | 86 | def draw_on(self, img, faces): 87 | import cv2 88 | for i in range(len(faces)): 89 | face = faces[i] 90 | box = face.bbox.astype(np.int) 91 | color = (0, 0, 255) 92 | cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, 2) 93 | if face.kps is not None: 94 | kps = face.kps.astype(np.int) 95 | #print(landmark.shape) 96 | for l in range(kps.shape[0]): 97 | color = (0, 0, 255) 98 | if l == 0 or l == 3: 99 | color = (0, 255, 0) 100 | cv2.circle(img, (kps[l][0], kps[l][1]), 1, color, 101 | 2) 102 | return img 103 | 104 | -------------------------------------------------------------------------------- /insightface_func/utils/face_align.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Organization : insightface.ai 3 | # @Author : Jia Guo 4 | # @Time : 2021-05-04 5 | # @Function : 6 | 7 | import cv2 8 | import numpy as np 9 | from skimage import transform as trans 10 | 11 | src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007], 12 | [51.157, 89.050], [57.025, 89.702]], 13 | dtype=np.float32) 14 | #<--left 15 | src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111], 16 | [45.177, 86.190], [64.246, 86.758]], 17 | dtype=np.float32) 18 | 19 | #---frontal 20 | src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493], 21 | [42.463, 87.010], [69.537, 87.010]], 22 | dtype=np.float32) 23 | 24 | #-->right 25 | src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111], 26 | [48.167, 86.758], [67.236, 86.190]], 27 | dtype=np.float32) 28 | 29 | #-->right profile 30 | src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007], 31 | [55.388, 89.702], [61.257, 89.050]], 32 | dtype=np.float32) 33 | 34 | src = np.array([src1, src2, src3, src4, src5]) 35 | src_map = src 36 | 37 | arcface_src = np.array( 38 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], 39 | [41.5493, 92.3655], [70.7299, 92.2041]], 40 | dtype=np.float32) 41 | 42 | arcface_src = np.expand_dims(arcface_src, axis=0) 43 | 44 | # In[66]: 45 | 46 | 47 | # lmk is prediction; src is template 48 | def estimate_norm(lmk, image_size=112, mode='arcface'): 49 | assert lmk.shape == (5, 2) 50 | tform = trans.SimilarityTransform() 51 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1) 52 | min_M = [] 53 | min_index = [] 54 | min_error = float('inf') 55 | if mode == 'arcface': 56 | assert image_size == 112 57 | src = arcface_src 58 | else: 59 | src = src_map * image_size / 112 60 | for i in np.arange(src.shape[0]): 61 | tform.estimate(lmk, src[i]) 62 | M = tform.params[0:2, :] 63 | results = np.dot(M, lmk_tran.T) 64 | results = results.T 65 | error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1))) 66 | # print(error) 67 | if error < min_error: 68 | min_error = error 69 | min_M = M 70 | min_index = i 71 | return min_M, min_index 72 | 73 | 74 | def norm_crop(img, landmark, image_size=112, mode='arcface'): 75 | if mode == 'Both': 76 | M_None, _ = estimate_norm(landmark, image_size, mode = 'None') 77 | M_arcface, _ = estimate_norm(landmark, 112, mode='arcface') 78 | warped_None = cv2.warpAffine(img, M_None, (image_size, image_size), borderValue=0.0) 79 | warped_arcface = cv2.warpAffine(img, M_arcface, (112, 112), borderValue=0.0) 80 | return warped_None, warped_arcface 81 | else: 82 | M, pose_index = estimate_norm(landmark, image_size, mode) 83 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 84 | return warped 85 | 86 | def square_crop(im, S): 87 | if im.shape[0] > im.shape[1]: 88 | height = S 89 | width = int(float(im.shape[1]) / im.shape[0] * S) 90 | scale = float(S) / im.shape[0] 91 | else: 92 | width = S 93 | height = int(float(im.shape[0]) / im.shape[1] * S) 94 | scale = float(S) / im.shape[1] 95 | resized_im = cv2.resize(im, (width, height)) 96 | det_im = np.zeros((S, S, 3), dtype=np.uint8) 97 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im 98 | return det_im, scale 99 | 100 | 101 | def transform(data, center, output_size, scale, rotation): 102 | scale_ratio = scale 103 | rot = float(rotation) * np.pi / 180.0 104 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 105 | t1 = trans.SimilarityTransform(scale=scale_ratio) 106 | cx = center[0] * scale_ratio 107 | cy = center[1] * scale_ratio 108 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 109 | t3 = trans.SimilarityTransform(rotation=rot) 110 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 111 | output_size / 2)) 112 | t = t1 + t2 + t3 + t4 113 | M = t.params[0:2] 114 | cropped = cv2.warpAffine(data, 115 | M, (output_size, output_size), 116 | borderValue=0.0) 117 | return cropped, M 118 | 119 | 120 | def trans_points2d(pts, M): 121 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 122 | for i in range(pts.shape[0]): 123 | pt = pts[i] 124 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 125 | new_pt = np.dot(M, new_pt) 126 | #print('new_pt', new_pt.shape, new_pt) 127 | new_pts[i] = new_pt[0:2] 128 | 129 | return new_pts 130 | 131 | 132 | def trans_points3d(pts, M): 133 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 134 | #print(scale) 135 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 136 | for i in range(pts.shape[0]): 137 | pt = pts[i] 138 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 139 | new_pt = np.dot(M, new_pt) 140 | #print('new_pt', new_pt.shape, new_pt) 141 | new_pts[i][0:2] = new_pt[0:2] 142 | new_pts[i][2] = pts[i][2] * scale 143 | 144 | return new_pts 145 | 146 | 147 | def trans_points(pts, M): 148 | if pts.shape[1] == 2: 149 | return trans_points2d(pts, M) 150 | else: 151 | return trans_points3d(pts, M) 152 | 153 | -------------------------------------------------------------------------------- /insightface_func/utils/face_align_ffhqandnewarc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-15 19:42:42 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-15 20:01:47 7 | Description: 8 | ''' 9 | 10 | import cv2 11 | import numpy as np 12 | from skimage import transform as trans 13 | 14 | src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007], 15 | [51.157, 89.050], [57.025, 89.702]], 16 | dtype=np.float32) 17 | #<--left 18 | src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111], 19 | [45.177, 86.190], [64.246, 86.758]], 20 | dtype=np.float32) 21 | 22 | #---frontal 23 | src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493], 24 | [42.463, 87.010], [69.537, 87.010]], 25 | dtype=np.float32) 26 | 27 | #-->right 28 | src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111], 29 | [48.167, 86.758], [67.236, 86.190]], 30 | dtype=np.float32) 31 | 32 | #-->right profile 33 | src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007], 34 | [55.388, 89.702], [61.257, 89.050]], 35 | dtype=np.float32) 36 | 37 | src = np.array([src1, src2, src3, src4, src5]) 38 | src_map = src 39 | 40 | ffhq_src = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], 41 | [201.26117, 371.41043], [313.08905, 371.15118]]) 42 | ffhq_src = np.expand_dims(ffhq_src, axis=0) 43 | 44 | # arcface_src = np.array( 45 | # [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], 46 | # [41.5493, 92.3655], [70.7299, 92.2041]], 47 | # dtype=np.float32) 48 | 49 | # arcface_src = np.expand_dims(arcface_src, axis=0) 50 | 51 | # In[66]: 52 | 53 | 54 | # lmk is prediction; src is template 55 | def estimate_norm(lmk, image_size=112, mode='ffhq'): 56 | assert lmk.shape == (5, 2) 57 | tform = trans.SimilarityTransform() 58 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1) 59 | min_M = [] 60 | min_index = [] 61 | min_error = float('inf') 62 | if mode == 'ffhq': 63 | # assert image_size == 112 64 | src = ffhq_src * image_size / 512 65 | else: 66 | src = src_map * image_size / 112 67 | for i in np.arange(src.shape[0]): 68 | tform.estimate(lmk, src[i]) 69 | M = tform.params[0:2, :] 70 | results = np.dot(M, lmk_tran.T) 71 | results = results.T 72 | error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1))) 73 | # print(error) 74 | if error < min_error: 75 | min_error = error 76 | min_M = M 77 | min_index = i 78 | return min_M, min_index 79 | 80 | 81 | def norm_crop(img, landmark, image_size=112, mode='ffhq'): 82 | if mode == 'Both': 83 | M_None, _ = estimate_norm(landmark, image_size, mode = 'newarc') 84 | M_ffhq, _ = estimate_norm(landmark, image_size, mode='ffhq') 85 | warped_None = cv2.warpAffine(img, M_None, (image_size, image_size), borderValue=0.0) 86 | warped_ffhq = cv2.warpAffine(img, M_ffhq, (image_size, image_size), borderValue=0.0) 87 | return warped_ffhq, warped_None 88 | else: 89 | M, pose_index = estimate_norm(landmark, image_size, mode) 90 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0) 91 | return warped 92 | 93 | def square_crop(im, S): 94 | if im.shape[0] > im.shape[1]: 95 | height = S 96 | width = int(float(im.shape[1]) / im.shape[0] * S) 97 | scale = float(S) / im.shape[0] 98 | else: 99 | width = S 100 | height = int(float(im.shape[0]) / im.shape[1] * S) 101 | scale = float(S) / im.shape[1] 102 | resized_im = cv2.resize(im, (width, height)) 103 | det_im = np.zeros((S, S, 3), dtype=np.uint8) 104 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im 105 | return det_im, scale 106 | 107 | 108 | def transform(data, center, output_size, scale, rotation): 109 | scale_ratio = scale 110 | rot = float(rotation) * np.pi / 180.0 111 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio) 112 | t1 = trans.SimilarityTransform(scale=scale_ratio) 113 | cx = center[0] * scale_ratio 114 | cy = center[1] * scale_ratio 115 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy)) 116 | t3 = trans.SimilarityTransform(rotation=rot) 117 | t4 = trans.SimilarityTransform(translation=(output_size / 2, 118 | output_size / 2)) 119 | t = t1 + t2 + t3 + t4 120 | M = t.params[0:2] 121 | cropped = cv2.warpAffine(data, 122 | M, (output_size, output_size), 123 | borderValue=0.0) 124 | return cropped, M 125 | 126 | 127 | def trans_points2d(pts, M): 128 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 129 | for i in range(pts.shape[0]): 130 | pt = pts[i] 131 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 132 | new_pt = np.dot(M, new_pt) 133 | #print('new_pt', new_pt.shape, new_pt) 134 | new_pts[i] = new_pt[0:2] 135 | 136 | return new_pts 137 | 138 | 139 | def trans_points3d(pts, M): 140 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1]) 141 | #print(scale) 142 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) 143 | for i in range(pts.shape[0]): 144 | pt = pts[i] 145 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32) 146 | new_pt = np.dot(M, new_pt) 147 | #print('new_pt', new_pt.shape, new_pt) 148 | new_pts[i][0:2] = new_pt[0:2] 149 | new_pts[i][2] = pts[i][2] * scale 150 | 151 | return new_pts 152 | 153 | 154 | def trans_points(pts, M): 155 | if pts.shape[1] == 2: 156 | return trans_points2d(pts, M) 157 | else: 158 | return trans_points3d(pts, M) 159 | 160 | -------------------------------------------------------------------------------- /options/train_gfpgan_v1.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_GFPGANv1_512 3 | model_type: GFPGANModel 4 | num_gpu: 4 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: FFHQ 11 | type: FFHQDegradationDataset 12 | # dataroot_gt: datasets/ffhq/ffhq_512.lmdb 13 | dataroot_gt: datasets/ffhq/ffhq_512 14 | io_backend: 15 | # type: lmdb 16 | type: disk 17 | 18 | use_hflip: true 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | out_size: 512 22 | 23 | blur_kernel_size: 41 24 | kernel_list: ['iso', 'aniso'] 25 | kernel_prob: [0.5, 0.5] 26 | blur_sigma: [0.1, 10] 27 | downsample_range: [0.8, 8] 28 | noise_range: [0, 20] 29 | jpeg_range: [60, 100] 30 | 31 | # color jitter and gray 32 | color_jitter_prob: 0.3 33 | color_jitter_shift: 20 34 | color_jitter_pt_prob: 0.3 35 | gray_prob: 0.01 36 | 37 | # If you do not want colorization, please set 38 | # color_jitter_prob: ~ 39 | # color_jitter_pt_prob: ~ 40 | # gray_prob: 0.01 41 | # gt_gray: True 42 | 43 | crop_components: true 44 | component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth 45 | eye_enlarge_ratio: 1.4 46 | 47 | # data loader 48 | use_shuffle: true 49 | num_worker_per_gpu: 6 50 | batch_size_per_gpu: 3 51 | dataset_enlarge_ratio: 1 52 | prefetch_mode: ~ 53 | 54 | val: 55 | # Please modify accordingly to use your own validation 56 | # Or comment the val block if do not need validation during training 57 | name: validation 58 | type: PairedImageDataset 59 | dataroot_lq: datasets/faces/validation/input 60 | dataroot_gt: datasets/faces/validation/reference 61 | io_backend: 62 | type: disk 63 | mean: [0.5, 0.5, 0.5] 64 | std: [0.5, 0.5, 0.5] 65 | scale: 1 66 | 67 | # network structures 68 | network_g: 69 | type: GFPGANv1 70 | out_size: 512 71 | num_style_feat: 512 72 | channel_multiplier: 1 73 | resample_kernel: [1, 3, 3, 1] 74 | decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth 75 | fix_decoder: true 76 | num_mlp: 8 77 | lr_mlp: 0.01 78 | input_is_latent: true 79 | different_w: true 80 | narrow: 1 81 | sft_half: true 82 | 83 | network_d: 84 | type: StyleGAN2Discriminator 85 | out_size: 512 86 | channel_multiplier: 1 87 | resample_kernel: [1, 3, 3, 1] 88 | 89 | network_d_left_eye: 90 | type: FacialComponentDiscriminator 91 | 92 | network_d_right_eye: 93 | type: FacialComponentDiscriminator 94 | 95 | network_d_mouth: 96 | type: FacialComponentDiscriminator 97 | 98 | network_identity: 99 | type: ResNetArcFace 100 | block: IRBlock 101 | layers: [2, 2, 2, 2] 102 | use_se: False 103 | 104 | # path 105 | path: 106 | pretrain_network_g: ~ 107 | param_key_g: params_ema 108 | strict_load_g: ~ 109 | pretrain_network_d: ~ 110 | pretrain_network_d_left_eye: ~ 111 | pretrain_network_d_right_eye: ~ 112 | pretrain_network_d_mouth: ~ 113 | pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth 114 | # resume 115 | resume_state: ~ 116 | ignore_resume_networks: ['network_identity'] 117 | 118 | # training settings 119 | train: 120 | optim_g: 121 | type: Adam 122 | lr: !!float 2e-3 123 | optim_d: 124 | type: Adam 125 | lr: !!float 2e-3 126 | optim_component: 127 | type: Adam 128 | lr: !!float 2e-3 129 | 130 | scheduler: 131 | type: MultiStepLR 132 | milestones: [600000, 700000] 133 | gamma: 0.5 134 | 135 | total_iter: 800000 136 | warmup_iter: -1 # no warm up 137 | 138 | # losses 139 | # pixel loss 140 | pixel_opt: 141 | type: L1Loss 142 | loss_weight: !!float 1e-1 143 | reduction: mean 144 | # L1 loss used in pyramid loss, component style loss and identity loss 145 | L1_opt: 146 | type: L1Loss 147 | loss_weight: 1 148 | reduction: mean 149 | 150 | # image pyramid loss 151 | pyramid_loss_weight: 1 152 | remove_pyramid_loss: 50000 153 | # perceptual loss (content and style losses) 154 | perceptual_opt: 155 | type: PerceptualLoss 156 | layer_weights: 157 | # before relu 158 | 'conv1_2': 0.1 159 | 'conv2_2': 0.1 160 | 'conv3_4': 1 161 | 'conv4_4': 1 162 | 'conv5_4': 1 163 | vgg_type: vgg19 164 | use_input_norm: true 165 | perceptual_weight: !!float 1 166 | style_weight: 50 167 | range_norm: true 168 | criterion: l1 169 | # gan loss 170 | gan_opt: 171 | type: GANLoss 172 | gan_type: wgan_softplus 173 | loss_weight: !!float 1e-1 174 | # r1 regularization for discriminator 175 | r1_reg_weight: 10 176 | # facial component loss 177 | gan_component_opt: 178 | type: GANLoss 179 | gan_type: vanilla 180 | real_label_val: 1.0 181 | fake_label_val: 0.0 182 | loss_weight: !!float 1 183 | comp_style_weight: 200 184 | # identity loss 185 | identity_weight: 10 186 | 187 | net_d_iters: 1 188 | net_d_init_iters: 0 189 | net_d_reg_every: 16 190 | 191 | # validation settings 192 | val: 193 | val_freq: !!float 5e3 194 | save_img: true 195 | 196 | metrics: 197 | psnr: # metric name, can be arbitrary 198 | type: calculate_psnr 199 | crop_border: 0 200 | test_y_channel: false 201 | 202 | # logging settings 203 | logger: 204 | print_freq: 100 205 | save_checkpoint_freq: !!float 5e3 206 | use_tb_logger: true 207 | wandb: 208 | project: ~ 209 | resume_id: ~ 210 | 211 | # dist training settings 212 | dist_params: 213 | backend: nccl 214 | port: 29500 215 | 216 | find_unused_parameters: true 217 | -------------------------------------------------------------------------------- /options/train_gfpgan_v1_simple.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_GFPGANv1_512_simple 3 | model_type: GFPGANModel 4 | num_gpu: 4 5 | manual_seed: 0 6 | 7 | # dataset and data loader settings 8 | datasets: 9 | train: 10 | name: FFHQ 11 | type: FFHQDegradationDataset 12 | # dataroot_gt: datasets/ffhq/ffhq_512.lmdb 13 | dataroot_gt: datasets/ffhq/ffhq_512 14 | io_backend: 15 | # type: lmdb 16 | type: disk 17 | 18 | use_hflip: true 19 | mean: [0.5, 0.5, 0.5] 20 | std: [0.5, 0.5, 0.5] 21 | out_size: 512 22 | 23 | blur_kernel_size: 41 24 | kernel_list: ['iso', 'aniso'] 25 | kernel_prob: [0.5, 0.5] 26 | blur_sigma: [0.1, 10] 27 | downsample_range: [0.8, 8] 28 | noise_range: [0, 20] 29 | jpeg_range: [60, 100] 30 | 31 | # color jitter and gray 32 | color_jitter_prob: 0.3 33 | color_jitter_shift: 20 34 | color_jitter_pt_prob: 0.3 35 | gray_prob: 0.01 36 | 37 | # If you do not want colorization, please set 38 | # color_jitter_prob: ~ 39 | # color_jitter_pt_prob: ~ 40 | # gray_prob: 0.01 41 | # gt_gray: True 42 | 43 | # crop_components: false 44 | # component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth 45 | # eye_enlarge_ratio: 1.4 46 | 47 | # data loader 48 | use_shuffle: true 49 | num_worker_per_gpu: 6 50 | batch_size_per_gpu: 3 51 | dataset_enlarge_ratio: 1 52 | prefetch_mode: ~ 53 | 54 | val: 55 | # Please modify accordingly to use your own validation 56 | # Or comment the val block if do not need validation during training 57 | name: validation 58 | type: PairedImageDataset 59 | dataroot_lq: datasets/faces/validation/input 60 | dataroot_gt: datasets/faces/validation/reference 61 | io_backend: 62 | type: disk 63 | mean: [0.5, 0.5, 0.5] 64 | std: [0.5, 0.5, 0.5] 65 | scale: 1 66 | 67 | # network structures 68 | network_g: 69 | type: GFPGANv1 70 | out_size: 512 71 | num_style_feat: 512 72 | channel_multiplier: 1 73 | resample_kernel: [1, 3, 3, 1] 74 | decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth 75 | fix_decoder: true 76 | num_mlp: 8 77 | lr_mlp: 0.01 78 | input_is_latent: true 79 | different_w: true 80 | narrow: 1 81 | sft_half: true 82 | 83 | network_d: 84 | type: StyleGAN2Discriminator 85 | out_size: 512 86 | channel_multiplier: 1 87 | resample_kernel: [1, 3, 3, 1] 88 | 89 | # network_d_left_eye: 90 | # type: FacialComponentDiscriminator 91 | 92 | # network_d_right_eye: 93 | # type: FacialComponentDiscriminator 94 | 95 | # network_d_mouth: 96 | # type: FacialComponentDiscriminator 97 | 98 | network_identity: 99 | type: ResNetArcFace 100 | block: IRBlock 101 | layers: [2, 2, 2, 2] 102 | use_se: False 103 | 104 | # path 105 | path: 106 | pretrain_network_g: ~ 107 | param_key_g: params_ema 108 | strict_load_g: ~ 109 | pretrain_network_d: ~ 110 | # pretrain_network_d_left_eye: ~ 111 | # pretrain_network_d_right_eye: ~ 112 | # pretrain_network_d_mouth: ~ 113 | pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth 114 | # resume 115 | resume_state: ~ 116 | ignore_resume_networks: ['network_identity'] 117 | 118 | # training settings 119 | train: 120 | optim_g: 121 | type: Adam 122 | lr: !!float 2e-3 123 | optim_d: 124 | type: Adam 125 | lr: !!float 2e-3 126 | optim_component: 127 | type: Adam 128 | lr: !!float 2e-3 129 | 130 | scheduler: 131 | type: MultiStepLR 132 | milestones: [600000, 700000] 133 | gamma: 0.5 134 | 135 | total_iter: 800000 136 | warmup_iter: -1 # no warm up 137 | 138 | # losses 139 | # pixel loss 140 | pixel_opt: 141 | type: L1Loss 142 | loss_weight: !!float 1e-1 143 | reduction: mean 144 | # L1 loss used in pyramid loss, component style loss and identity loss 145 | L1_opt: 146 | type: L1Loss 147 | loss_weight: 1 148 | reduction: mean 149 | 150 | # image pyramid loss 151 | pyramid_loss_weight: 1 152 | remove_pyramid_loss: 50000 153 | # perceptual loss (content and style losses) 154 | perceptual_opt: 155 | type: PerceptualLoss 156 | layer_weights: 157 | # before relu 158 | 'conv1_2': 0.1 159 | 'conv2_2': 0.1 160 | 'conv3_4': 1 161 | 'conv4_4': 1 162 | 'conv5_4': 1 163 | vgg_type: vgg19 164 | use_input_norm: true 165 | perceptual_weight: !!float 1 166 | style_weight: 50 167 | range_norm: true 168 | criterion: l1 169 | # gan loss 170 | gan_opt: 171 | type: GANLoss 172 | gan_type: wgan_softplus 173 | loss_weight: !!float 1e-1 174 | # r1 regularization for discriminator 175 | r1_reg_weight: 10 176 | # facial component loss 177 | # gan_component_opt: 178 | # type: GANLoss 179 | # gan_type: vanilla 180 | # real_label_val: 1.0 181 | # fake_label_val: 0.0 182 | # loss_weight: !!float 1 183 | # comp_style_weight: 200 184 | # identity loss 185 | identity_weight: 10 186 | 187 | net_d_iters: 1 188 | net_d_init_iters: 0 189 | net_d_reg_every: 16 190 | 191 | # validation settings 192 | val: 193 | val_freq: !!float 5e3 194 | save_img: true 195 | 196 | metrics: 197 | psnr: # metric name, can be arbitrary 198 | type: calculate_psnr 199 | crop_border: 0 200 | test_y_channel: false 201 | 202 | # logging settings 203 | logger: 204 | print_freq: 100 205 | save_checkpoint_freq: !!float 5e3 206 | use_tb_logger: true 207 | wandb: 208 | project: ~ 209 | resume_id: ~ 210 | 211 | # dist training settings 212 | dist_params: 213 | backend: nccl 214 | port: 29500 215 | 216 | find_unused_parameters: true 217 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | numpy<1.21 # numba requires numpy<1.21,>=1.17 3 | opencv-python 4 | torchvision 5 | scipy 6 | tqdm 7 | basicsr>=1.3.4.0 8 | facexlib>=0.2.0.3 9 | lmdb 10 | pyyaml 11 | tb-nightly 12 | yapf 13 | -------------------------------------------------------------------------------- /scripts/crop_align_vggface2_FFHQalign.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-16 15:34:14 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-19 16:14:01 7 | Description: 8 | ''' 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | import os 13 | import argparse 14 | import cv2 15 | import glob 16 | from tqdm import tqdm 17 | from insightface_func.face_detect_crop_ffhq_newarcAlign import Face_detect_crop 18 | import argparse 19 | 20 | def align_image_dir(dir_name_tmp): 21 | ori_path_tmp = os.path.join(input_dir, dir_name_tmp) 22 | image_filenames = glob.glob(os.path.join(ori_path_tmp,'*')) 23 | save_dir_ffhqalign = os.path.join(output_dir_ffhqalign,dir_name_tmp) 24 | if not os.path.exists(save_dir_ffhqalign): 25 | os.makedirs(save_dir_ffhqalign) 26 | 27 | 28 | for file in image_filenames: 29 | image_file = os.path.basename(file) 30 | 31 | image_file_name_ffhqalign = os.path.join(save_dir_ffhqalign, image_file) 32 | if os.path.exists(image_file_name_ffhqalign): 33 | continue 34 | 35 | face_img = cv2.imread(file) 36 | if face_img.shape[0]<250 or face_img.shape[1]<250: 37 | continue 38 | ret = app.get(face_img,crop_size,mode=mode) 39 | if len(ret)!=0 : 40 | cv2.imwrite(image_file_name_ffhqalign, ret[0]) 41 | else: 42 | continue 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | 48 | parser.add_argument('--input_dir',type=str,default = '/Data/VGGface2/train') 49 | parser.add_argument('--output_dir_ffhqalign',type=str,default = '/Data/VGGface2_FFHQalign') 50 | parser.add_argument('--crop_size',type=int,default = 256) 51 | parser.add_argument('--mode',type=str,default = 'ffhq',choices=['ffhq','newarc','both']) 52 | 53 | args = parser.parse_args() 54 | input_dir = args.input_dir 55 | output_dir_ffhqalign = args.output_dir_ffhqalign 56 | crop_size = args.crop_size 57 | mode = args.mode 58 | 59 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 60 | 61 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(320,320)) 62 | 63 | dirs = sorted(os.listdir(input_dir)) 64 | handle_dir_list = dirs 65 | for handle_dir_list_tmp in tqdm(handle_dir_list): 66 | align_image_dir(handle_dir_list_tmp) 67 | 68 | -------------------------------------------------------------------------------- /scripts/crop_align_vggface2_FFHQalignandNewarcalign.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-15 19:42:42 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-19 16:17:54 7 | Description: 8 | ''' 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | import os 13 | import argparse 14 | import cv2 15 | import glob 16 | from tqdm import tqdm 17 | from insightface_func.face_detect_crop_ffhq_newarcAlign import Face_detect_crop 18 | import argparse 19 | 20 | def align_image_dir(dir_name_tmp): 21 | ori_path_tmp = os.path.join(input_dir, dir_name_tmp) 22 | image_filenames = glob.glob(os.path.join(ori_path_tmp,'*')) 23 | save_dir_newarcalign = os.path.join(output_dir_newarcalign,dir_name_tmp) 24 | save_dir_ffhqalign = os.path.join(output_dir_ffhqalign,dir_name_tmp) 25 | if not os.path.exists(save_dir_newarcalign): 26 | os.makedirs(save_dir_newarcalign) 27 | if not os.path.exists(save_dir_ffhqalign): 28 | os.makedirs(save_dir_ffhqalign) 29 | 30 | 31 | for file in image_filenames: 32 | image_file = os.path.basename(file) 33 | 34 | image_file_name_newarcalign = os.path.join(save_dir_newarcalign, image_file) 35 | image_file_name_ffhqalign = os.path.join(save_dir_ffhqalign, image_file) 36 | if os.path.exists(image_file_name_newarcalign) and os.path.exists(image_file_name_ffhqalign): 37 | continue 38 | 39 | face_img = cv2.imread(file) 40 | if face_img.shape[0]<250 or face_img.shape[1]<250: 41 | continue 42 | ret = app.get(face_img,crop_size,mode=mode) 43 | if len(ret)!=0 : 44 | cv2.imwrite(image_file_name_ffhqalign, ret[0]) 45 | cv2.imwrite(image_file_name_newarcalign, ret[1]) 46 | else: 47 | continue 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | 53 | parser.add_argument('--input_dir',type=str,default = '/home/gdp/harddisk/Data1/VGGface2/train') 54 | parser.add_argument('--output_dir_ffhqalign',type=str,default = '/home/gdp/harddisk/Data1/VGGface2_ffhq_align') 55 | parser.add_argument('--output_dir_newarcalign',type=str,default = '/home/gdp/harddisk/Data1/VGGface2_newarc_align') 56 | parser.add_argument('--crop_size',type=int,default = 256) 57 | parser.add_argument('--mode',type=str,default = 'Both') 58 | 59 | args = parser.parse_args() 60 | input_dir = args.input_dir 61 | output_dir_newarcalign = args.output_dir_newarcalign 62 | output_dir_ffhqalign = args.output_dir_ffhqalign 63 | crop_size = args.crop_size 64 | mode = args.mode 65 | 66 | app = Face_detect_crop(name='antelope', root='./insightface_func/models') 67 | 68 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(320,320)) 69 | 70 | dirs = sorted(os.listdir(input_dir)) 71 | handle_dir_list = dirs 72 | for handle_dir_list_tmp in tqdm(handle_dir_list): 73 | align_image_dir(handle_dir_list_tmp) 74 | 75 | -------------------------------------------------------------------------------- /scripts/inference_gfpgan_forvggface2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-16 19:30:52 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-20 17:57:55 7 | Description: 8 | ''' 9 | import os 10 | import torch 11 | import argparse 12 | 13 | from tqdm import tqdm 14 | from vggface_dataset import getLoader 15 | from basicsr.utils import imwrite, tensor2img 16 | from gfpgan.archs.gfpganv1_arch import GFPGANv1 17 | from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean 18 | import platform 19 | 20 | def main_worker(args): 21 | arch = args.arch 22 | 23 | with torch.no_grad(): 24 | # initialize the GFP-GAN 25 | if arch == 'clean': 26 | gfpgan = GFPGANv1Clean( 27 | out_size=512, 28 | num_style_feat=512, 29 | channel_multiplier=args.channel, 30 | decoder_load_path=None, 31 | fix_decoder=False, 32 | num_mlp=8, 33 | input_is_latent=True, 34 | different_w=True, 35 | narrow=1, 36 | sft_half=True) 37 | else: 38 | gfpgan = GFPGANv1( 39 | out_size=512, 40 | num_style_feat=512, 41 | channel_multiplier=args.channel, 42 | decoder_load_path=None, 43 | fix_decoder=True, 44 | num_mlp=8, 45 | input_is_latent=True, 46 | different_w=True, 47 | narrow=1, 48 | sft_half=True) 49 | 50 | loadnet = torch.load(args.model_path) 51 | if 'params_ema' in loadnet: 52 | keyname = 'params_ema' 53 | else: 54 | keyname = 'params' 55 | gfpgan.load_state_dict(loadnet[keyname], strict=True) 56 | gfpgan.eval() 57 | gfpgan.cuda() 58 | 59 | test_dataloader = getLoader(args.input_path, 512, args.batchSize, 8) 60 | 61 | print(len(test_dataloader)) 62 | 63 | 64 | for images,filenames in tqdm(test_dataloader): 65 | images = images.cuda() 66 | 67 | output_batch = gfpgan(images, return_rgb=False)[0] 68 | 69 | for tmp_index in range(len(output_batch)): 70 | tmp_filename = filenames[tmp_index] 71 | 72 | split_leave = tmp_filename.split(args.input_path)[-1].split(split_name) 73 | restored_face = output_batch[tmp_index] 74 | restored_face = tensor2img(restored_face, rgb2bgr=True, min_max=(-1, 1)) 75 | restored_face = restored_face.astype('uint8') 76 | 77 | sub_dir = os.path.join(args.save_dir, split_leave[-2]) 78 | os.makedirs(sub_dir, exist_ok=True) 79 | 80 | save_path_tmp = os.path.join(sub_dir, split_leave[-1]) 81 | 82 | imwrite(restored_face, save_path_tmp) 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser() 87 | 88 | parser.add_argument('--arch', type=str, default='clean') 89 | parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth') 90 | parser.add_argument('--input_path', type=str, default='/Data/VGGface2_FFHQalign') 91 | parser.add_argument('--sft_half', default = False, action='store_true') 92 | parser.add_argument('--batchSize', type=int, default = 8) 93 | parser.add_argument('--save_dir', type=str, default = ' ') 94 | parser.add_argument('--channel', type=int, default=2) 95 | 96 | args = parser.parse_args() 97 | 98 | if platform.system().lower() == 'windows': 99 | split_name = '\\' 100 | elif platform.system().lower() == 'linux': 101 | split_name = '/' 102 | os.makedirs(args.save_dir, exist_ok=True) 103 | main_worker(args) 104 | -------------------------------------------------------------------------------- /scripts/vggface_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-16 19:32:18 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-16 19:35:12 7 | Description: 8 | ''' 9 | import os 10 | from PIL import Image 11 | from torch.utils import data 12 | from torchvision import transforms as T 13 | import glob 14 | 15 | 16 | 17 | class TotalDataset(data.Dataset): 18 | 19 | def __init__(self,image_dir, 20 | content_transform): 21 | self.image_dir = image_dir 22 | 23 | self.content_transform= content_transform 24 | self.dataset = [] 25 | self.mean = [0.5, 0.5, 0.5] 26 | self.std = [0.5, 0.5, 0.5] 27 | self.preprocess() 28 | self.num_images = len(self.dataset) 29 | 30 | def preprocess(self): 31 | additional_pattern = '*/*' 32 | self.dataset.extend(sorted(glob.glob(os.path.join(self.image_dir, additional_pattern), recursive=False))) 33 | 34 | print('Finished preprocessing the VGGFACE2 dataset...') 35 | 36 | 37 | def __getitem__(self, index): 38 | """Return single image.""" 39 | dataset = self.dataset 40 | 41 | src_filename1 = dataset[index] 42 | 43 | src_image1 = self.content_transform(Image.open(src_filename1)) 44 | 45 | 46 | return src_image1, src_filename1 47 | 48 | 49 | def __len__(self): 50 | """Return the number of images.""" 51 | return self.num_images 52 | 53 | def getLoader(c_image_dir, ResizeSize=512, batch_size=16, num_workers=8): 54 | """Build and return a data loader.""" 55 | c_transforms = [] 56 | 57 | 58 | c_transforms.append(T.Resize([ResizeSize,ResizeSize])) 59 | c_transforms.append(T.ToTensor()) 60 | c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) 61 | 62 | c_transforms = T.Compose(c_transforms) 63 | 64 | content_dataset = TotalDataset(c_image_dir, c_transforms) 65 | 66 | 67 | sampler = None 68 | content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, 69 | drop_last=False,num_workers=num_workers,sampler=sampler,pin_memory=True) 70 | return content_data_loader 71 | 72 | def denorm(x): 73 | out = (x + 1) / 2 74 | return out.clamp_(0, 1) 75 | 76 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=120 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | column_limit = 120 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | 15 | [isort] 16 | line_length = 120 17 | multi_line_output = 0 18 | known_standard_library = pkg_resources,setuptools 19 | known_first_party = gfpgan 20 | known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm 21 | no_lines_before = STDLIB,LOCALFOLDER 22 | default_section = THIRDPARTY 23 | 24 | [codespell] 25 | skip = .git,./docs/build 26 | count = 27 | quiet-level = 3 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import time 8 | 9 | version_file = 'gfpgan/version.py' 10 | 11 | 12 | def readme(): 13 | with open('README.md', encoding='utf-8') as f: 14 | content = f.read() 15 | return content 16 | 17 | 18 | def get_git_hash(): 19 | 20 | def _minimal_ext_cmd(cmd): 21 | # construct minimal environment 22 | env = {} 23 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 24 | v = os.environ.get(k) 25 | if v is not None: 26 | env[k] = v 27 | # LANGUAGE is used on win32 28 | env['LANGUAGE'] = 'C' 29 | env['LANG'] = 'C' 30 | env['LC_ALL'] = 'C' 31 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 32 | return out 33 | 34 | try: 35 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 36 | sha = out.strip().decode('ascii') 37 | except OSError: 38 | sha = 'unknown' 39 | 40 | return sha 41 | 42 | 43 | def get_hash(): 44 | if os.path.exists('.git'): 45 | sha = get_git_hash()[:7] 46 | else: 47 | sha = 'unknown' 48 | 49 | return sha 50 | 51 | 52 | def write_version_py(): 53 | content = """# GENERATED VERSION FILE 54 | # TIME: {} 55 | __version__ = '{}' 56 | __gitsha__ = '{}' 57 | version_info = ({}) 58 | """ 59 | sha = get_hash() 60 | with open('VERSION', 'r') as f: 61 | SHORT_VERSION = f.read().strip() 62 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 63 | 64 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 65 | with open(version_file, 'w') as f: 66 | f.write(version_file_str) 67 | 68 | 69 | def get_version(): 70 | with open(version_file, 'r') as f: 71 | exec(compile(f.read(), version_file, 'exec')) 72 | return locals()['__version__'] 73 | 74 | 75 | def get_requirements(filename='requirements.txt'): 76 | here = os.path.dirname(os.path.realpath(__file__)) 77 | with open(os.path.join(here, filename), 'r') as f: 78 | requires = [line.replace('\n', '') for line in f.readlines()] 79 | return requires 80 | 81 | 82 | if __name__ == '__main__': 83 | write_version_py() 84 | setup( 85 | name='gfpgan', 86 | version=get_version(), 87 | description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration', 88 | long_description=readme(), 89 | long_description_content_type='text/markdown', 90 | author='Xintao Wang', 91 | author_email='xintao.wang@outlook.com', 92 | keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan', 93 | url='https://github.com/TencentARC/GFPGAN', 94 | include_package_data=True, 95 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 96 | classifiers=[ 97 | 'Development Status :: 4 - Beta', 98 | 'License :: OSI Approved :: Apache Software License', 99 | 'Operating System :: OS Independent', 100 | 'Programming Language :: Python :: 3', 101 | 'Programming Language :: Python :: 3.7', 102 | 'Programming Language :: Python :: 3.8', 103 | ], 104 | license='Apache License Version 2.0', 105 | setup_requires=['cython', 'numpy'], 106 | install_requires=get_requirements(), 107 | zip_safe=False) 108 | --------------------------------------------------------------------------------