├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── assets ├── afhq_dataset.jpg ├── afhq_interpolation.gif ├── afhqv2_teaser2.jpg ├── celebahq_interpolation.gif ├── representative │ ├── afhq │ │ ├── ref │ │ │ ├── cat │ │ │ │ ├── flickr_cat_000495.jpg │ │ │ │ ├── flickr_cat_000557.jpg │ │ │ │ ├── pixabay_cat_000355.jpg │ │ │ │ ├── pixabay_cat_000491.jpg │ │ │ │ ├── pixabay_cat_000535.jpg │ │ │ │ ├── pixabay_cat_000623.jpg │ │ │ │ ├── pixabay_cat_000730.jpg │ │ │ │ ├── pixabay_cat_001479.jpg │ │ │ │ ├── pixabay_cat_001699.jpg │ │ │ │ └── pixabay_cat_003046.jpg │ │ │ ├── dog │ │ │ │ ├── flickr_dog_001072.jpg │ │ │ │ ├── pixabay_dog_000121.jpg │ │ │ │ ├── pixabay_dog_000322.jpg │ │ │ │ ├── pixabay_dog_000357.jpg │ │ │ │ ├── pixabay_dog_000409.jpg │ │ │ │ ├── pixabay_dog_000799.jpg │ │ │ │ ├── pixabay_dog_000890.jpg │ │ │ │ └── pixabay_dog_001082.jpg │ │ │ └── wild │ │ │ │ ├── flickr_wild_000731.jpg │ │ │ │ ├── flickr_wild_001223.jpg │ │ │ │ ├── flickr_wild_002020.jpg │ │ │ │ ├── flickr_wild_002092.jpg │ │ │ │ ├── flickr_wild_002933.jpg │ │ │ │ ├── flickr_wild_003137.jpg │ │ │ │ ├── flickr_wild_003355.jpg │ │ │ │ ├── flickr_wild_003796.jpg │ │ │ │ ├── flickr_wild_003969.jpg │ │ │ │ └── pixabay_wild_000637.jpg │ │ └── src │ │ │ ├── cat │ │ │ ├── flickr_cat_000253.jpg │ │ │ ├── pixabay_cat_000181.jpg │ │ │ ├── pixabay_cat_000241.jpg │ │ │ ├── pixabay_cat_000276.jpg │ │ │ └── pixabay_cat_004826.jpg │ │ │ ├── dog │ │ │ ├── flickr_dog_000094.jpg │ │ │ ├── pixabay_dog_000321.jpg │ │ │ ├── pixabay_dog_000322.jpg │ │ │ ├── pixabay_dog_001082.jpg │ │ │ └── pixabay_dog_002066.jpg │ │ │ └── wild │ │ │ ├── flickr_wild_000432.jpg │ │ │ ├── flickr_wild_000814.jpg │ │ │ ├── flickr_wild_002036.jpg │ │ │ ├── flickr_wild_002159.jpg │ │ │ └── pixabay_wild_000558.jpg │ ├── celeba_hq │ │ ├── ref │ │ │ ├── female │ │ │ │ ├── 015248.jpg │ │ │ │ ├── 030321.jpg │ │ │ │ ├── 031796.jpg │ │ │ │ ├── 036619.jpg │ │ │ │ ├── 042373.jpg │ │ │ │ ├── 048197.jpg │ │ │ │ ├── 052599.jpg │ │ │ │ ├── 058150.jpg │ │ │ │ ├── 058225.jpg │ │ │ │ ├── 058881.jpg │ │ │ │ ├── 063109.jpg │ │ │ │ ├── 064119.jpg │ │ │ │ ├── 064307.jpg │ │ │ │ ├── 074075.jpg │ │ │ │ ├── 074934.jpg │ │ │ │ ├── 076551.jpg │ │ │ │ ├── 081680.jpg │ │ │ │ ├── 081871.jpg │ │ │ │ ├── 084913.jpg │ │ │ │ ├── 086986.jpg │ │ │ │ ├── 113393.jpg │ │ │ │ ├── 135626.jpg │ │ │ │ ├── 140613.jpg │ │ │ │ ├── 142595.jpg │ │ │ │ └── 195650.jpg │ │ │ └── male │ │ │ │ ├── 012712.jpg │ │ │ │ ├── 020167.jpg │ │ │ │ ├── 021612.jpg │ │ │ │ ├── 036367.jpg │ │ │ │ ├── 037023.jpg │ │ │ │ ├── 038919.jpg │ │ │ │ ├── 047763.jpg │ │ │ │ ├── 060259.jpg │ │ │ │ ├── 067791.jpg │ │ │ │ ├── 077921.jpg │ │ │ │ ├── 083510.jpg │ │ │ │ ├── 094805.jpg │ │ │ │ ├── 116032.jpg │ │ │ │ ├── 118017.jpg │ │ │ │ ├── 137590.jpg │ │ │ │ ├── 145842.jpg │ │ │ │ ├── 153793.jpg │ │ │ │ ├── 156498.jpg │ │ │ │ ├── 164930.jpg │ │ │ │ ├── 189498.jpg │ │ │ │ └── 191084.jpg │ │ └── src │ │ │ ├── female │ │ │ ├── 039913.jpg │ │ │ ├── 051340.jpg │ │ │ ├── 069067.jpg │ │ │ ├── 091623.jpg │ │ │ └── 172559.jpg │ │ │ └── male │ │ │ ├── 005735.jpg │ │ │ ├── 006930.jpg │ │ │ ├── 016387.jpg │ │ │ ├── 191300.jpg │ │ │ └── 196930.jpg │ └── custom │ │ ├── female │ │ └── custom_female.jpg │ │ └── male │ │ └── custom_male.jpg ├── teaser.jpg └── youtube_video.jpg ├── core ├── __init__.py ├── checkpoint.py ├── data_loader.py ├── model.py ├── solver.py ├── utils.py └── wing.py ├── download.sh ├── main.py └── metrics ├── __init__.py ├── eval.py ├── fid.py ├── lpips.py └── lpips_weights.ckpt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020-present, NAVER Corp. 2 | All rights reserved. 3 | 4 | 5 | Attribution-NonCommercial 4.0 International 6 | 7 | ======================================================================= 8 | 9 | Creative Commons Corporation ("Creative Commons") is not a law firm and 10 | does not provide legal services or legal advice. Distribution of 11 | Creative Commons public licenses does not create a lawyer-client or 12 | other relationship. Creative Commons makes its licenses and related 13 | information available on an "as-is" basis. Creative Commons gives no 14 | warranties regarding its licenses, any material licensed under their 15 | terms and conditions, or any related information. Creative Commons 16 | disclaims all liability for damages resulting from their use to the 17 | fullest extent possible. 18 | 19 | Using Creative Commons Public Licenses 20 | 21 | Creative Commons public licenses provide a standard set of terms and 22 | conditions that creators and other rights holders may use to share 23 | original works of authorship and other material subject to copyright 24 | and certain other rights specified in the public license below. The 25 | following considerations are for informational purposes only, are not 26 | exhaustive, and do not form part of our licenses. 27 | 28 | Considerations for licensors: Our public licenses are 29 | intended for use by those authorized to give the public 30 | permission to use material in ways otherwise restricted by 31 | copyright and certain other rights. Our licenses are 32 | irrevocable. Licensors should read and understand the terms 33 | and conditions of the license they choose before applying it. 34 | Licensors should also secure all rights necessary before 35 | applying our licenses so that the public can reuse the 36 | material as expected. Licensors should clearly mark any 37 | material not subject to the license. This includes other CC- 38 | licensed material, or material used under an exception or 39 | limitation to copyright. More considerations for licensors: 40 | wiki.creativecommons.org/Considerations_for_licensors 41 | 42 | Considerations for the public: By using one of our public 43 | licenses, a licensor grants the public permission to use the 44 | licensed material under specified terms and conditions. If 45 | the licensor's permission is not necessary for any reason--for 46 | example, because of any applicable exception or limitation to 47 | copyright--then that use is not regulated by the license. Our 48 | licenses grant only permissions under copyright and certain 49 | other rights that a licensor has authority to grant. Use of 50 | the licensed material may still be restricted for other 51 | reasons, including because others have copyright or other 52 | rights in the material. A licensor may make special requests, 53 | such as asking that all changes be marked or described. 54 | Although not required by our licenses, you are encouraged to 55 | respect those requests where reasonable. More_considerations 56 | for the public: 57 | wiki.creativecommons.org/Considerations_for_licensees 58 | 59 | ======================================================================= 60 | 61 | Creative Commons Attribution-NonCommercial 4.0 International Public 62 | License 63 | 64 | By exercising the Licensed Rights (defined below), You accept and agree 65 | to be bound by the terms and conditions of this Creative Commons 66 | Attribution-NonCommercial 4.0 International Public License ("Public 67 | License"). To the extent this Public License may be interpreted as a 68 | contract, You are granted the Licensed Rights in consideration of Your 69 | acceptance of these terms and conditions, and the Licensor grants You 70 | such rights in consideration of benefits the Licensor receives from 71 | making the Licensed Material available under these terms and 72 | conditions. 73 | 74 | 75 | Section 1 -- Definitions. 76 | 77 | a. Adapted Material means material subject to Copyright and Similar 78 | Rights that is derived from or based upon the Licensed Material 79 | and in which the Licensed Material is translated, altered, 80 | arranged, transformed, or otherwise modified in a manner requiring 81 | permission under the Copyright and Similar Rights held by the 82 | Licensor. For purposes of this Public License, where the Licensed 83 | Material is a musical work, performance, or sound recording, 84 | Adapted Material is always produced where the Licensed Material is 85 | synched in timed relation with a moving image. 86 | 87 | b. Adapter's License means the license You apply to Your Copyright 88 | and Similar Rights in Your contributions to Adapted Material in 89 | accordance with the terms and conditions of this Public License. 90 | 91 | c. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | d. Effective Technological Measures means those measures that, in the 99 | absence of proper authority, may not be circumvented under laws 100 | fulfilling obligations under Article 11 of the WIPO Copyright 101 | Treaty adopted on December 20, 1996, and/or similar international 102 | agreements. 103 | 104 | e. Exceptions and Limitations means fair use, fair dealing, and/or 105 | any other exception or limitation to Copyright and Similar Rights 106 | that applies to Your use of the Licensed Material. 107 | 108 | f. Licensed Material means the artistic or literary work, database, 109 | or other material to which the Licensor applied this Public 110 | License. 111 | 112 | g. Licensed Rights means the rights granted to You subject to the 113 | terms and conditions of this Public License, which are limited to 114 | all Copyright and Similar Rights that apply to Your use of the 115 | Licensed Material and that the Licensor has authority to license. 116 | 117 | h. Licensor means the individual(s) or entity(ies) granting rights 118 | under this Public License. 119 | 120 | i. NonCommercial means not primarily intended for or directed towards 121 | commercial advantage or monetary compensation. For purposes of 122 | this Public License, the exchange of the Licensed Material for 123 | other material subject to Copyright and Similar Rights by digital 124 | file-sharing or similar means is NonCommercial provided there is 125 | no payment of monetary compensation in connection with the 126 | exchange. 127 | 128 | j. Share means to provide material to the public by any means or 129 | process that requires permission under the Licensed Rights, such 130 | as reproduction, public display, public performance, distribution, 131 | dissemination, communication, or importation, and to make material 132 | available to the public including in ways that members of the 133 | public may access the material from a place and at a time 134 | individually chosen by them. 135 | 136 | k. Sui Generis Database Rights means rights other than copyright 137 | resulting from Directive 96/9/EC of the European Parliament and of 138 | the Council of 11 March 1996 on the legal protection of databases, 139 | as amended and/or succeeded, as well as other essentially 140 | equivalent rights anywhere in the world. 141 | 142 | l. You means the individual or entity exercising the Licensed Rights 143 | under this Public License. Your has a corresponding meaning. 144 | 145 | 146 | Section 2 -- Scope. 147 | 148 | a. License grant. 149 | 150 | 1. Subject to the terms and conditions of this Public License, 151 | the Licensor hereby grants You a worldwide, royalty-free, 152 | non-sublicensable, non-exclusive, irrevocable license to 153 | exercise the Licensed Rights in the Licensed Material to: 154 | 155 | a. reproduce and Share the Licensed Material, in whole or 156 | in part, for NonCommercial purposes only; and 157 | 158 | b. produce, reproduce, and Share Adapted Material for 159 | NonCommercial purposes only. 160 | 161 | 2. Exceptions and Limitations. For the avoidance of doubt, where 162 | Exceptions and Limitations apply to Your use, this Public 163 | License does not apply, and You do not need to comply with 164 | its terms and conditions. 165 | 166 | 3. Term. The term of this Public License is specified in Section 167 | 6(a). 168 | 169 | 4. Media and formats; technical modifications allowed. The 170 | Licensor authorizes You to exercise the Licensed Rights in 171 | all media and formats whether now known or hereafter created, 172 | and to make technical modifications necessary to do so. The 173 | Licensor waives and/or agrees not to assert any right or 174 | authority to forbid You from making technical modifications 175 | necessary to exercise the Licensed Rights, including 176 | technical modifications necessary to circumvent Effective 177 | Technological Measures. For purposes of this Public License, 178 | simply making modifications authorized by this Section 2(a) 179 | (4) never produces Adapted Material. 180 | 181 | 5. Downstream recipients. 182 | 183 | a. Offer from the Licensor -- Licensed Material. Every 184 | recipient of the Licensed Material automatically 185 | receives an offer from the Licensor to exercise the 186 | Licensed Rights under the terms and conditions of this 187 | Public License. 188 | 189 | b. No downstream restrictions. You may not offer or impose 190 | any additional or different terms or conditions on, or 191 | apply any Effective Technological Measures to, the 192 | Licensed Material if doing so restricts exercise of the 193 | Licensed Rights by any recipient of the Licensed 194 | Material. 195 | 196 | 6. No endorsement. Nothing in this Public License constitutes or 197 | may be construed as permission to assert or imply that You 198 | are, or that Your use of the Licensed Material is, connected 199 | with, or sponsored, endorsed, or granted official status by, 200 | the Licensor or others designated to receive attribution as 201 | provided in Section 3(a)(1)(A)(i). 202 | 203 | b. Other rights. 204 | 205 | 1. Moral rights, such as the right of integrity, are not 206 | licensed under this Public License, nor are publicity, 207 | privacy, and/or other similar personality rights; however, to 208 | the extent possible, the Licensor waives and/or agrees not to 209 | assert any such rights held by the Licensor to the limited 210 | extent necessary to allow You to exercise the Licensed 211 | Rights, but not otherwise. 212 | 213 | 2. Patent and trademark rights are not licensed under this 214 | Public License. 215 | 216 | 3. To the extent possible, the Licensor waives any right to 217 | collect royalties from You for the exercise of the Licensed 218 | Rights, whether directly or through a collecting society 219 | under any voluntary or waivable statutory or compulsory 220 | licensing scheme. In all other cases the Licensor expressly 221 | reserves any right to collect such royalties, including when 222 | the Licensed Material is used other than for NonCommercial 223 | purposes. 224 | 225 | 226 | Section 3 -- License Conditions. 227 | 228 | Your exercise of the Licensed Rights is expressly made subject to the 229 | following conditions. 230 | 231 | a. Attribution. 232 | 233 | 1. If You Share the Licensed Material (including in modified 234 | form), You must: 235 | 236 | a. retain the following if it is supplied by the Licensor 237 | with the Licensed Material: 238 | 239 | i. identification of the creator(s) of the Licensed 240 | Material and any others designated to receive 241 | attribution, in any reasonable manner requested by 242 | the Licensor (including by pseudonym if 243 | designated); 244 | 245 | ii. a copyright notice; 246 | 247 | iii. a notice that refers to this Public License; 248 | 249 | iv. a notice that refers to the disclaimer of 250 | warranties; 251 | 252 | v. a URI or hyperlink to the Licensed Material to the 253 | extent reasonably practicable; 254 | 255 | b. indicate if You modified the Licensed Material and 256 | retain an indication of any previous modifications; and 257 | 258 | c. indicate the Licensed Material is licensed under this 259 | Public License, and include the text of, or the URI or 260 | hyperlink to, this Public License. 261 | 262 | 2. You may satisfy the conditions in Section 3(a)(1) in any 263 | reasonable manner based on the medium, means, and context in 264 | which You Share the Licensed Material. For example, it may be 265 | reasonable to satisfy the conditions by providing a URI or 266 | hyperlink to a resource that includes the required 267 | information. 268 | 269 | 3. If requested by the Licensor, You must remove any of the 270 | information required by Section 3(a)(1)(A) to the extent 271 | reasonably practicable. 272 | 273 | 4. If You Share Adapted Material You produce, the Adapter's 274 | License You apply must not prevent recipients of the Adapted 275 | Material from complying with this Public License. 276 | 277 | 278 | Section 4 -- Sui Generis Database Rights. 279 | 280 | Where the Licensed Rights include Sui Generis Database Rights that 281 | apply to Your use of the Licensed Material: 282 | 283 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 284 | to extract, reuse, reproduce, and Share all or a substantial 285 | portion of the contents of the database for NonCommercial purposes 286 | only; 287 | 288 | b. if You include all or a substantial portion of the database 289 | contents in a database in which You have Sui Generis Database 290 | Rights, then the database in which You have Sui Generis Database 291 | Rights (but not its individual contents) is Adapted Material; and 292 | 293 | c. You must comply with the conditions in Section 3(a) if You Share 294 | all or a substantial portion of the contents of the database. 295 | 296 | For the avoidance of doubt, this Section 4 supplements and does not 297 | replace Your obligations under this Public License where the Licensed 298 | Rights include other Copyright and Similar Rights. 299 | 300 | 301 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 302 | 303 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 304 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 305 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 306 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 307 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 308 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 309 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 310 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 311 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 312 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 313 | 314 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 315 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 316 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 317 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 318 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 319 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 320 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 321 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 322 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 323 | 324 | c. The disclaimer of warranties and limitation of liability provided 325 | above shall be interpreted in a manner that, to the extent 326 | possible, most closely approximates an absolute disclaimer and 327 | waiver of all liability. 328 | 329 | 330 | Section 6 -- Term and Termination. 331 | 332 | a. This Public License applies for the term of the Copyright and 333 | Similar Rights licensed here. However, if You fail to comply with 334 | this Public License, then Your rights under this Public License 335 | terminate automatically. 336 | 337 | b. Where Your right to use the Licensed Material has terminated under 338 | Section 6(a), it reinstates: 339 | 340 | 1. automatically as of the date the violation is cured, provided 341 | it is cured within 30 days of Your discovery of the 342 | violation; or 343 | 344 | 2. upon express reinstatement by the Licensor. 345 | 346 | For the avoidance of doubt, this Section 6(b) does not affect any 347 | right the Licensor may have to seek remedies for Your violations 348 | of this Public License. 349 | 350 | c. For the avoidance of doubt, the Licensor may also offer the 351 | Licensed Material under separate terms or conditions or stop 352 | distributing the Licensed Material at any time; however, doing so 353 | will not terminate this Public License. 354 | 355 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 356 | License. 357 | 358 | 359 | Section 7 -- Other Terms and Conditions. 360 | 361 | a. The Licensor shall not be bound by any additional or different 362 | terms or conditions communicated by You unless expressly agreed. 363 | 364 | b. Any arrangements, understandings, or agreements regarding the 365 | Licensed Material not stated herein are separate from and 366 | independent of the terms and conditions of this Public License. 367 | 368 | 369 | Section 8 -- Interpretation. 370 | 371 | a. For the avoidance of doubt, this Public License does not, and 372 | shall not be interpreted to, reduce, limit, restrict, or impose 373 | conditions on any use of the Licensed Material that could lawfully 374 | be made without permission under this Public License. 375 | 376 | b. To the extent possible, if any provision of this Public License is 377 | deemed unenforceable, it shall be automatically reformed to the 378 | minimum extent necessary to make it enforceable. If the provision 379 | cannot be reformed, it shall be severed from this Public License 380 | without affecting the enforceability of the remaining terms and 381 | conditions. 382 | 383 | c. No term or condition of this Public License will be waived and no 384 | failure to comply consented to unless expressly agreed to by the 385 | Licensor. 386 | 387 | d. Nothing in this Public License constitutes or may be interpreted 388 | as a limitation upon, or waiver of, any privileges and immunities 389 | that apply to the Licensor or You, including from the legal 390 | processes of any jurisdiction or authority. 391 | 392 | ======================================================================= 393 | 394 | Creative Commons is not a party to its public 395 | licenses. Notwithstanding, Creative Commons may elect to apply one of 396 | its public licenses to material it publishes and in those instances 397 | will be considered the "Licensor." The text of the Creative Commons 398 | public licenses is dedicated to the public domain under the CC0 Public 399 | Domain Dedication. Except for the limited purpose of indicating that 400 | material is shared under a Creative Commons public license or as 401 | otherwise permitted by the Creative Commons policies published at 402 | creativecommons.org/policies, Creative Commons does not authorize the 403 | use of the trademark "Creative Commons" or any other trademark or logo 404 | of Creative Commons without its prior written consent including, 405 | without limitation, in connection with any unauthorized modifications 406 | to any of its public licenses or any other arrangements, 407 | understandings, or agreements concerning use of licensed material. For 408 | the avoidance of doubt, this paragraph does not form part of the 409 | public licenses. 410 | 411 | Creative Commons may be contacted at creativecommons.org. 412 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | StarGAN v2 2 | 3 | Copyright (c) 2020-present NAVER Corp. 4 | All rights reserved. 5 | 6 | This work is licensed under the Creative Commons Attribution-NonCommercial 7 | 4.0 International License. To view a copy of this license, visit 8 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 9 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 10 | 11 | -------------------------------------------------------------------------------------- 12 | 13 | This project contains subcomponents with separate copyright notices and license terms. 14 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 15 | 16 | ===== 17 | 18 | 1adrianb/face-alignment 19 | https://github.com/1adrianb/face-alignment 20 | 21 | 22 | BSD 3-Clause License 23 | 24 | Copyright (c) 2017, Adrian Bulat 25 | All rights reserved. 26 | 27 | Redistribution and use in source and binary forms, with or without 28 | modification, are permitted provided that the following conditions are met: 29 | 30 | * Redistributions of source code must retain the above copyright notice, this 31 | list of conditions and the following disclaimer. 32 | 33 | * Redistributions in binary form must reproduce the above copyright notice, 34 | this list of conditions and the following disclaimer in the documentation 35 | and/or other materials provided with the distribution. 36 | 37 | * Neither the name of the copyright holder nor the names of its 38 | contributors may be used to endorse or promote products derived from 39 | this software without specific prior written permission. 40 | 41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 42 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 43 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 44 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 45 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 46 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 47 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 48 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 49 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 50 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 51 | 52 | ===== 53 | 54 | protossw512/AdaptiveWingLoss 55 | https://github.com/protossw512/AdaptiveWingLoss 56 | 57 | 58 | [ICCV 2019] Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression - Official Implementation 59 | 60 | 61 | Licensed under the Apache License, Version 2.0 (the "License"); 62 | you may not use this file except in compliance with the License. 63 | You may obtain a copy of the License at 64 | 65 | http://www.apache.org/licenses/LICENSE-2.0 66 | 67 | Unless required by applicable law or agreed to in writing, software 68 | distributed under the License is distributed on an "AS IS" BASIS, 69 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 70 | See the License for the specific language governing permissions and 71 | limitations under the License. 72 | 73 | --- 74 | 75 | author = {Wang, Xinyao and Bo, Liefeng and Fuxin, Li}, 76 | title = {Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression}, 77 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 78 | month = {October}, 79 | year = {2019} 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## StarGAN v2 - Official PyTorch Implementation 3 | 4 |

5 | 6 | > **StarGAN v2: Diverse Image Synthesis for Multiple Domains**
7 | > [Yunjey Choi](https://github.com/yunjey)\*, [Youngjung Uh](https://github.com/youngjung)\*, [Jaejun Yoo](http://jaejunyoo.blogspot.com/search/label/kr)\*, [Jung-Woo Ha](https://www.facebook.com/jungwoo.ha.921)
8 | > In CVPR 2020. (* indicates equal contribution)
9 | 10 | > Paper: https://arxiv.org/abs/1912.01865
11 | > Video: https://youtu.be/0EVh5Ki4dIY
12 | 13 | > **Abstract:** *A good image-to-image translation model should learn a mapping between different visual domains while satisfying the following properties: 1) diversity of generated images and 2) scalability over multiple domains. Existing methods address either of the issues, having limited diversity or multiple models for all domains. We propose StarGAN v2, a single framework that tackles both and shows significantly improved results over the baselines. Experiments on CelebA-HQ and a new animal faces dataset (AFHQ) validate our superiority in terms of visual quality, diversity, and scalability. To better assess image-to-image translation models, we release AFHQ, high-quality animal faces with large inter- and intra-domain variations. The code, pre-trained models, and dataset are available at clovaai/stargan-v2.* 14 | 15 | ## Teaser video 16 | Click the figure to watch the teaser video.
17 | 18 | [![IMAGE ALT TEXT HERE](assets/youtube_video.jpg)](https://youtu.be/0EVh5Ki4dIY) 19 | 20 | ## TensorFlow implementation 21 | The TensorFlow implementation of StarGAN v2 by our team member junho can be found at [clovaai/stargan-v2-tensorflow](https://github.com/clovaai/stargan-v2-tensorflow). 22 | 23 | ## Software installation 24 | Clone this repository: 25 | 26 | ```bash 27 | git clone https://github.com/clovaai/stargan-v2.git 28 | cd stargan-v2/ 29 | ``` 30 | 31 | Install the dependencies: 32 | ```bash 33 | conda create -n stargan-v2 python=3.6.7 34 | conda activate stargan-v2 35 | conda install -y pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.0 -c pytorch 36 | conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge 37 | pip install opencv-python==4.1.2.30 ffmpeg-python==0.2.0 scikit-image==0.16.2 38 | pip install pillow==7.0.0 scipy==1.2.1 tqdm==4.43.0 munch==2.5.0 39 | ``` 40 | 41 | ## Datasets and pre-trained networks 42 | We provide a script to download datasets used in StarGAN v2 and the corresponding pre-trained networks. The datasets and network checkpoints will be downloaded and stored in the `data` and `expr/checkpoints` directories, respectively. 43 | 44 | CelebA-HQ. To download the [CelebA-HQ](https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs) dataset and the pre-trained network, run the following commands: 45 | ```bash 46 | bash download.sh celeba-hq-dataset 47 | bash download.sh pretrained-network-celeba-hq 48 | bash download.sh wing 49 | ``` 50 | 51 | AFHQ. To download the [AFHQ](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) dataset and the pre-trained network, run the following commands: 52 | ```bash 53 | bash download.sh afhq-dataset 54 | bash download.sh pretrained-network-afhq 55 | ``` 56 | 57 | 58 | ## Generating interpolation videos 59 | After downloading the pre-trained networks, you can synthesize output images reflecting diverse styles (e.g., hairstyle) of reference images. The following commands will save generated images and interpolation videos to the `expr/results` directory. 60 | 61 | 62 | CelebA-HQ. To generate images and interpolation videos, run the following command: 63 | ```bash 64 | python main.py --mode sample --num_domains 2 --resume_iter 100000 --w_hpf 1 \ 65 | --checkpoint_dir expr/checkpoints/celeba_hq \ 66 | --result_dir expr/results/celeba_hq \ 67 | --src_dir assets/representative/celeba_hq/src \ 68 | --ref_dir assets/representative/celeba_hq/ref 69 | ``` 70 | 71 | To transform a custom image, first crop the image manually so that the proportion of face occupied in the whole is similar to that of CelebA-HQ. Then, run the following command for additional fine rotation and cropping. All custom images in the `inp_dir` directory will be aligned and stored in the `out_dir` directory. 72 | 73 | ```bash 74 | python main.py --mode align \ 75 | --inp_dir assets/representative/custom/female \ 76 | --out_dir assets/representative/celeba_hq/src/female 77 | ``` 78 | 79 | 80 |

81 | 82 | 83 | AFHQ. To generate images and interpolation videos, run the following command: 84 | ```bash 85 | python main.py --mode sample --num_domains 3 --resume_iter 100000 --w_hpf 0 \ 86 | --checkpoint_dir expr/checkpoints/afhq \ 87 | --result_dir expr/results/afhq \ 88 | --src_dir assets/representative/afhq/src \ 89 | --ref_dir assets/representative/afhq/ref 90 | ``` 91 | 92 |

93 | 94 | ## Evaluation metrics 95 | To evaluate StarGAN v2 using [Fréchet Inception Distance (FID)](https://arxiv.org/abs/1706.08500) and [Learned Perceptual Image Patch Similarity (LPIPS)](https://arxiv.org/abs/1801.03924), run the following commands: 96 | 97 | 98 | ```bash 99 | # celeba-hq 100 | python main.py --mode eval --num_domains 2 --w_hpf 1 \ 101 | --resume_iter 100000 \ 102 | --train_img_dir data/celeba_hq/train \ 103 | --val_img_dir data/celeba_hq/val \ 104 | --checkpoint_dir expr/checkpoints/celeba_hq \ 105 | --eval_dir expr/eval/celeba_hq 106 | 107 | # afhq 108 | python main.py --mode eval --num_domains 3 --w_hpf 0 \ 109 | --resume_iter 100000 \ 110 | --train_img_dir data/afhq/train \ 111 | --val_img_dir data/afhq/val \ 112 | --checkpoint_dir expr/checkpoints/afhq \ 113 | --eval_dir expr/eval/afhq 114 | ``` 115 | 116 | Note that the evaluation metrics are calculated using random latent vectors or reference images, both of which are selected by the [seed number](https://github.com/clovaai/stargan-v2/blob/master/main.py#L35). In the paper, we reported the average of values from 10 measurements using different seed numbers. The following table shows the calculated values for both latent-guided and reference-guided synthesis. 117 | 118 | | Dataset | FID (latent) | LPIPS (latent) | FID (reference) | LPIPS (reference) | Elapsed time | 119 | | :---------- | :------------: | :----: | :-----: | :----: | :----------:| 120 | | `celeba-hq` | 13.73 ± 0.06 | 0.4515 ± 0.0006 | 23.84 ± 0.03 | 0.3880 ± 0.0001 | 49min 51s 121 | | `afhq` | 16.18 ± 0.15 | 0.4501 ± 0.0007 | 19.78 ± 0.01 | 0.4315 ± 0.0002 | 64min 49s 122 | 123 | 124 | 125 | ## Training networks 126 | To train StarGAN v2 from scratch, run the following commands. Generated images and network checkpoints will be stored in the `expr/samples` and `expr/checkpoints` directories, respectively. Training takes about three days on a single Tesla V100 GPU. Please see [here](https://github.com/clovaai/stargan-v2/blob/master/main.py#L86-L179) for training arguments and a description of them. 127 | 128 | ```bash 129 | # celeba-hq 130 | python main.py --mode train --num_domains 2 --w_hpf 1 \ 131 | --lambda_reg 1 --lambda_sty 1 --lambda_ds 1 --lambda_cyc 1 \ 132 | --train_img_dir data/celeba_hq/train \ 133 | --val_img_dir data/celeba_hq/val 134 | 135 | # afhq 136 | python main.py --mode train --num_domains 3 --w_hpf 0 \ 137 | --lambda_reg 1 --lambda_sty 1 --lambda_ds 2 --lambda_cyc 1 \ 138 | --train_img_dir data/afhq/train \ 139 | --val_img_dir data/afhq/val 140 | ``` 141 | 142 | ## Animal Faces-HQ dataset (AFHQ) 143 | 144 |

145 | 146 | We release a new dataset of animal faces, Animal Faces-HQ (AFHQ), consisting of 15,000 high-quality images at 512×512 resolution. The figure above shows example images of the AFHQ dataset. The dataset includes three domains of cat, dog, and wildlife, each providing about 5000 images. By having multiple (three) domains and diverse images of various breeds per each domain, AFHQ sets a challenging image-to-image translation problem. For each domain, we select 500 images as a test set and provide all remaining images as a training set. To download the dataset, run the following command: 147 | 148 | ```bash 149 | bash download.sh afhq-dataset 150 | ``` 151 | 152 | 153 | **[Update: 2021.07.01]** We rebuild the original AFHQ dataset by using high-quality resize filtering (i.e., Lanczos resampling). Please see the [clean FID paper](https://arxiv.org/abs/2104.11222) that brings attention to the unfortunate software library situation for downsampling. We thank to [Alias-Free GAN](https://nvlabs.github.io/alias-free-gan/) authors for their suggestion and contribution to the updated AFHQ dataset. If you use the updated dataset, we recommend to cite not only our paper but also their paper. 154 | 155 | The differences from the original dataset are as follows: 156 | * We resize the images using Lanczos resampling instead of nearest neighbor downsampling. 157 | * About 2% of the original images had been removed. So the set is now has 15803 images, whereas the original had 16130. 158 | * Images are saved as PNG format to avoid compression artifacts. This makes the files bigger than the original, but it's worth it. 159 | 160 | 161 | To download the updated dataset, run the following command: 162 | 163 | ```bash 164 | bash download.sh afhq-v2-dataset 165 | ``` 166 | 167 |

168 | 169 | 170 | 171 | ## License 172 | The source code, pre-trained models, and dataset are available under [Creative Commons BY-NC 4.0](https://github.com/clovaai/stargan-v2/blob/master/LICENSE) license by NAVER Corporation. You can **use, copy, tranform and build upon** the material for **non-commercial purposes** as long as you give **appropriate credit** by citing our paper, and indicate if changes were made. 173 | 174 | For business inquiries, please contact clova-jobs@navercorp.com.
175 | For technical and other inquires, please contact yunjey.choi@navercorp.com. 176 | 177 | 178 | ## Citation 179 | If you find this work useful for your research, please cite our paper: 180 | 181 | ``` 182 | @inproceedings{choi2020starganv2, 183 | title={StarGAN v2: Diverse Image Synthesis for Multiple Domains}, 184 | author={Yunjey Choi and Youngjung Uh and Jaejun Yoo and Jung-Woo Ha}, 185 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 186 | year={2020} 187 | } 188 | ``` 189 | 190 | ## Acknowledgements 191 | We would like to thank the full-time and visiting Clova AI Research (now NAVER AI Lab) members for their valuable feedback and an early review: especially Seongjoon Oh, Junsuk Choe, Muhammad Ferjad Naeem, and Kyungjune Baek. We also thank Alias-Free GAN authors for their contribution to the updated AFHQ dataset. 192 | -------------------------------------------------------------------------------- /assets/afhq_dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/afhq_dataset.jpg -------------------------------------------------------------------------------- /assets/afhq_interpolation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/afhq_interpolation.gif -------------------------------------------------------------------------------- /assets/afhqv2_teaser2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/afhqv2_teaser2.jpg -------------------------------------------------------------------------------- /assets/celebahq_interpolation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/celebahq_interpolation.gif -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/flickr_cat_000495.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/flickr_cat_000495.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/flickr_cat_000557.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/flickr_cat_000557.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_000355.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000355.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_000491.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000491.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_000535.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000535.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_000623.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000623.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_000730.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000730.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_001479.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_001479.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_001699.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_001699.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/cat/pixabay_cat_003046.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_003046.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/flickr_dog_001072.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/flickr_dog_001072.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/pixabay_dog_000121.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000121.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/pixabay_dog_000322.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000322.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/pixabay_dog_000357.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000357.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/pixabay_dog_000409.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000409.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/pixabay_dog_000799.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000799.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/pixabay_dog_000890.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000890.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/dog/pixabay_dog_001082.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_001082.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_000731.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_000731.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_001223.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_001223.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_002020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_002020.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_002092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_002092.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_002933.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_002933.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_003137.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003137.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_003355.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003355.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_003796.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003796.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/flickr_wild_003969.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003969.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/ref/wild/pixabay_wild_000637.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/pixabay_wild_000637.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/cat/flickr_cat_000253.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/flickr_cat_000253.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/cat/pixabay_cat_000181.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_000181.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/cat/pixabay_cat_000241.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_000241.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/cat/pixabay_cat_000276.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_000276.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/cat/pixabay_cat_004826.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_004826.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/dog/flickr_dog_000094.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/flickr_dog_000094.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/dog/pixabay_dog_000321.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_000321.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/dog/pixabay_dog_000322.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_000322.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/dog/pixabay_dog_001082.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_001082.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/dog/pixabay_dog_002066.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_002066.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/wild/flickr_wild_000432.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_000432.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/wild/flickr_wild_000814.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_000814.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/wild/flickr_wild_002036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_002036.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/wild/flickr_wild_002159.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_002159.jpg -------------------------------------------------------------------------------- /assets/representative/afhq/src/wild/pixabay_wild_000558.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/pixabay_wild_000558.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/015248.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/015248.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/030321.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/030321.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/031796.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/031796.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/036619.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/036619.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/042373.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/042373.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/048197.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/048197.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/052599.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/052599.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/058150.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/058150.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/058225.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/058225.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/058881.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/058881.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/063109.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/063109.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/064119.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/064119.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/064307.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/064307.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/074075.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/074075.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/074934.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/074934.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/076551.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/076551.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/081680.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/081680.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/081871.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/081871.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/084913.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/084913.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/086986.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/086986.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/113393.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/113393.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/135626.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/135626.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/140613.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/140613.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/142595.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/142595.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/female/195650.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/195650.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/012712.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/012712.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/020167.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/020167.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/021612.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/021612.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/036367.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/036367.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/037023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/037023.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/038919.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/038919.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/047763.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/047763.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/060259.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/060259.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/067791.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/067791.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/077921.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/077921.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/083510.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/083510.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/094805.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/094805.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/116032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/116032.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/118017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/118017.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/137590.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/137590.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/145842.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/145842.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/153793.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/153793.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/156498.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/156498.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/164930.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/164930.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/189498.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/189498.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/ref/male/191084.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/191084.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/female/039913.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/039913.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/female/051340.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/051340.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/female/069067.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/069067.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/female/091623.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/091623.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/female/172559.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/172559.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/male/005735.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/005735.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/male/006930.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/006930.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/male/016387.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/016387.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/male/191300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/191300.jpg -------------------------------------------------------------------------------- /assets/representative/celeba_hq/src/male/196930.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/196930.jpg -------------------------------------------------------------------------------- /assets/representative/custom/female/custom_female.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/custom/female/custom_female.jpg -------------------------------------------------------------------------------- /assets/representative/custom/male/custom_male.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/custom/male/custom_male.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/teaser.jpg -------------------------------------------------------------------------------- /assets/youtube_video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/youtube_video.jpg -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/core/__init__.py -------------------------------------------------------------------------------- /core/checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import os 12 | import torch 13 | 14 | 15 | class CheckpointIO(object): 16 | def __init__(self, fname_template, data_parallel=False, **kwargs): 17 | os.makedirs(os.path.dirname(fname_template), exist_ok=True) 18 | self.fname_template = fname_template 19 | self.module_dict = kwargs 20 | self.data_parallel = data_parallel 21 | 22 | def register(self, **kwargs): 23 | self.module_dict.update(kwargs) 24 | 25 | def save(self, step): 26 | fname = self.fname_template.format(step) 27 | print('Saving checkpoint into %s...' % fname) 28 | outdict = {} 29 | for name, module in self.module_dict.items(): 30 | if self.data_parallel: 31 | outdict[name] = module.module.state_dict() 32 | else: 33 | outdict[name] = module.state_dict() 34 | 35 | torch.save(outdict, fname) 36 | 37 | def load(self, step): 38 | fname = self.fname_template.format(step) 39 | assert os.path.exists(fname), fname + ' does not exist!' 40 | print('Loading checkpoint from %s...' % fname) 41 | if torch.cuda.is_available(): 42 | module_dict = torch.load(fname) 43 | else: 44 | module_dict = torch.load(fname, map_location=torch.device('cpu')) 45 | 46 | for name, module in self.module_dict.items(): 47 | if self.data_parallel: 48 | module.module.load_state_dict(module_dict[name]) 49 | else: 50 | module.load_state_dict(module_dict[name]) 51 | -------------------------------------------------------------------------------- /core/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | from pathlib import Path 12 | from itertools import chain 13 | import os 14 | import random 15 | 16 | from munch import Munch 17 | from PIL import Image 18 | import numpy as np 19 | 20 | import torch 21 | from torch.utils import data 22 | from torch.utils.data.sampler import WeightedRandomSampler 23 | from torchvision import transforms 24 | from torchvision.datasets import ImageFolder 25 | 26 | 27 | def listdir(dname): 28 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) 29 | for ext in ['png', 'jpg', 'jpeg', 'JPG']])) 30 | return fnames 31 | 32 | 33 | class DefaultDataset(data.Dataset): 34 | def __init__(self, root, transform=None): 35 | self.samples = listdir(root) 36 | self.samples.sort() 37 | self.transform = transform 38 | self.targets = None 39 | 40 | def __getitem__(self, index): 41 | fname = self.samples[index] 42 | img = Image.open(fname).convert('RGB') 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | return img 46 | 47 | def __len__(self): 48 | return len(self.samples) 49 | 50 | 51 | class ReferenceDataset(data.Dataset): 52 | def __init__(self, root, transform=None): 53 | self.samples, self.targets = self._make_dataset(root) 54 | self.transform = transform 55 | 56 | def _make_dataset(self, root): 57 | domains = os.listdir(root) 58 | fnames, fnames2, labels = [], [], [] 59 | for idx, domain in enumerate(sorted(domains)): 60 | class_dir = os.path.join(root, domain) 61 | cls_fnames = listdir(class_dir) 62 | fnames += cls_fnames 63 | fnames2 += random.sample(cls_fnames, len(cls_fnames)) 64 | labels += [idx] * len(cls_fnames) 65 | return list(zip(fnames, fnames2)), labels 66 | 67 | def __getitem__(self, index): 68 | fname, fname2 = self.samples[index] 69 | label = self.targets[index] 70 | img = Image.open(fname).convert('RGB') 71 | img2 = Image.open(fname2).convert('RGB') 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | img2 = self.transform(img2) 75 | return img, img2, label 76 | 77 | def __len__(self): 78 | return len(self.targets) 79 | 80 | 81 | def _make_balanced_sampler(labels): 82 | class_counts = np.bincount(labels) 83 | class_weights = 1. / class_counts 84 | weights = class_weights[labels] 85 | return WeightedRandomSampler(weights, len(weights)) 86 | 87 | 88 | def get_train_loader(root, which='source', img_size=256, 89 | batch_size=8, prob=0.5, num_workers=4): 90 | print('Preparing DataLoader to fetch %s images ' 91 | 'during the training phase...' % which) 92 | 93 | crop = transforms.RandomResizedCrop( 94 | img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1]) 95 | rand_crop = transforms.Lambda( 96 | lambda x: crop(x) if random.random() < prob else x) 97 | 98 | transform = transforms.Compose([ 99 | rand_crop, 100 | transforms.Resize([img_size, img_size]), 101 | transforms.RandomHorizontalFlip(), 102 | transforms.ToTensor(), 103 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 104 | std=[0.5, 0.5, 0.5]), 105 | ]) 106 | 107 | if which == 'source': 108 | dataset = ImageFolder(root, transform) 109 | elif which == 'reference': 110 | dataset = ReferenceDataset(root, transform) 111 | else: 112 | raise NotImplementedError 113 | 114 | sampler = _make_balanced_sampler(dataset.targets) 115 | return data.DataLoader(dataset=dataset, 116 | batch_size=batch_size, 117 | sampler=sampler, 118 | num_workers=num_workers, 119 | pin_memory=True, 120 | drop_last=True) 121 | 122 | 123 | def get_eval_loader(root, img_size=256, batch_size=32, 124 | imagenet_normalize=True, shuffle=True, 125 | num_workers=4, drop_last=False): 126 | print('Preparing DataLoader for the evaluation phase...') 127 | if imagenet_normalize: 128 | height, width = 299, 299 129 | mean = [0.485, 0.456, 0.406] 130 | std = [0.229, 0.224, 0.225] 131 | else: 132 | height, width = img_size, img_size 133 | mean = [0.5, 0.5, 0.5] 134 | std = [0.5, 0.5, 0.5] 135 | 136 | transform = transforms.Compose([ 137 | transforms.Resize([img_size, img_size]), 138 | transforms.Resize([height, width]), 139 | transforms.ToTensor(), 140 | transforms.Normalize(mean=mean, std=std) 141 | ]) 142 | 143 | dataset = DefaultDataset(root, transform=transform) 144 | return data.DataLoader(dataset=dataset, 145 | batch_size=batch_size, 146 | shuffle=shuffle, 147 | num_workers=num_workers, 148 | pin_memory=True, 149 | drop_last=drop_last) 150 | 151 | 152 | def get_test_loader(root, img_size=256, batch_size=32, 153 | shuffle=True, num_workers=4): 154 | print('Preparing DataLoader for the generation phase...') 155 | transform = transforms.Compose([ 156 | transforms.Resize([img_size, img_size]), 157 | transforms.ToTensor(), 158 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 159 | std=[0.5, 0.5, 0.5]), 160 | ]) 161 | 162 | dataset = ImageFolder(root, transform) 163 | return data.DataLoader(dataset=dataset, 164 | batch_size=batch_size, 165 | shuffle=shuffle, 166 | num_workers=num_workers, 167 | pin_memory=True) 168 | 169 | 170 | class InputFetcher: 171 | def __init__(self, loader, loader_ref=None, latent_dim=16, mode=''): 172 | self.loader = loader 173 | self.loader_ref = loader_ref 174 | self.latent_dim = latent_dim 175 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 176 | self.mode = mode 177 | 178 | def _fetch_inputs(self): 179 | try: 180 | x, y = next(self.iter) 181 | except (AttributeError, StopIteration): 182 | self.iter = iter(self.loader) 183 | x, y = next(self.iter) 184 | return x, y 185 | 186 | def _fetch_refs(self): 187 | try: 188 | x, x2, y = next(self.iter_ref) 189 | except (AttributeError, StopIteration): 190 | self.iter_ref = iter(self.loader_ref) 191 | x, x2, y = next(self.iter_ref) 192 | return x, x2, y 193 | 194 | def __next__(self): 195 | x, y = self._fetch_inputs() 196 | if self.mode == 'train': 197 | x_ref, x_ref2, y_ref = self._fetch_refs() 198 | z_trg = torch.randn(x.size(0), self.latent_dim) 199 | z_trg2 = torch.randn(x.size(0), self.latent_dim) 200 | inputs = Munch(x_src=x, y_src=y, y_ref=y_ref, 201 | x_ref=x_ref, x_ref2=x_ref2, 202 | z_trg=z_trg, z_trg2=z_trg2) 203 | elif self.mode == 'val': 204 | x_ref, y_ref = self._fetch_inputs() 205 | inputs = Munch(x_src=x, y_src=y, 206 | x_ref=x_ref, y_ref=y_ref) 207 | elif self.mode == 'test': 208 | inputs = Munch(x=x, y=y) 209 | else: 210 | raise NotImplementedError 211 | 212 | return Munch({k: v.to(self.device) 213 | for k, v in inputs.items()}) -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import copy 12 | import math 13 | 14 | from munch import Munch 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | from core.wing import FAN 21 | 22 | 23 | class ResBlk(nn.Module): 24 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), 25 | normalize=False, downsample=False): 26 | super().__init__() 27 | self.actv = actv 28 | self.normalize = normalize 29 | self.downsample = downsample 30 | self.learned_sc = dim_in != dim_out 31 | self._build_weights(dim_in, dim_out) 32 | 33 | def _build_weights(self, dim_in, dim_out): 34 | self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) 35 | self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) 36 | if self.normalize: 37 | self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) 38 | self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) 39 | if self.learned_sc: 40 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 41 | 42 | def _shortcut(self, x): 43 | if self.learned_sc: 44 | x = self.conv1x1(x) 45 | if self.downsample: 46 | x = F.avg_pool2d(x, 2) 47 | return x 48 | 49 | def _residual(self, x): 50 | if self.normalize: 51 | x = self.norm1(x) 52 | x = self.actv(x) 53 | x = self.conv1(x) 54 | if self.downsample: 55 | x = F.avg_pool2d(x, 2) 56 | if self.normalize: 57 | x = self.norm2(x) 58 | x = self.actv(x) 59 | x = self.conv2(x) 60 | return x 61 | 62 | def forward(self, x): 63 | x = self._shortcut(x) + self._residual(x) 64 | return x / math.sqrt(2) # unit variance 65 | 66 | 67 | class AdaIN(nn.Module): 68 | def __init__(self, style_dim, num_features): 69 | super().__init__() 70 | self.norm = nn.InstanceNorm2d(num_features, affine=False) 71 | self.fc = nn.Linear(style_dim, num_features*2) 72 | 73 | def forward(self, x, s): 74 | h = self.fc(s) 75 | h = h.view(h.size(0), h.size(1), 1, 1) 76 | gamma, beta = torch.chunk(h, chunks=2, dim=1) 77 | return (1 + gamma) * self.norm(x) + beta 78 | 79 | 80 | class AdainResBlk(nn.Module): 81 | def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0, 82 | actv=nn.LeakyReLU(0.2), upsample=False): 83 | super().__init__() 84 | self.w_hpf = w_hpf 85 | self.actv = actv 86 | self.upsample = upsample 87 | self.learned_sc = dim_in != dim_out 88 | self._build_weights(dim_in, dim_out, style_dim) 89 | 90 | def _build_weights(self, dim_in, dim_out, style_dim=64): 91 | self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) 92 | self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) 93 | self.norm1 = AdaIN(style_dim, dim_in) 94 | self.norm2 = AdaIN(style_dim, dim_out) 95 | if self.learned_sc: 96 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 97 | 98 | def _shortcut(self, x): 99 | if self.upsample: 100 | x = F.interpolate(x, scale_factor=2, mode='nearest') 101 | if self.learned_sc: 102 | x = self.conv1x1(x) 103 | return x 104 | 105 | def _residual(self, x, s): 106 | x = self.norm1(x, s) 107 | x = self.actv(x) 108 | if self.upsample: 109 | x = F.interpolate(x, scale_factor=2, mode='nearest') 110 | x = self.conv1(x) 111 | x = self.norm2(x, s) 112 | x = self.actv(x) 113 | x = self.conv2(x) 114 | return x 115 | 116 | def forward(self, x, s): 117 | out = self._residual(x, s) 118 | if self.w_hpf == 0: 119 | out = (out + self._shortcut(x)) / math.sqrt(2) 120 | return out 121 | 122 | 123 | class HighPass(nn.Module): 124 | def __init__(self, w_hpf, device): 125 | super(HighPass, self).__init__() 126 | self.register_buffer('filter', 127 | torch.tensor([[-1, -1, -1], 128 | [-1, 8., -1], 129 | [-1, -1, -1]]) / w_hpf) 130 | 131 | def forward(self, x): 132 | filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1) 133 | return F.conv2d(x, filter, padding=1, groups=x.size(1)) 134 | 135 | 136 | class Generator(nn.Module): 137 | def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1): 138 | super().__init__() 139 | dim_in = 2**14 // img_size 140 | self.img_size = img_size 141 | self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1) 142 | self.encode = nn.ModuleList() 143 | self.decode = nn.ModuleList() 144 | self.to_rgb = nn.Sequential( 145 | nn.InstanceNorm2d(dim_in, affine=True), 146 | nn.LeakyReLU(0.2), 147 | nn.Conv2d(dim_in, 3, 1, 1, 0)) 148 | 149 | # down/up-sampling blocks 150 | repeat_num = int(np.log2(img_size)) - 4 151 | if w_hpf > 0: 152 | repeat_num += 1 153 | for _ in range(repeat_num): 154 | dim_out = min(dim_in*2, max_conv_dim) 155 | self.encode.append( 156 | ResBlk(dim_in, dim_out, normalize=True, downsample=True)) 157 | self.decode.insert( 158 | 0, AdainResBlk(dim_out, dim_in, style_dim, 159 | w_hpf=w_hpf, upsample=True)) # stack-like 160 | dim_in = dim_out 161 | 162 | # bottleneck blocks 163 | for _ in range(2): 164 | self.encode.append( 165 | ResBlk(dim_out, dim_out, normalize=True)) 166 | self.decode.insert( 167 | 0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf)) 168 | 169 | if w_hpf > 0: 170 | device = torch.device( 171 | 'cuda' if torch.cuda.is_available() else 'cpu') 172 | self.hpf = HighPass(w_hpf, device) 173 | 174 | def forward(self, x, s, masks=None): 175 | x = self.from_rgb(x) 176 | cache = {} 177 | for block in self.encode: 178 | if (masks is not None) and (x.size(2) in [32, 64, 128]): 179 | cache[x.size(2)] = x 180 | x = block(x) 181 | for block in self.decode: 182 | x = block(x, s) 183 | if (masks is not None) and (x.size(2) in [32, 64, 128]): 184 | mask = masks[0] if x.size(2) in [32] else masks[1] 185 | mask = F.interpolate(mask, size=x.size(2), mode='bilinear') 186 | x = x + self.hpf(mask * cache[x.size(2)]) 187 | return self.to_rgb(x) 188 | 189 | 190 | class MappingNetwork(nn.Module): 191 | def __init__(self, latent_dim=16, style_dim=64, num_domains=2): 192 | super().__init__() 193 | layers = [] 194 | layers += [nn.Linear(latent_dim, 512)] 195 | layers += [nn.ReLU()] 196 | for _ in range(3): 197 | layers += [nn.Linear(512, 512)] 198 | layers += [nn.ReLU()] 199 | self.shared = nn.Sequential(*layers) 200 | 201 | self.unshared = nn.ModuleList() 202 | for _ in range(num_domains): 203 | self.unshared += [nn.Sequential(nn.Linear(512, 512), 204 | nn.ReLU(), 205 | nn.Linear(512, 512), 206 | nn.ReLU(), 207 | nn.Linear(512, 512), 208 | nn.ReLU(), 209 | nn.Linear(512, style_dim))] 210 | 211 | def forward(self, z, y): 212 | h = self.shared(z) 213 | out = [] 214 | for layer in self.unshared: 215 | out += [layer(h)] 216 | out = torch.stack(out, dim=1) # (batch, num_domains, style_dim) 217 | idx = torch.LongTensor(range(y.size(0))).to(y.device) 218 | s = out[idx, y] # (batch, style_dim) 219 | return s 220 | 221 | 222 | class StyleEncoder(nn.Module): 223 | def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512): 224 | super().__init__() 225 | dim_in = 2**14 // img_size 226 | blocks = [] 227 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] 228 | 229 | repeat_num = int(np.log2(img_size)) - 2 230 | for _ in range(repeat_num): 231 | dim_out = min(dim_in*2, max_conv_dim) 232 | blocks += [ResBlk(dim_in, dim_out, downsample=True)] 233 | dim_in = dim_out 234 | 235 | blocks += [nn.LeakyReLU(0.2)] 236 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] 237 | blocks += [nn.LeakyReLU(0.2)] 238 | self.shared = nn.Sequential(*blocks) 239 | 240 | self.unshared = nn.ModuleList() 241 | for _ in range(num_domains): 242 | self.unshared += [nn.Linear(dim_out, style_dim)] 243 | 244 | def forward(self, x, y): 245 | h = self.shared(x) 246 | h = h.view(h.size(0), -1) 247 | out = [] 248 | for layer in self.unshared: 249 | out += [layer(h)] 250 | out = torch.stack(out, dim=1) # (batch, num_domains, style_dim) 251 | idx = torch.LongTensor(range(y.size(0))).to(y.device) 252 | s = out[idx, y] # (batch, style_dim) 253 | return s 254 | 255 | 256 | class Discriminator(nn.Module): 257 | def __init__(self, img_size=256, num_domains=2, max_conv_dim=512): 258 | super().__init__() 259 | dim_in = 2**14 // img_size 260 | blocks = [] 261 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] 262 | 263 | repeat_num = int(np.log2(img_size)) - 2 264 | for _ in range(repeat_num): 265 | dim_out = min(dim_in*2, max_conv_dim) 266 | blocks += [ResBlk(dim_in, dim_out, downsample=True)] 267 | dim_in = dim_out 268 | 269 | blocks += [nn.LeakyReLU(0.2)] 270 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] 271 | blocks += [nn.LeakyReLU(0.2)] 272 | blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)] 273 | self.main = nn.Sequential(*blocks) 274 | 275 | def forward(self, x, y): 276 | out = self.main(x) 277 | out = out.view(out.size(0), -1) # (batch, num_domains) 278 | idx = torch.LongTensor(range(y.size(0))).to(y.device) 279 | out = out[idx, y] # (batch) 280 | return out 281 | 282 | 283 | def build_model(args): 284 | generator = nn.DataParallel(Generator(args.img_size, args.style_dim, w_hpf=args.w_hpf)) 285 | mapping_network = nn.DataParallel(MappingNetwork(args.latent_dim, args.style_dim, args.num_domains)) 286 | style_encoder = nn.DataParallel(StyleEncoder(args.img_size, args.style_dim, args.num_domains)) 287 | discriminator = nn.DataParallel(Discriminator(args.img_size, args.num_domains)) 288 | generator_ema = copy.deepcopy(generator) 289 | mapping_network_ema = copy.deepcopy(mapping_network) 290 | style_encoder_ema = copy.deepcopy(style_encoder) 291 | 292 | nets = Munch(generator=generator, 293 | mapping_network=mapping_network, 294 | style_encoder=style_encoder, 295 | discriminator=discriminator) 296 | nets_ema = Munch(generator=generator_ema, 297 | mapping_network=mapping_network_ema, 298 | style_encoder=style_encoder_ema) 299 | 300 | if args.w_hpf > 0: 301 | fan = nn.DataParallel(FAN(fname_pretrained=args.wing_path).eval()) 302 | fan.get_heatmap = fan.module.get_heatmap 303 | nets.fan = fan 304 | nets_ema.fan = fan 305 | 306 | return nets, nets_ema 307 | -------------------------------------------------------------------------------- /core/solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import os 12 | from os.path import join as ospj 13 | import time 14 | import datetime 15 | from munch import Munch 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | from core.model import build_model 22 | from core.checkpoint import CheckpointIO 23 | from core.data_loader import InputFetcher 24 | import core.utils as utils 25 | from metrics.eval import calculate_metrics 26 | 27 | 28 | class Solver(nn.Module): 29 | def __init__(self, args): 30 | super().__init__() 31 | self.args = args 32 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 33 | 34 | self.nets, self.nets_ema = build_model(args) 35 | # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device) 36 | for name, module in self.nets.items(): 37 | utils.print_network(module, name) 38 | setattr(self, name, module) 39 | for name, module in self.nets_ema.items(): 40 | setattr(self, name + '_ema', module) 41 | 42 | if args.mode == 'train': 43 | self.optims = Munch() 44 | for net in self.nets.keys(): 45 | if net == 'fan': 46 | continue 47 | self.optims[net] = torch.optim.Adam( 48 | params=self.nets[net].parameters(), 49 | lr=args.f_lr if net == 'mapping_network' else args.lr, 50 | betas=[args.beta1, args.beta2], 51 | weight_decay=args.weight_decay) 52 | 53 | self.ckptios = [ 54 | CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), data_parallel=True, **self.nets), 55 | CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), data_parallel=True, **self.nets_ema), 56 | CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)] 57 | else: 58 | self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), data_parallel=True, **self.nets_ema)] 59 | 60 | self.to(self.device) 61 | for name, network in self.named_children(): 62 | # Do not initialize the FAN parameters 63 | if ('ema' not in name) and ('fan' not in name): 64 | print('Initializing %s...' % name) 65 | network.apply(utils.he_init) 66 | 67 | def _save_checkpoint(self, step): 68 | for ckptio in self.ckptios: 69 | ckptio.save(step) 70 | 71 | def _load_checkpoint(self, step): 72 | for ckptio in self.ckptios: 73 | ckptio.load(step) 74 | 75 | def _reset_grad(self): 76 | for optim in self.optims.values(): 77 | optim.zero_grad() 78 | 79 | def train(self, loaders): 80 | args = self.args 81 | nets = self.nets 82 | nets_ema = self.nets_ema 83 | optims = self.optims 84 | 85 | # fetch random validation images for debugging 86 | fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train') 87 | fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val') 88 | inputs_val = next(fetcher_val) 89 | 90 | # resume training if necessary 91 | if args.resume_iter > 0: 92 | self._load_checkpoint(args.resume_iter) 93 | 94 | # remember the initial value of ds weight 95 | initial_lambda_ds = args.lambda_ds 96 | 97 | print('Start training...') 98 | start_time = time.time() 99 | for i in range(args.resume_iter, args.total_iters): 100 | # fetch images and labels 101 | inputs = next(fetcher) 102 | x_real, y_org = inputs.x_src, inputs.y_src 103 | x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref 104 | z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2 105 | 106 | masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None 107 | 108 | # train the discriminator 109 | d_loss, d_losses_latent = compute_d_loss( 110 | nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) 111 | self._reset_grad() 112 | d_loss.backward() 113 | optims.discriminator.step() 114 | 115 | d_loss, d_losses_ref = compute_d_loss( 116 | nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) 117 | self._reset_grad() 118 | d_loss.backward() 119 | optims.discriminator.step() 120 | 121 | # train the generator 122 | g_loss, g_losses_latent = compute_g_loss( 123 | nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) 124 | self._reset_grad() 125 | g_loss.backward() 126 | optims.generator.step() 127 | optims.mapping_network.step() 128 | optims.style_encoder.step() 129 | 130 | g_loss, g_losses_ref = compute_g_loss( 131 | nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks) 132 | self._reset_grad() 133 | g_loss.backward() 134 | optims.generator.step() 135 | 136 | # compute moving average of network parameters 137 | moving_average(nets.generator, nets_ema.generator, beta=0.999) 138 | moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) 139 | moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) 140 | 141 | # decay weight for diversity sensitive loss 142 | if args.lambda_ds > 0: 143 | args.lambda_ds -= (initial_lambda_ds / args.ds_iter) 144 | 145 | # print out log info 146 | if (i+1) % args.print_every == 0: 147 | elapsed = time.time() - start_time 148 | elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] 149 | log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters) 150 | all_losses = dict() 151 | for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], 152 | ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): 153 | for key, value in loss.items(): 154 | all_losses[prefix + key] = value 155 | all_losses['G/lambda_ds'] = args.lambda_ds 156 | log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()]) 157 | print(log) 158 | 159 | # generate images for debugging 160 | if (i+1) % args.sample_every == 0: 161 | os.makedirs(args.sample_dir, exist_ok=True) 162 | utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1) 163 | 164 | # save model checkpoints 165 | if (i+1) % args.save_every == 0: 166 | self._save_checkpoint(step=i+1) 167 | 168 | # compute FID and LPIPS if necessary 169 | if (i+1) % args.eval_every == 0: 170 | calculate_metrics(nets_ema, args, i+1, mode='latent') 171 | calculate_metrics(nets_ema, args, i+1, mode='reference') 172 | 173 | @torch.no_grad() 174 | def sample(self, loaders): 175 | args = self.args 176 | nets_ema = self.nets_ema 177 | os.makedirs(args.result_dir, exist_ok=True) 178 | self._load_checkpoint(args.resume_iter) 179 | 180 | src = next(InputFetcher(loaders.src, None, args.latent_dim, 'test')) 181 | ref = next(InputFetcher(loaders.ref, None, args.latent_dim, 'test')) 182 | 183 | fname = ospj(args.result_dir, 'reference.jpg') 184 | print('Working on {}...'.format(fname)) 185 | utils.translate_using_reference(nets_ema, args, src.x, ref.x, ref.y, fname) 186 | 187 | fname = ospj(args.result_dir, 'video_ref.mp4') 188 | print('Working on {}...'.format(fname)) 189 | utils.video_ref(nets_ema, args, src.x, ref.x, ref.y, fname) 190 | 191 | @torch.no_grad() 192 | def evaluate(self): 193 | args = self.args 194 | nets_ema = self.nets_ema 195 | resume_iter = args.resume_iter 196 | self._load_checkpoint(args.resume_iter) 197 | calculate_metrics(nets_ema, args, step=resume_iter, mode='latent') 198 | calculate_metrics(nets_ema, args, step=resume_iter, mode='reference') 199 | 200 | 201 | def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None): 202 | assert (z_trg is None) != (x_ref is None) 203 | # with real images 204 | x_real.requires_grad_() 205 | out = nets.discriminator(x_real, y_org) 206 | loss_real = adv_loss(out, 1) 207 | loss_reg = r1_reg(out, x_real) 208 | 209 | # with fake images 210 | with torch.no_grad(): 211 | if z_trg is not None: 212 | s_trg = nets.mapping_network(z_trg, y_trg) 213 | else: # x_ref is not None 214 | s_trg = nets.style_encoder(x_ref, y_trg) 215 | 216 | x_fake = nets.generator(x_real, s_trg, masks=masks) 217 | out = nets.discriminator(x_fake, y_trg) 218 | loss_fake = adv_loss(out, 0) 219 | 220 | loss = loss_real + loss_fake + args.lambda_reg * loss_reg 221 | return loss, Munch(real=loss_real.item(), 222 | fake=loss_fake.item(), 223 | reg=loss_reg.item()) 224 | 225 | 226 | def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None): 227 | assert (z_trgs is None) != (x_refs is None) 228 | if z_trgs is not None: 229 | z_trg, z_trg2 = z_trgs 230 | if x_refs is not None: 231 | x_ref, x_ref2 = x_refs 232 | 233 | # adversarial loss 234 | if z_trgs is not None: 235 | s_trg = nets.mapping_network(z_trg, y_trg) 236 | else: 237 | s_trg = nets.style_encoder(x_ref, y_trg) 238 | 239 | x_fake = nets.generator(x_real, s_trg, masks=masks) 240 | out = nets.discriminator(x_fake, y_trg) 241 | loss_adv = adv_loss(out, 1) 242 | 243 | # style reconstruction loss 244 | s_pred = nets.style_encoder(x_fake, y_trg) 245 | loss_sty = torch.mean(torch.abs(s_pred - s_trg)) 246 | 247 | # diversity sensitive loss 248 | if z_trgs is not None: 249 | s_trg2 = nets.mapping_network(z_trg2, y_trg) 250 | else: 251 | s_trg2 = nets.style_encoder(x_ref2, y_trg) 252 | x_fake2 = nets.generator(x_real, s_trg2, masks=masks) 253 | x_fake2 = x_fake2.detach() 254 | loss_ds = torch.mean(torch.abs(x_fake - x_fake2)) 255 | 256 | # cycle-consistency loss 257 | masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None 258 | s_org = nets.style_encoder(x_real, y_org) 259 | x_rec = nets.generator(x_fake, s_org, masks=masks) 260 | loss_cyc = torch.mean(torch.abs(x_rec - x_real)) 261 | 262 | loss = loss_adv + args.lambda_sty * loss_sty \ 263 | - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc 264 | return loss, Munch(adv=loss_adv.item(), 265 | sty=loss_sty.item(), 266 | ds=loss_ds.item(), 267 | cyc=loss_cyc.item()) 268 | 269 | 270 | def moving_average(model, model_test, beta=0.999): 271 | for param, param_test in zip(model.parameters(), model_test.parameters()): 272 | param_test.data = torch.lerp(param.data, param_test.data, beta) 273 | 274 | 275 | def adv_loss(logits, target): 276 | assert target in [1, 0] 277 | targets = torch.full_like(logits, fill_value=target) 278 | loss = F.binary_cross_entropy_with_logits(logits, targets) 279 | return loss 280 | 281 | 282 | def r1_reg(d_out, x_in): 283 | # zero-centered gradient penalty for real images 284 | batch_size = x_in.size(0) 285 | grad_dout = torch.autograd.grad( 286 | outputs=d_out.sum(), inputs=x_in, 287 | create_graph=True, retain_graph=True, only_inputs=True 288 | )[0] 289 | grad_dout2 = grad_dout.pow(2) 290 | assert(grad_dout2.size() == x_in.size()) 291 | reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) 292 | return reg -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import os 12 | from os.path import join as ospj 13 | import json 14 | import glob 15 | from shutil import copyfile 16 | 17 | from tqdm import tqdm 18 | import ffmpeg 19 | 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torchvision 25 | import torchvision.utils as vutils 26 | 27 | 28 | def save_json(json_file, filename): 29 | with open(filename, 'w') as f: 30 | json.dump(json_file, f, indent=4, sort_keys=False) 31 | 32 | 33 | def print_network(network, name): 34 | num_params = 0 35 | for p in network.parameters(): 36 | num_params += p.numel() 37 | # print(network) 38 | print("Number of parameters of %s: %i" % (name, num_params)) 39 | 40 | 41 | def he_init(module): 42 | if isinstance(module, nn.Conv2d): 43 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') 44 | if module.bias is not None: 45 | nn.init.constant_(module.bias, 0) 46 | if isinstance(module, nn.Linear): 47 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') 48 | if module.bias is not None: 49 | nn.init.constant_(module.bias, 0) 50 | 51 | 52 | def denormalize(x): 53 | out = (x + 1) / 2 54 | return out.clamp_(0, 1) 55 | 56 | 57 | def save_image(x, ncol, filename): 58 | x = denormalize(x) 59 | vutils.save_image(x.cpu(), filename, nrow=ncol, padding=0) 60 | 61 | 62 | @torch.no_grad() 63 | def translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename): 64 | N, C, H, W = x_src.size() 65 | s_ref = nets.style_encoder(x_ref, y_ref) 66 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None 67 | x_fake = nets.generator(x_src, s_ref, masks=masks) 68 | s_src = nets.style_encoder(x_src, y_src) 69 | masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None 70 | x_rec = nets.generator(x_fake, s_src, masks=masks) 71 | x_concat = [x_src, x_ref, x_fake, x_rec] 72 | x_concat = torch.cat(x_concat, dim=0) 73 | save_image(x_concat, N, filename) 74 | del x_concat 75 | 76 | 77 | @torch.no_grad() 78 | def translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename): 79 | N, C, H, W = x_src.size() 80 | latent_dim = z_trg_list[0].size(1) 81 | x_concat = [x_src] 82 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None 83 | 84 | for i, y_trg in enumerate(y_trg_list): 85 | z_many = torch.randn(10000, latent_dim).to(x_src.device) 86 | y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0]) 87 | s_many = nets.mapping_network(z_many, y_many) 88 | s_avg = torch.mean(s_many, dim=0, keepdim=True) 89 | s_avg = s_avg.repeat(N, 1) 90 | 91 | for z_trg in z_trg_list: 92 | s_trg = nets.mapping_network(z_trg, y_trg) 93 | s_trg = torch.lerp(s_avg, s_trg, psi) 94 | x_fake = nets.generator(x_src, s_trg, masks=masks) 95 | x_concat += [x_fake] 96 | 97 | x_concat = torch.cat(x_concat, dim=0) 98 | save_image(x_concat, N, filename) 99 | 100 | 101 | @torch.no_grad() 102 | def translate_using_reference(nets, args, x_src, x_ref, y_ref, filename): 103 | N, C, H, W = x_src.size() 104 | wb = torch.ones(1, C, H, W).to(x_src.device) 105 | x_src_with_wb = torch.cat([wb, x_src], dim=0) 106 | 107 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None 108 | s_ref = nets.style_encoder(x_ref, y_ref) 109 | s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1) 110 | x_concat = [x_src_with_wb] 111 | for i, s_ref in enumerate(s_ref_list): 112 | x_fake = nets.generator(x_src, s_ref, masks=masks) 113 | x_fake_with_ref = torch.cat([x_ref[i:i+1], x_fake], dim=0) 114 | x_concat += [x_fake_with_ref] 115 | 116 | x_concat = torch.cat(x_concat, dim=0) 117 | save_image(x_concat, N+1, filename) 118 | del x_concat 119 | 120 | 121 | @torch.no_grad() 122 | def debug_image(nets, args, inputs, step): 123 | x_src, y_src = inputs.x_src, inputs.y_src 124 | x_ref, y_ref = inputs.x_ref, inputs.y_ref 125 | 126 | device = inputs.x_src.device 127 | N = inputs.x_src.size(0) 128 | 129 | # translate and reconstruct (reference-guided) 130 | filename = ospj(args.sample_dir, '%06d_cycle_consistency.jpg' % (step)) 131 | translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename) 132 | 133 | # latent-guided image synthesis 134 | y_trg_list = [torch.tensor(y).repeat(N).to(device) 135 | for y in range(min(args.num_domains, 5))] 136 | z_trg_list = torch.randn(args.num_outs_per_domain, 1, args.latent_dim).repeat(1, N, 1).to(device) 137 | for psi in [0.5, 0.7, 1.0]: 138 | filename = ospj(args.sample_dir, '%06d_latent_psi_%.1f.jpg' % (step, psi)) 139 | translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename) 140 | 141 | # reference-guided image synthesis 142 | filename = ospj(args.sample_dir, '%06d_reference.jpg' % (step)) 143 | translate_using_reference(nets, args, x_src, x_ref, y_ref, filename) 144 | 145 | 146 | # ======================= # 147 | # Video-related functions # 148 | # ======================= # 149 | 150 | 151 | def sigmoid(x, w=1): 152 | return 1. / (1 + np.exp(-w * x)) 153 | 154 | 155 | def get_alphas(start=-5, end=5, step=0.5, len_tail=10): 156 | return [0] + [sigmoid(alpha) for alpha in np.arange(start, end, step)] + [1] * len_tail 157 | 158 | 159 | def interpolate(nets, args, x_src, s_prev, s_next): 160 | ''' returns T x C x H x W ''' 161 | B = x_src.size(0) 162 | frames = [] 163 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None 164 | alphas = get_alphas() 165 | 166 | for alpha in alphas: 167 | s_ref = torch.lerp(s_prev, s_next, alpha) 168 | x_fake = nets.generator(x_src, s_ref, masks=masks) 169 | entries = torch.cat([x_src.cpu(), x_fake.cpu()], dim=2) 170 | frame = torchvision.utils.make_grid(entries, nrow=B, padding=0, pad_value=-1).unsqueeze(0) 171 | frames.append(frame) 172 | frames = torch.cat(frames) 173 | return frames 174 | 175 | 176 | def slide(entries, margin=32): 177 | """Returns a sliding reference window. 178 | Args: 179 | entries: a list containing two reference images, x_prev and x_next, 180 | both of which has a shape (1, 3, 256, 256) 181 | Returns: 182 | canvas: output slide of shape (num_frames, 3, 256*2, 256+margin) 183 | """ 184 | _, C, H, W = entries[0].shape 185 | alphas = get_alphas() 186 | T = len(alphas) # number of frames 187 | 188 | canvas = - torch.ones((T, C, H*2, W + margin)) 189 | merged = torch.cat(entries, dim=2) # (1, 3, 512, 256) 190 | for t, alpha in enumerate(alphas): 191 | top = int(H * (1 - alpha)) # top, bottom for canvas 192 | bottom = H * 2 193 | m_top = 0 # top, bottom for merged 194 | m_bottom = 2 * H - top 195 | canvas[t, :, top:bottom, :W] = merged[:, :, m_top:m_bottom, :] 196 | return canvas 197 | 198 | 199 | @torch.no_grad() 200 | def video_ref(nets, args, x_src, x_ref, y_ref, fname): 201 | video = [] 202 | s_ref = nets.style_encoder(x_ref, y_ref) 203 | s_prev = None 204 | for data_next in tqdm(zip(x_ref, y_ref, s_ref), 'video_ref', len(x_ref)): 205 | x_next, y_next, s_next = [d.unsqueeze(0) for d in data_next] 206 | if s_prev is None: 207 | x_prev, y_prev, s_prev = x_next, y_next, s_next 208 | continue 209 | if y_prev != y_next: 210 | x_prev, y_prev, s_prev = x_next, y_next, s_next 211 | continue 212 | 213 | interpolated = interpolate(nets, args, x_src, s_prev, s_next) 214 | entries = [x_prev, x_next] 215 | slided = slide(entries) # (T, C, 256*2, 256) 216 | frames = torch.cat([slided, interpolated], dim=3).cpu() # (T, C, 256*2, 256*(batch+1)) 217 | video.append(frames) 218 | x_prev, y_prev, s_prev = x_next, y_next, s_next 219 | 220 | # append last frame 10 time 221 | for _ in range(10): 222 | video.append(frames[-1:]) 223 | video = tensor2ndarray255(torch.cat(video)) 224 | save_video(fname, video) 225 | 226 | 227 | @torch.no_grad() 228 | def video_latent(nets, args, x_src, y_list, z_list, psi, fname): 229 | latent_dim = z_list[0].size(1) 230 | s_list = [] 231 | for i, y_trg in enumerate(y_list): 232 | z_many = torch.randn(10000, latent_dim).to(x_src.device) 233 | y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0]) 234 | s_many = nets.mapping_network(z_many, y_many) 235 | s_avg = torch.mean(s_many, dim=0, keepdim=True) 236 | s_avg = s_avg.repeat(x_src.size(0), 1) 237 | 238 | for z_trg in z_list: 239 | s_trg = nets.mapping_network(z_trg, y_trg) 240 | s_trg = torch.lerp(s_avg, s_trg, psi) 241 | s_list.append(s_trg) 242 | 243 | s_prev = None 244 | video = [] 245 | # fetch reference images 246 | for idx_ref, s_next in enumerate(tqdm(s_list, 'video_latent', len(s_list))): 247 | if s_prev is None: 248 | s_prev = s_next 249 | continue 250 | if idx_ref % len(z_list) == 0: 251 | s_prev = s_next 252 | continue 253 | frames = interpolate(nets, args, x_src, s_prev, s_next).cpu() 254 | video.append(frames) 255 | s_prev = s_next 256 | for _ in range(10): 257 | video.append(frames[-1:]) 258 | video = tensor2ndarray255(torch.cat(video)) 259 | save_video(fname, video) 260 | 261 | 262 | def save_video(fname, images, output_fps=30, vcodec='libx264', filters=''): 263 | assert isinstance(images, np.ndarray), "images should be np.array: NHWC" 264 | num_frames, height, width, channels = images.shape 265 | stream = ffmpeg.input('pipe:', format='rawvideo', 266 | pix_fmt='rgb24', s='{}x{}'.format(width, height)) 267 | stream = ffmpeg.filter(stream, 'setpts', '2*PTS') # 2*PTS is for slower playback 268 | stream = ffmpeg.output(stream, fname, pix_fmt='yuv420p', vcodec=vcodec, r=output_fps) 269 | stream = ffmpeg.overwrite_output(stream) 270 | process = ffmpeg.run_async(stream, pipe_stdin=True) 271 | for frame in tqdm(images, desc='writing video to %s' % fname): 272 | process.stdin.write(frame.astype(np.uint8).tobytes()) 273 | process.stdin.close() 274 | process.wait() 275 | 276 | 277 | def tensor2ndarray255(images): 278 | images = torch.clamp(images * 0.5 + 0.5, 0, 1) 279 | return images.cpu().numpy().transpose(0, 2, 3, 1) * 255 -------------------------------------------------------------------------------- /core/wing.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | 10 | Lines (19 to 80) were adapted from https://github.com/1adrianb/face-alignment 11 | Lines (83 to 235) were adapted from https://github.com/protossw512/AdaptiveWingLoss 12 | """ 13 | 14 | from collections import namedtuple 15 | from copy import deepcopy 16 | from functools import partial 17 | 18 | from munch import Munch 19 | import numpy as np 20 | import cv2 21 | from skimage.filters import gaussian 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | 27 | def get_preds_fromhm(hm): 28 | max, idx = torch.max( 29 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 30 | idx += 1 31 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 32 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 33 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 34 | 35 | for i in range(preds.size(0)): 36 | for j in range(preds.size(1)): 37 | hm_ = hm[i, j, :] 38 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 39 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 40 | diff = torch.FloatTensor( 41 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 42 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 43 | preds[i, j].add_(diff.sign_().mul_(.25)) 44 | 45 | preds.add_(-0.5) 46 | return preds 47 | 48 | 49 | class HourGlass(nn.Module): 50 | def __init__(self, num_modules, depth, num_features, first_one=False): 51 | super(HourGlass, self).__init__() 52 | self.num_modules = num_modules 53 | self.depth = depth 54 | self.features = num_features 55 | self.coordconv = CoordConvTh(64, 64, True, True, 256, first_one, 56 | out_channels=256, 57 | kernel_size=1, stride=1, padding=0) 58 | self._generate_network(self.depth) 59 | 60 | def _generate_network(self, level): 61 | self.add_module('b1_' + str(level), ConvBlock(256, 256)) 62 | self.add_module('b2_' + str(level), ConvBlock(256, 256)) 63 | if level > 1: 64 | self._generate_network(level - 1) 65 | else: 66 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) 67 | self.add_module('b3_' + str(level), ConvBlock(256, 256)) 68 | 69 | def _forward(self, level, inp): 70 | up1 = inp 71 | up1 = self._modules['b1_' + str(level)](up1) 72 | low1 = F.avg_pool2d(inp, 2, stride=2) 73 | low1 = self._modules['b2_' + str(level)](low1) 74 | 75 | if level > 1: 76 | low2 = self._forward(level - 1, low1) 77 | else: 78 | low2 = low1 79 | low2 = self._modules['b2_plus_' + str(level)](low2) 80 | low3 = low2 81 | low3 = self._modules['b3_' + str(level)](low3) 82 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 83 | 84 | return up1 + up2 85 | 86 | def forward(self, x, heatmap): 87 | x, last_channel = self.coordconv(x, heatmap) 88 | return self._forward(self.depth, x), last_channel 89 | 90 | 91 | class AddCoordsTh(nn.Module): 92 | def __init__(self, height=64, width=64, with_r=False, with_boundary=False): 93 | super(AddCoordsTh, self).__init__() 94 | self.with_r = with_r 95 | self.with_boundary = with_boundary 96 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 97 | 98 | with torch.no_grad(): 99 | x_coords = torch.arange(height).unsqueeze(1).expand(height, width).float() 100 | y_coords = torch.arange(width).unsqueeze(0).expand(height, width).float() 101 | x_coords = (x_coords / (height - 1)) * 2 - 1 102 | y_coords = (y_coords / (width - 1)) * 2 - 1 103 | coords = torch.stack([x_coords, y_coords], dim=0) # (2, height, width) 104 | 105 | if self.with_r: 106 | rr = torch.sqrt(torch.pow(x_coords, 2) + torch.pow(y_coords, 2)) # (height, width) 107 | rr = (rr / torch.max(rr)).unsqueeze(0) 108 | coords = torch.cat([coords, rr], dim=0) 109 | 110 | self.coords = coords.unsqueeze(0).to(device) # (1, 2 or 3, height, width) 111 | self.x_coords = x_coords.to(device) 112 | self.y_coords = y_coords.to(device) 113 | 114 | def forward(self, x, heatmap=None): 115 | """ 116 | x: (batch, c, x_dim, y_dim) 117 | """ 118 | coords = self.coords.repeat(x.size(0), 1, 1, 1) 119 | 120 | if self.with_boundary and heatmap is not None: 121 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) 122 | zero_tensor = torch.zeros_like(self.x_coords) 123 | xx_boundary_channel = torch.where(boundary_channel > 0.05, self.x_coords, zero_tensor).to(zero_tensor.device) 124 | yy_boundary_channel = torch.where(boundary_channel > 0.05, self.y_coords, zero_tensor).to(zero_tensor.device) 125 | coords = torch.cat([coords, xx_boundary_channel, yy_boundary_channel], dim=1) 126 | 127 | x_and_coords = torch.cat([x, coords], dim=1) 128 | return x_and_coords 129 | 130 | 131 | class CoordConvTh(nn.Module): 132 | """CoordConv layer as in the paper.""" 133 | def __init__(self, height, width, with_r, with_boundary, 134 | in_channels, first_one=False, *args, **kwargs): 135 | super(CoordConvTh, self).__init__() 136 | self.addcoords = AddCoordsTh(height, width, with_r, with_boundary) 137 | in_channels += 2 138 | if with_r: 139 | in_channels += 1 140 | if with_boundary and not first_one: 141 | in_channels += 2 142 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) 143 | 144 | def forward(self, input_tensor, heatmap=None): 145 | ret = self.addcoords(input_tensor, heatmap) 146 | last_channel = ret[:, -2:, :, :] 147 | ret = self.conv(ret) 148 | return ret, last_channel 149 | 150 | 151 | class ConvBlock(nn.Module): 152 | def __init__(self, in_planes, out_planes): 153 | super(ConvBlock, self).__init__() 154 | self.bn1 = nn.BatchNorm2d(in_planes) 155 | conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False, dilation=1) 156 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 157 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 158 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 159 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 160 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 161 | 162 | self.downsample = None 163 | if in_planes != out_planes: 164 | self.downsample = nn.Sequential(nn.BatchNorm2d(in_planes), 165 | nn.ReLU(True), 166 | nn.Conv2d(in_planes, out_planes, 1, 1, bias=False)) 167 | 168 | def forward(self, x): 169 | residual = x 170 | 171 | out1 = self.bn1(x) 172 | out1 = F.relu(out1, True) 173 | out1 = self.conv1(out1) 174 | 175 | out2 = self.bn2(out1) 176 | out2 = F.relu(out2, True) 177 | out2 = self.conv2(out2) 178 | 179 | out3 = self.bn3(out2) 180 | out3 = F.relu(out3, True) 181 | out3 = self.conv3(out3) 182 | 183 | out3 = torch.cat((out1, out2, out3), 1) 184 | if self.downsample is not None: 185 | residual = self.downsample(residual) 186 | out3 += residual 187 | return out3 188 | 189 | 190 | class FAN(nn.Module): 191 | def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None): 192 | super(FAN, self).__init__() 193 | self.num_modules = num_modules 194 | self.end_relu = end_relu 195 | 196 | # Base part 197 | self.conv1 = CoordConvTh(256, 256, True, False, 198 | in_channels=3, out_channels=64, 199 | kernel_size=7, stride=2, padding=3) 200 | self.bn1 = nn.BatchNorm2d(64) 201 | self.conv2 = ConvBlock(64, 128) 202 | self.conv3 = ConvBlock(128, 128) 203 | self.conv4 = ConvBlock(128, 256) 204 | 205 | # Stacking part 206 | self.add_module('m0', HourGlass(1, 4, 256, first_one=True)) 207 | self.add_module('top_m_0', ConvBlock(256, 256)) 208 | self.add_module('conv_last0', nn.Conv2d(256, 256, 1, 1, 0)) 209 | self.add_module('bn_end0', nn.BatchNorm2d(256)) 210 | self.add_module('l0', nn.Conv2d(256, num_landmarks+1, 1, 1, 0)) 211 | 212 | if fname_pretrained is not None: 213 | self.load_pretrained_weights(fname_pretrained) 214 | 215 | def load_pretrained_weights(self, fname): 216 | if torch.cuda.is_available(): 217 | checkpoint = torch.load(fname) 218 | else: 219 | checkpoint = torch.load(fname, map_location=torch.device('cpu')) 220 | model_weights = self.state_dict() 221 | model_weights.update({k: v for k, v in checkpoint['state_dict'].items() 222 | if k in model_weights}) 223 | self.load_state_dict(model_weights) 224 | 225 | def forward(self, x): 226 | x, _ = self.conv1(x) 227 | x = F.relu(self.bn1(x), True) 228 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 229 | x = self.conv3(x) 230 | x = self.conv4(x) 231 | 232 | outputs = [] 233 | boundary_channels = [] 234 | tmp_out = None 235 | ll, boundary_channel = self._modules['m0'](x, tmp_out) 236 | ll = self._modules['top_m_0'](ll) 237 | ll = F.relu(self._modules['bn_end0'] 238 | (self._modules['conv_last0'](ll)), True) 239 | 240 | # Predict heatmaps 241 | tmp_out = self._modules['l0'](ll) 242 | if self.end_relu: 243 | tmp_out = F.relu(tmp_out) # HACK: Added relu 244 | outputs.append(tmp_out) 245 | boundary_channels.append(boundary_channel) 246 | return outputs, boundary_channels 247 | 248 | @torch.no_grad() 249 | def get_heatmap(self, x, b_preprocess=True): 250 | ''' outputs 0-1 normalized heatmap ''' 251 | x = F.interpolate(x, size=256, mode='bilinear') 252 | x_01 = x*0.5 + 0.5 253 | outputs, _ = self(x_01) 254 | heatmaps = outputs[-1][:, :-1, :, :] 255 | scale_factor = x.size(2) // heatmaps.size(2) 256 | if b_preprocess: 257 | heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor, 258 | mode='bilinear', align_corners=True) 259 | heatmaps = preprocess(heatmaps) 260 | return heatmaps 261 | 262 | @torch.no_grad() 263 | def get_landmark(self, x): 264 | ''' outputs landmarks of x.shape ''' 265 | heatmaps = self.get_heatmap(x, b_preprocess=False) 266 | landmarks = [] 267 | for i in range(x.size(0)): 268 | pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0)) 269 | landmarks.append(pred_landmarks) 270 | scale_factor = x.size(2) // heatmaps.size(2) 271 | landmarks = torch.cat(landmarks) * scale_factor 272 | return landmarks 273 | 274 | 275 | # ========================== # 276 | # Align related functions # 277 | # ========================== # 278 | 279 | 280 | def tensor2numpy255(tensor): 281 | """Converts torch tensor to numpy array.""" 282 | return ((tensor.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5) * 255).astype('uint8') 283 | 284 | 285 | def np2tensor(image): 286 | """Converts numpy array to torch tensor.""" 287 | return torch.FloatTensor(image).permute(2, 0, 1) / 255 * 2 - 1 288 | 289 | 290 | class FaceAligner(): 291 | def __init__(self, fname_wing, fname_celeba_mean, output_size): 292 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 293 | self.fan = FAN(fname_pretrained=fname_wing).to(self.device).eval() 294 | scale = output_size // 256 295 | self.CELEB_REF = np.float32(np.load(fname_celeba_mean)['mean']) * scale 296 | self.xaxis_ref = landmarks2xaxis(self.CELEB_REF) 297 | self.output_size = output_size 298 | 299 | def align(self, imgs, output_size=256): 300 | ''' imgs = torch.CUDATensor of BCHW ''' 301 | imgs = imgs.to(self.device) 302 | landmarkss = self.fan.get_landmark(imgs).cpu().numpy() 303 | for i, (img, landmarks) in enumerate(zip(imgs, landmarkss)): 304 | img_np = tensor2numpy255(img) 305 | img_np, landmarks = pad_mirror(img_np, landmarks) 306 | transform = self.landmarks2mat(landmarks) 307 | rows, cols, _ = img_np.shape 308 | rows = max(rows, self.output_size) 309 | cols = max(cols, self.output_size) 310 | aligned = cv2.warpPerspective(img_np, transform, (cols, rows), flags=cv2.INTER_LANCZOS4) 311 | imgs[i] = np2tensor(aligned[:self.output_size, :self.output_size, :]) 312 | return imgs 313 | 314 | def landmarks2mat(self, landmarks): 315 | T_origin = points2T(landmarks, 'from') 316 | xaxis_src = landmarks2xaxis(landmarks) 317 | R = vecs2R(xaxis_src, self.xaxis_ref) 318 | S = landmarks2S(landmarks, self.CELEB_REF) 319 | T_ref = points2T(self.CELEB_REF, 'to') 320 | matrix = np.dot(T_ref, np.dot(S, np.dot(R, T_origin))) 321 | return matrix 322 | 323 | 324 | def points2T(point, direction): 325 | point_mean = point.mean(axis=0) 326 | T = np.eye(3) 327 | coef = -1 if direction == 'from' else 1 328 | T[:2, 2] = coef * point_mean 329 | return T 330 | 331 | 332 | def landmarks2eyes(landmarks): 333 | idx_left = np.array(list(range(60, 67+1)) + [96]) 334 | idx_right = np.array(list(range(68, 75+1)) + [97]) 335 | left = landmarks[idx_left] 336 | right = landmarks[idx_right] 337 | return left.mean(axis=0), right.mean(axis=0) 338 | 339 | 340 | def landmarks2mouthends(landmarks): 341 | left = landmarks[76] 342 | right = landmarks[82] 343 | return left, right 344 | 345 | 346 | def rotate90(vec): 347 | x, y = vec 348 | return np.array([y, -x]) 349 | 350 | 351 | def landmarks2xaxis(landmarks): 352 | eye_left, eye_right = landmarks2eyes(landmarks) 353 | mouth_left, mouth_right = landmarks2mouthends(landmarks) 354 | xp = eye_right - eye_left # x' in pggan 355 | eye_center = (eye_left + eye_right) * 0.5 356 | mouth_center = (mouth_left + mouth_right) * 0.5 357 | yp = eye_center - mouth_center 358 | xaxis = xp - rotate90(yp) 359 | return xaxis / np.linalg.norm(xaxis) 360 | 361 | 362 | def vecs2R(vec_x, vec_y): 363 | vec_x = vec_x / np.linalg.norm(vec_x) 364 | vec_y = vec_y / np.linalg.norm(vec_y) 365 | c = np.dot(vec_x, vec_y) 366 | s = np.sqrt(1 - c * c) * np.sign(np.cross(vec_x, vec_y)) 367 | R = np.array(((c, -s, 0), (s, c, 0), (0, 0, 1))) 368 | return R 369 | 370 | 371 | def landmarks2S(x, y): 372 | x_mean = x.mean(axis=0).squeeze() 373 | y_mean = y.mean(axis=0).squeeze() 374 | # vectors = mean -> each point 375 | x_vectors = x - x_mean 376 | y_vectors = y - y_mean 377 | 378 | x_norms = np.linalg.norm(x_vectors, axis=1) 379 | y_norms = np.linalg.norm(y_vectors, axis=1) 380 | 381 | indices = [96, 97, 76, 82] # indices for eyes, lips 382 | scale = (y_norms / x_norms)[indices].mean() 383 | 384 | S = np.eye(3) 385 | S[0, 0] = S[1, 1] = scale 386 | return S 387 | 388 | 389 | def pad_mirror(img, landmarks): 390 | H, W, _ = img.shape 391 | img = np.pad(img, ((H//2, H//2), (W//2, W//2), (0, 0)), 'reflect') 392 | small_blurred = gaussian(cv2.resize(img, (W, H)), H//100, multichannel=True) 393 | blurred = cv2.resize(small_blurred, (W * 2, H * 2)) * 255 394 | 395 | H, W, _ = img.shape 396 | coords = np.meshgrid(np.arange(H), np.arange(W), indexing="ij") 397 | weight_y = np.clip(coords[0] / (H//4), 0, 1) 398 | weight_x = np.clip(coords[1] / (H//4), 0, 1) 399 | weight_y = np.minimum(weight_y, np.flip(weight_y, axis=0)) 400 | weight_x = np.minimum(weight_x, np.flip(weight_x, axis=1)) 401 | weight = np.expand_dims(np.minimum(weight_y, weight_x), 2)**4 402 | img = img * weight + blurred * (1 - weight) 403 | landmarks += np.array([W//4, H//4]) 404 | return img, landmarks 405 | 406 | 407 | def align_faces(args, input_dir, output_dir): 408 | import os 409 | from torchvision import transforms 410 | from PIL import Image 411 | from core.utils import save_image 412 | 413 | aligner = FaceAligner(args.wing_path, args.lm_path, args.img_size) 414 | transform = transforms.Compose([ 415 | transforms.Resize((args.img_size, args.img_size)), 416 | transforms.ToTensor(), 417 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 418 | std=[0.5, 0.5, 0.5]), 419 | ]) 420 | 421 | fnames = os.listdir(input_dir) 422 | os.makedirs(output_dir, exist_ok=True) 423 | fnames.sort() 424 | for fname in fnames: 425 | image = Image.open(os.path.join(input_dir, fname)).convert('RGB') 426 | x = transform(image).unsqueeze(0) 427 | x_aligned = aligner.align(x) 428 | save_image(x_aligned, 1, filename=os.path.join(output_dir, fname)) 429 | print('Saved the aligned image to %s...' % fname) 430 | 431 | 432 | # ========================== # 433 | # Mask related functions # 434 | # ========================== # 435 | 436 | 437 | def normalize(x, eps=1e-6): 438 | """Apply min-max normalization.""" 439 | x = x.contiguous() 440 | N, C, H, W = x.size() 441 | x_ = x.view(N*C, -1) 442 | max_val = torch.max(x_, dim=1, keepdim=True)[0] 443 | min_val = torch.min(x_, dim=1, keepdim=True)[0] 444 | x_ = (x_ - min_val) / (max_val - min_val + eps) 445 | out = x_.view(N, C, H, W) 446 | return out 447 | 448 | 449 | def truncate(x, thres=0.1): 450 | """Remove small values in heatmaps.""" 451 | return torch.where(x < thres, torch.zeros_like(x), x) 452 | 453 | 454 | def resize(x, p=2): 455 | """Resize heatmaps.""" 456 | return x**p 457 | 458 | 459 | def shift(x, N): 460 | """Shift N pixels up or down.""" 461 | up = N >= 0 462 | N = abs(N) 463 | _, _, H, W = x.size() 464 | head = torch.arange(N) 465 | tail = torch.arange(H-N) 466 | 467 | if up: 468 | head = torch.arange(H-N)+N 469 | tail = torch.arange(N) 470 | else: 471 | head = torch.arange(N) + (H-N) 472 | tail = torch.arange(H-N) 473 | 474 | # permutation indices 475 | perm = torch.cat([head, tail]).to(x.device) 476 | out = x[:, :, perm, :] 477 | return out 478 | 479 | 480 | IDXPAIR = namedtuple('IDXPAIR', 'start end') 481 | index_map = Munch(chin=IDXPAIR(0 + 8, 33 - 8), 482 | eyebrows=IDXPAIR(33, 51), 483 | eyebrowsedges=IDXPAIR(33, 46), 484 | nose=IDXPAIR(51, 55), 485 | nostrils=IDXPAIR(55, 60), 486 | eyes=IDXPAIR(60, 76), 487 | lipedges=IDXPAIR(76, 82), 488 | lipupper=IDXPAIR(77, 82), 489 | liplower=IDXPAIR(83, 88), 490 | lipinner=IDXPAIR(88, 96)) 491 | OPPAIR = namedtuple('OPPAIR', 'shift resize') 492 | 493 | 494 | def preprocess(x): 495 | """Preprocess 98-dimensional heatmaps.""" 496 | N, C, H, W = x.size() 497 | x = truncate(x) 498 | x = normalize(x) 499 | 500 | sw = H // 256 501 | operations = Munch(chin=OPPAIR(0, 3), 502 | eyebrows=OPPAIR(-7*sw, 2), 503 | nostrils=OPPAIR(8*sw, 4), 504 | lipupper=OPPAIR(-8*sw, 4), 505 | liplower=OPPAIR(8*sw, 4), 506 | lipinner=OPPAIR(-2*sw, 3)) 507 | 508 | for part, ops in operations.items(): 509 | start, end = index_map[part] 510 | x[:, start:end] = resize(shift(x[:, start:end], ops.shift), ops.resize) 511 | 512 | zero_out = torch.cat([torch.arange(0, index_map.chin.start), 513 | torch.arange(index_map.chin.end, 33), 514 | torch.LongTensor([index_map.eyebrowsedges.start, 515 | index_map.eyebrowsedges.end, 516 | index_map.lipedges.start, 517 | index_map.lipedges.end])]) 518 | x[:, zero_out] = 0 519 | 520 | start, end = index_map.nose 521 | x[:, start+1:end] = shift(x[:, start+1:end], 4*sw) 522 | x[:, start:end] = resize(x[:, start:end], 1) 523 | 524 | start, end = index_map.eyes 525 | x[:, start:end] = resize(x[:, start:end], 1) 526 | x[:, start:end] = resize(shift(x[:, start:end], -8), 3) + \ 527 | shift(x[:, start:end], -24) 528 | 529 | # Second-level mask 530 | x2 = deepcopy(x) 531 | x2[:, index_map.chin.start:index_map.chin.end] = 0 # start:end was 0:33 532 | x2[:, index_map.lipedges.start:index_map.lipinner.end] = 0 # start:end was 76:96 533 | x2[:, index_map.eyebrows.start:index_map.eyebrows.end] = 0 # start:end was 33:51 534 | 535 | x = torch.sum(x, dim=1, keepdim=True) # (N, 1, H, W) 536 | x2 = torch.sum(x2, dim=1, keepdim=True) # mask without faceline and mouth 537 | 538 | x[x != x] = 0 # set nan to zero 539 | x2[x != x] = 0 # set nan to zero 540 | return x.clamp_(0, 1), x2.clamp_(0, 1) -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | FILE=$1 12 | 13 | if [ $FILE == "pretrained-network-celeba-hq" ]; then 14 | URL=https://www.dropbox.com/s/96fmei6c93o8b8t/100000_nets_ema.ckpt?dl=0 15 | mkdir -p ./expr/checkpoints/celeba_hq 16 | OUT_FILE=./expr/checkpoints/celeba_hq/100000_nets_ema.ckpt 17 | wget -N $URL -O $OUT_FILE 18 | 19 | elif [ $FILE == "pretrained-network-afhq" ]; then 20 | URL=https://www.dropbox.com/s/etwm810v25h42sn/100000_nets_ema.ckpt?dl=0 21 | mkdir -p ./expr/checkpoints/afhq 22 | OUT_FILE=./expr/checkpoints/afhq/100000_nets_ema.ckpt 23 | wget -N $URL -O $OUT_FILE 24 | 25 | elif [ $FILE == "wing" ]; then 26 | URL=https://www.dropbox.com/s/tjxpypwpt38926e/wing.ckpt?dl=0 27 | mkdir -p ./expr/checkpoints/ 28 | OUT_FILE=./expr/checkpoints/wing.ckpt 29 | wget -N $URL -O $OUT_FILE 30 | URL=https://www.dropbox.com/s/91fth49gyb7xksk/celeba_lm_mean.npz?dl=0 31 | OUT_FILE=./expr/checkpoints/celeba_lm_mean.npz 32 | wget -N $URL -O $OUT_FILE 33 | 34 | elif [ $FILE == "celeba-hq-dataset" ]; then 35 | URL=https://www.dropbox.com/s/f7pvjij2xlpff59/celeba_hq.zip?dl=0 36 | ZIP_FILE=./data/celeba_hq.zip 37 | mkdir -p ./data 38 | wget -N $URL -O $ZIP_FILE 39 | unzip $ZIP_FILE -d ./data 40 | rm $ZIP_FILE 41 | 42 | elif [ $FILE == "afhq-dataset" ]; then 43 | URL=https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0 44 | ZIP_FILE=./data/afhq.zip 45 | mkdir -p ./data 46 | wget -N $URL -O $ZIP_FILE 47 | unzip $ZIP_FILE -d ./data 48 | rm $ZIP_FILE 49 | 50 | elif [ $FILE == "afhq-v2-dataset" ]; then 51 | #URL=https://www.dropbox.com/s/scckftx13grwmiv/afhq_v2.zip?dl=0 52 | URL=https://www.dropbox.com/s/vkzjokiwof5h8w6/afhq_v2.zip?dl=0 53 | ZIP_FILE=./data/afhq_v2.zip 54 | mkdir -p ./data 55 | wget -N $URL -O $ZIP_FILE 56 | unzip $ZIP_FILE -d ./data 57 | rm $ZIP_FILE 58 | 59 | else 60 | echo "Available arguments are pretrained-network-celeba-hq, pretrained-network-afhq, celeba-hq-dataset, and afhq-dataset." 61 | exit 1 62 | 63 | fi 64 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import os 12 | import argparse 13 | 14 | from munch import Munch 15 | from torch.backends import cudnn 16 | import torch 17 | 18 | from core.data_loader import get_train_loader 19 | from core.data_loader import get_test_loader 20 | from core.solver import Solver 21 | 22 | 23 | def str2bool(v): 24 | return v.lower() in ('true') 25 | 26 | 27 | def subdirs(dname): 28 | return [d for d in os.listdir(dname) 29 | if os.path.isdir(os.path.join(dname, d))] 30 | 31 | 32 | def main(args): 33 | print(args) 34 | cudnn.benchmark = True 35 | torch.manual_seed(args.seed) 36 | 37 | solver = Solver(args) 38 | 39 | if args.mode == 'train': 40 | assert len(subdirs(args.train_img_dir)) == args.num_domains 41 | assert len(subdirs(args.val_img_dir)) == args.num_domains 42 | loaders = Munch(src=get_train_loader(root=args.train_img_dir, 43 | which='source', 44 | img_size=args.img_size, 45 | batch_size=args.batch_size, 46 | prob=args.randcrop_prob, 47 | num_workers=args.num_workers), 48 | ref=get_train_loader(root=args.train_img_dir, 49 | which='reference', 50 | img_size=args.img_size, 51 | batch_size=args.batch_size, 52 | prob=args.randcrop_prob, 53 | num_workers=args.num_workers), 54 | val=get_test_loader(root=args.val_img_dir, 55 | img_size=args.img_size, 56 | batch_size=args.val_batch_size, 57 | shuffle=True, 58 | num_workers=args.num_workers)) 59 | solver.train(loaders) 60 | elif args.mode == 'sample': 61 | assert len(subdirs(args.src_dir)) == args.num_domains 62 | assert len(subdirs(args.ref_dir)) == args.num_domains 63 | loaders = Munch(src=get_test_loader(root=args.src_dir, 64 | img_size=args.img_size, 65 | batch_size=args.val_batch_size, 66 | shuffle=False, 67 | num_workers=args.num_workers), 68 | ref=get_test_loader(root=args.ref_dir, 69 | img_size=args.img_size, 70 | batch_size=args.val_batch_size, 71 | shuffle=False, 72 | num_workers=args.num_workers)) 73 | solver.sample(loaders) 74 | elif args.mode == 'eval': 75 | solver.evaluate() 76 | elif args.mode == 'align': 77 | from core.wing import align_faces 78 | align_faces(args, args.inp_dir, args.out_dir) 79 | else: 80 | raise NotImplementedError 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | 86 | # model arguments 87 | parser.add_argument('--img_size', type=int, default=256, 88 | help='Image resolution') 89 | parser.add_argument('--num_domains', type=int, default=2, 90 | help='Number of domains') 91 | parser.add_argument('--latent_dim', type=int, default=16, 92 | help='Latent vector dimension') 93 | parser.add_argument('--hidden_dim', type=int, default=512, 94 | help='Hidden dimension of mapping network') 95 | parser.add_argument('--style_dim', type=int, default=64, 96 | help='Style code dimension') 97 | 98 | # weight for objective functions 99 | parser.add_argument('--lambda_reg', type=float, default=1, 100 | help='Weight for R1 regularization') 101 | parser.add_argument('--lambda_cyc', type=float, default=1, 102 | help='Weight for cyclic consistency loss') 103 | parser.add_argument('--lambda_sty', type=float, default=1, 104 | help='Weight for style reconstruction loss') 105 | parser.add_argument('--lambda_ds', type=float, default=1, 106 | help='Weight for diversity sensitive loss') 107 | parser.add_argument('--ds_iter', type=int, default=100000, 108 | help='Number of iterations to optimize diversity sensitive loss') 109 | parser.add_argument('--w_hpf', type=float, default=1, 110 | help='weight for high-pass filtering') 111 | 112 | # training arguments 113 | parser.add_argument('--randcrop_prob', type=float, default=0.5, 114 | help='Probabilty of using random-resized cropping') 115 | parser.add_argument('--total_iters', type=int, default=100000, 116 | help='Number of total iterations') 117 | parser.add_argument('--resume_iter', type=int, default=0, 118 | help='Iterations to resume training/testing') 119 | parser.add_argument('--batch_size', type=int, default=8, 120 | help='Batch size for training') 121 | parser.add_argument('--val_batch_size', type=int, default=32, 122 | help='Batch size for validation') 123 | parser.add_argument('--lr', type=float, default=1e-4, 124 | help='Learning rate for D, E and G') 125 | parser.add_argument('--f_lr', type=float, default=1e-6, 126 | help='Learning rate for F') 127 | parser.add_argument('--beta1', type=float, default=0.0, 128 | help='Decay rate for 1st moment of Adam') 129 | parser.add_argument('--beta2', type=float, default=0.99, 130 | help='Decay rate for 2nd moment of Adam') 131 | parser.add_argument('--weight_decay', type=float, default=1e-4, 132 | help='Weight decay for optimizer') 133 | parser.add_argument('--num_outs_per_domain', type=int, default=10, 134 | help='Number of generated images per domain during sampling') 135 | 136 | # misc 137 | parser.add_argument('--mode', type=str, required=True, 138 | choices=['train', 'sample', 'eval', 'align'], 139 | help='This argument is used in solver') 140 | parser.add_argument('--num_workers', type=int, default=4, 141 | help='Number of workers used in DataLoader') 142 | parser.add_argument('--seed', type=int, default=777, 143 | help='Seed for random number generator') 144 | 145 | # directory for training 146 | parser.add_argument('--train_img_dir', type=str, default='data/celeba_hq/train', 147 | help='Directory containing training images') 148 | parser.add_argument('--val_img_dir', type=str, default='data/celeba_hq/val', 149 | help='Directory containing validation images') 150 | parser.add_argument('--sample_dir', type=str, default='expr/samples', 151 | help='Directory for saving generated images') 152 | parser.add_argument('--checkpoint_dir', type=str, default='expr/checkpoints', 153 | help='Directory for saving network checkpoints') 154 | 155 | # directory for calculating metrics 156 | parser.add_argument('--eval_dir', type=str, default='expr/eval', 157 | help='Directory for saving metrics, i.e., FID and LPIPS') 158 | 159 | # directory for testing 160 | parser.add_argument('--result_dir', type=str, default='expr/results', 161 | help='Directory for saving generated images and videos') 162 | parser.add_argument('--src_dir', type=str, default='assets/representative/celeba_hq/src', 163 | help='Directory containing input source images') 164 | parser.add_argument('--ref_dir', type=str, default='assets/representative/celeba_hq/ref', 165 | help='Directory containing input reference images') 166 | parser.add_argument('--inp_dir', type=str, default='assets/representative/custom/female', 167 | help='input directory when aligning faces') 168 | parser.add_argument('--out_dir', type=str, default='assets/representative/celeba_hq/src/female', 169 | help='output directory when aligning faces') 170 | 171 | # face alignment 172 | parser.add_argument('--wing_path', type=str, default='expr/checkpoints/wing.ckpt') 173 | parser.add_argument('--lm_path', type=str, default='expr/checkpoints/celeba_lm_mean.npz') 174 | 175 | # step size 176 | parser.add_argument('--print_every', type=int, default=10) 177 | parser.add_argument('--sample_every', type=int, default=5000) 178 | parser.add_argument('--save_every', type=int, default=10000) 179 | parser.add_argument('--eval_every', type=int, default=50000) 180 | 181 | args = parser.parse_args() 182 | main(args) 183 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import os 12 | import shutil 13 | from collections import OrderedDict 14 | from tqdm import tqdm 15 | 16 | import numpy as np 17 | import torch 18 | 19 | from metrics.fid import calculate_fid_given_paths 20 | from metrics.lpips import calculate_lpips_given_images 21 | from core.data_loader import get_eval_loader 22 | from core import utils 23 | 24 | 25 | @torch.no_grad() 26 | def calculate_metrics(nets, args, step, mode): 27 | print('Calculating evaluation metrics...') 28 | assert mode in ['latent', 'reference'] 29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | 31 | domains = os.listdir(args.val_img_dir) 32 | domains.sort() 33 | num_domains = len(domains) 34 | print('Number of domains: %d' % num_domains) 35 | 36 | lpips_dict = OrderedDict() 37 | for trg_idx, trg_domain in enumerate(domains): 38 | src_domains = [x for x in domains if x != trg_domain] 39 | 40 | if mode == 'reference': 41 | path_ref = os.path.join(args.val_img_dir, trg_domain) 42 | loader_ref = get_eval_loader(root=path_ref, 43 | img_size=args.img_size, 44 | batch_size=args.val_batch_size, 45 | imagenet_normalize=False, 46 | drop_last=True) 47 | 48 | for src_idx, src_domain in enumerate(src_domains): 49 | path_src = os.path.join(args.val_img_dir, src_domain) 50 | loader_src = get_eval_loader(root=path_src, 51 | img_size=args.img_size, 52 | batch_size=args.val_batch_size, 53 | imagenet_normalize=False) 54 | 55 | task = '%s2%s' % (src_domain, trg_domain) 56 | path_fake = os.path.join(args.eval_dir, task) 57 | shutil.rmtree(path_fake, ignore_errors=True) 58 | os.makedirs(path_fake) 59 | 60 | lpips_values = [] 61 | print('Generating images and calculating LPIPS for %s...' % task) 62 | for i, x_src in enumerate(tqdm(loader_src, total=len(loader_src))): 63 | N = x_src.size(0) 64 | x_src = x_src.to(device) 65 | y_trg = torch.tensor([trg_idx] * N).to(device) 66 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None 67 | 68 | # generate 10 outputs from the same input 69 | group_of_images = [] 70 | for j in range(args.num_outs_per_domain): 71 | if mode == 'latent': 72 | z_trg = torch.randn(N, args.latent_dim).to(device) 73 | s_trg = nets.mapping_network(z_trg, y_trg) 74 | else: 75 | try: 76 | x_ref = next(iter_ref).to(device) 77 | except: 78 | iter_ref = iter(loader_ref) 79 | x_ref = next(iter_ref).to(device) 80 | 81 | if x_ref.size(0) > N: 82 | x_ref = x_ref[:N] 83 | s_trg = nets.style_encoder(x_ref, y_trg) 84 | 85 | x_fake = nets.generator(x_src, s_trg, masks=masks) 86 | group_of_images.append(x_fake) 87 | 88 | # save generated images to calculate FID later 89 | for k in range(N): 90 | filename = os.path.join( 91 | path_fake, 92 | '%.4i_%.2i.png' % (i*args.val_batch_size+(k+1), j+1)) 93 | utils.save_image(x_fake[k], ncol=1, filename=filename) 94 | 95 | lpips_value = calculate_lpips_given_images(group_of_images) 96 | lpips_values.append(lpips_value) 97 | 98 | # calculate LPIPS for each task (e.g. cat2dog, dog2cat) 99 | lpips_mean = np.array(lpips_values).mean() 100 | lpips_dict['LPIPS_%s/%s' % (mode, task)] = lpips_mean 101 | 102 | # delete dataloaders 103 | del loader_src 104 | if mode == 'reference': 105 | del loader_ref 106 | del iter_ref 107 | 108 | # calculate the average LPIPS for all tasks 109 | lpips_mean = 0 110 | for _, value in lpips_dict.items(): 111 | lpips_mean += value / len(lpips_dict) 112 | lpips_dict['LPIPS_%s/mean' % mode] = lpips_mean 113 | 114 | # report LPIPS values 115 | filename = os.path.join(args.eval_dir, 'LPIPS_%.5i_%s.json' % (step, mode)) 116 | utils.save_json(lpips_dict, filename) 117 | 118 | # calculate and report fid values 119 | calculate_fid_for_all_tasks(args, domains, step=step, mode=mode) 120 | 121 | 122 | def calculate_fid_for_all_tasks(args, domains, step, mode): 123 | print('Calculating FID for all tasks...') 124 | fid_values = OrderedDict() 125 | for trg_domain in domains: 126 | src_domains = [x for x in domains if x != trg_domain] 127 | 128 | for src_domain in src_domains: 129 | task = '%s2%s' % (src_domain, trg_domain) 130 | path_real = os.path.join(args.train_img_dir, trg_domain) 131 | path_fake = os.path.join(args.eval_dir, task) 132 | print('Calculating FID for %s...' % task) 133 | fid_value = calculate_fid_given_paths( 134 | paths=[path_real, path_fake], 135 | img_size=args.img_size, 136 | batch_size=args.val_batch_size) 137 | fid_values['FID_%s/%s' % (mode, task)] = fid_value 138 | 139 | # calculate the average FID for all tasks 140 | fid_mean = 0 141 | for _, value in fid_values.items(): 142 | fid_mean += value / len(fid_values) 143 | fid_values['FID_%s/mean' % mode] = fid_mean 144 | 145 | # report FID values 146 | filename = os.path.join(args.eval_dir, 'FID_%.5i_%s.json' % (step, mode)) 147 | utils.save_json(fid_values, filename) 148 | -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import os 12 | import argparse 13 | 14 | import torch 15 | import torch.nn as nn 16 | import numpy as np 17 | from torchvision import models 18 | from scipy import linalg 19 | from core.data_loader import get_eval_loader 20 | 21 | try: 22 | from tqdm import tqdm 23 | except ImportError: 24 | def tqdm(x): return x 25 | 26 | 27 | class InceptionV3(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | inception = models.inception_v3(pretrained=True) 31 | self.block1 = nn.Sequential( 32 | inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, 33 | inception.Conv2d_2b_3x3, 34 | nn.MaxPool2d(kernel_size=3, stride=2)) 35 | self.block2 = nn.Sequential( 36 | inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, 37 | nn.MaxPool2d(kernel_size=3, stride=2)) 38 | self.block3 = nn.Sequential( 39 | inception.Mixed_5b, inception.Mixed_5c, 40 | inception.Mixed_5d, inception.Mixed_6a, 41 | inception.Mixed_6b, inception.Mixed_6c, 42 | inception.Mixed_6d, inception.Mixed_6e) 43 | self.block4 = nn.Sequential( 44 | inception.Mixed_7a, inception.Mixed_7b, 45 | inception.Mixed_7c, 46 | nn.AdaptiveAvgPool2d(output_size=(1, 1))) 47 | 48 | def forward(self, x): 49 | x = self.block1(x) 50 | x = self.block2(x) 51 | x = self.block3(x) 52 | x = self.block4(x) 53 | return x.view(x.size(0), -1) 54 | 55 | 56 | def frechet_distance(mu, cov, mu2, cov2): 57 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False) 58 | dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc) 59 | return np.real(dist) 60 | 61 | 62 | @torch.no_grad() 63 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50): 64 | print('Calculating FID given paths %s and %s...' % (paths[0], paths[1])) 65 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 66 | inception = InceptionV3().eval().to(device) 67 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths] 68 | 69 | mu, cov = [], [] 70 | for loader in loaders: 71 | actvs = [] 72 | for x in tqdm(loader, total=len(loader)): 73 | actv = inception(x.to(device)) 74 | actvs.append(actv) 75 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy() 76 | mu.append(np.mean(actvs, axis=0)) 77 | cov.append(np.cov(actvs, rowvar=False)) 78 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1]) 79 | return fid_value 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('--paths', type=str, nargs=2, help='paths to real and fake images') 85 | parser.add_argument('--img_size', type=int, default=256, help='image resolution') 86 | parser.add_argument('--batch_size', type=int, default=64, help='batch size to use') 87 | args = parser.parse_args() 88 | fid_value = calculate_fid_given_paths(args.paths, args.img_size, args.batch_size) 89 | print('FID: ', fid_value) 90 | 91 | # python -m metrics.fid --paths PATH_REAL PATH_FAKE -------------------------------------------------------------------------------- /metrics/lpips.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import models 14 | 15 | 16 | def normalize(x, eps=1e-10): 17 | return x * torch.rsqrt(torch.sum(x**2, dim=1, keepdim=True) + eps) 18 | 19 | 20 | class AlexNet(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.layers = models.alexnet(pretrained=True).features 24 | self.channels = [] 25 | for layer in self.layers: 26 | if isinstance(layer, nn.Conv2d): 27 | self.channels.append(layer.out_channels) 28 | 29 | def forward(self, x): 30 | fmaps = [] 31 | for layer in self.layers: 32 | x = layer(x) 33 | if isinstance(layer, nn.ReLU): 34 | fmaps.append(x) 35 | return fmaps 36 | 37 | 38 | class Conv1x1(nn.Module): 39 | def __init__(self, in_channels, out_channels=1): 40 | super().__init__() 41 | self.main = nn.Sequential( 42 | nn.Dropout(0.5), 43 | nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)) 44 | 45 | def forward(self, x): 46 | return self.main(x) 47 | 48 | 49 | class LPIPS(nn.Module): 50 | def __init__(self): 51 | super().__init__() 52 | self.alexnet = AlexNet() 53 | self.lpips_weights = nn.ModuleList() 54 | for channels in self.alexnet.channels: 55 | self.lpips_weights.append(Conv1x1(channels, 1)) 56 | self._load_lpips_weights() 57 | # imagenet normalization for range [-1, 1] 58 | self.mu = torch.tensor([-0.03, -0.088, -0.188]).view(1, 3, 1, 1).cuda() 59 | self.sigma = torch.tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1).cuda() 60 | 61 | def _load_lpips_weights(self): 62 | own_state_dict = self.state_dict() 63 | if torch.cuda.is_available(): 64 | state_dict = torch.load('metrics/lpips_weights.ckpt') 65 | else: 66 | state_dict = torch.load('metrics/lpips_weights.ckpt', 67 | map_location=torch.device('cpu')) 68 | for name, param in state_dict.items(): 69 | if name in own_state_dict: 70 | own_state_dict[name].copy_(param) 71 | 72 | def forward(self, x, y): 73 | x = (x - self.mu) / self.sigma 74 | y = (y - self.mu) / self.sigma 75 | x_fmaps = self.alexnet(x) 76 | y_fmaps = self.alexnet(y) 77 | lpips_value = 0 78 | for x_fmap, y_fmap, conv1x1 in zip(x_fmaps, y_fmaps, self.lpips_weights): 79 | x_fmap = normalize(x_fmap) 80 | y_fmap = normalize(y_fmap) 81 | lpips_value += torch.mean(conv1x1((x_fmap - y_fmap)**2)) 82 | return lpips_value 83 | 84 | 85 | @torch.no_grad() 86 | def calculate_lpips_given_images(group_of_images): 87 | # group_of_images = [torch.randn(N, C, H, W) for _ in range(10)] 88 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 89 | lpips = LPIPS().eval().to(device) 90 | lpips_values = [] 91 | num_rand_outputs = len(group_of_images) 92 | 93 | # calculate the average of pairwise distances among all random outputs 94 | for i in range(num_rand_outputs-1): 95 | for j in range(i+1, num_rand_outputs): 96 | lpips_values.append(lpips(group_of_images[i], group_of_images[j])) 97 | lpips_value = torch.mean(torch.stack(lpips_values, dim=0)) 98 | return lpips_value.item() -------------------------------------------------------------------------------- /metrics/lpips_weights.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/metrics/lpips_weights.ckpt --------------------------------------------------------------------------------