├── .gitignore
├── LICENSE
├── NOTICE
├── README.md
├── assets
├── afhq_dataset.jpg
├── afhq_interpolation.gif
├── afhqv2_teaser2.jpg
├── celebahq_interpolation.gif
├── representative
│ ├── afhq
│ │ ├── ref
│ │ │ ├── cat
│ │ │ │ ├── flickr_cat_000495.jpg
│ │ │ │ ├── flickr_cat_000557.jpg
│ │ │ │ ├── pixabay_cat_000355.jpg
│ │ │ │ ├── pixabay_cat_000491.jpg
│ │ │ │ ├── pixabay_cat_000535.jpg
│ │ │ │ ├── pixabay_cat_000623.jpg
│ │ │ │ ├── pixabay_cat_000730.jpg
│ │ │ │ ├── pixabay_cat_001479.jpg
│ │ │ │ ├── pixabay_cat_001699.jpg
│ │ │ │ └── pixabay_cat_003046.jpg
│ │ │ ├── dog
│ │ │ │ ├── flickr_dog_001072.jpg
│ │ │ │ ├── pixabay_dog_000121.jpg
│ │ │ │ ├── pixabay_dog_000322.jpg
│ │ │ │ ├── pixabay_dog_000357.jpg
│ │ │ │ ├── pixabay_dog_000409.jpg
│ │ │ │ ├── pixabay_dog_000799.jpg
│ │ │ │ ├── pixabay_dog_000890.jpg
│ │ │ │ └── pixabay_dog_001082.jpg
│ │ │ └── wild
│ │ │ │ ├── flickr_wild_000731.jpg
│ │ │ │ ├── flickr_wild_001223.jpg
│ │ │ │ ├── flickr_wild_002020.jpg
│ │ │ │ ├── flickr_wild_002092.jpg
│ │ │ │ ├── flickr_wild_002933.jpg
│ │ │ │ ├── flickr_wild_003137.jpg
│ │ │ │ ├── flickr_wild_003355.jpg
│ │ │ │ ├── flickr_wild_003796.jpg
│ │ │ │ ├── flickr_wild_003969.jpg
│ │ │ │ └── pixabay_wild_000637.jpg
│ │ └── src
│ │ │ ├── cat
│ │ │ ├── flickr_cat_000253.jpg
│ │ │ ├── pixabay_cat_000181.jpg
│ │ │ ├── pixabay_cat_000241.jpg
│ │ │ ├── pixabay_cat_000276.jpg
│ │ │ └── pixabay_cat_004826.jpg
│ │ │ ├── dog
│ │ │ ├── flickr_dog_000094.jpg
│ │ │ ├── pixabay_dog_000321.jpg
│ │ │ ├── pixabay_dog_000322.jpg
│ │ │ ├── pixabay_dog_001082.jpg
│ │ │ └── pixabay_dog_002066.jpg
│ │ │ └── wild
│ │ │ ├── flickr_wild_000432.jpg
│ │ │ ├── flickr_wild_000814.jpg
│ │ │ ├── flickr_wild_002036.jpg
│ │ │ ├── flickr_wild_002159.jpg
│ │ │ └── pixabay_wild_000558.jpg
│ ├── celeba_hq
│ │ ├── ref
│ │ │ ├── female
│ │ │ │ ├── 015248.jpg
│ │ │ │ ├── 030321.jpg
│ │ │ │ ├── 031796.jpg
│ │ │ │ ├── 036619.jpg
│ │ │ │ ├── 042373.jpg
│ │ │ │ ├── 048197.jpg
│ │ │ │ ├── 052599.jpg
│ │ │ │ ├── 058150.jpg
│ │ │ │ ├── 058225.jpg
│ │ │ │ ├── 058881.jpg
│ │ │ │ ├── 063109.jpg
│ │ │ │ ├── 064119.jpg
│ │ │ │ ├── 064307.jpg
│ │ │ │ ├── 074075.jpg
│ │ │ │ ├── 074934.jpg
│ │ │ │ ├── 076551.jpg
│ │ │ │ ├── 081680.jpg
│ │ │ │ ├── 081871.jpg
│ │ │ │ ├── 084913.jpg
│ │ │ │ ├── 086986.jpg
│ │ │ │ ├── 113393.jpg
│ │ │ │ ├── 135626.jpg
│ │ │ │ ├── 140613.jpg
│ │ │ │ ├── 142595.jpg
│ │ │ │ └── 195650.jpg
│ │ │ └── male
│ │ │ │ ├── 012712.jpg
│ │ │ │ ├── 020167.jpg
│ │ │ │ ├── 021612.jpg
│ │ │ │ ├── 036367.jpg
│ │ │ │ ├── 037023.jpg
│ │ │ │ ├── 038919.jpg
│ │ │ │ ├── 047763.jpg
│ │ │ │ ├── 060259.jpg
│ │ │ │ ├── 067791.jpg
│ │ │ │ ├── 077921.jpg
│ │ │ │ ├── 083510.jpg
│ │ │ │ ├── 094805.jpg
│ │ │ │ ├── 116032.jpg
│ │ │ │ ├── 118017.jpg
│ │ │ │ ├── 137590.jpg
│ │ │ │ ├── 145842.jpg
│ │ │ │ ├── 153793.jpg
│ │ │ │ ├── 156498.jpg
│ │ │ │ ├── 164930.jpg
│ │ │ │ ├── 189498.jpg
│ │ │ │ └── 191084.jpg
│ │ └── src
│ │ │ ├── female
│ │ │ ├── 039913.jpg
│ │ │ ├── 051340.jpg
│ │ │ ├── 069067.jpg
│ │ │ ├── 091623.jpg
│ │ │ └── 172559.jpg
│ │ │ └── male
│ │ │ ├── 005735.jpg
│ │ │ ├── 006930.jpg
│ │ │ ├── 016387.jpg
│ │ │ ├── 191300.jpg
│ │ │ └── 196930.jpg
│ └── custom
│ │ ├── female
│ │ └── custom_female.jpg
│ │ └── male
│ │ └── custom_male.jpg
├── teaser.jpg
└── youtube_video.jpg
├── core
├── __init__.py
├── checkpoint.py
├── data_loader.py
├── model.py
├── solver.py
├── utils.py
└── wing.py
├── download.sh
├── main.py
└── metrics
├── __init__.py
├── eval.py
├── fid.py
├── lpips.py
└── lpips_weights.ckpt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020-present, NAVER Corp.
2 | All rights reserved.
3 |
4 |
5 | Attribution-NonCommercial 4.0 International
6 |
7 | =======================================================================
8 |
9 | Creative Commons Corporation ("Creative Commons") is not a law firm and
10 | does not provide legal services or legal advice. Distribution of
11 | Creative Commons public licenses does not create a lawyer-client or
12 | other relationship. Creative Commons makes its licenses and related
13 | information available on an "as-is" basis. Creative Commons gives no
14 | warranties regarding its licenses, any material licensed under their
15 | terms and conditions, or any related information. Creative Commons
16 | disclaims all liability for damages resulting from their use to the
17 | fullest extent possible.
18 |
19 | Using Creative Commons Public Licenses
20 |
21 | Creative Commons public licenses provide a standard set of terms and
22 | conditions that creators and other rights holders may use to share
23 | original works of authorship and other material subject to copyright
24 | and certain other rights specified in the public license below. The
25 | following considerations are for informational purposes only, are not
26 | exhaustive, and do not form part of our licenses.
27 |
28 | Considerations for licensors: Our public licenses are
29 | intended for use by those authorized to give the public
30 | permission to use material in ways otherwise restricted by
31 | copyright and certain other rights. Our licenses are
32 | irrevocable. Licensors should read and understand the terms
33 | and conditions of the license they choose before applying it.
34 | Licensors should also secure all rights necessary before
35 | applying our licenses so that the public can reuse the
36 | material as expected. Licensors should clearly mark any
37 | material not subject to the license. This includes other CC-
38 | licensed material, or material used under an exception or
39 | limitation to copyright. More considerations for licensors:
40 | wiki.creativecommons.org/Considerations_for_licensors
41 |
42 | Considerations for the public: By using one of our public
43 | licenses, a licensor grants the public permission to use the
44 | licensed material under specified terms and conditions. If
45 | the licensor's permission is not necessary for any reason--for
46 | example, because of any applicable exception or limitation to
47 | copyright--then that use is not regulated by the license. Our
48 | licenses grant only permissions under copyright and certain
49 | other rights that a licensor has authority to grant. Use of
50 | the licensed material may still be restricted for other
51 | reasons, including because others have copyright or other
52 | rights in the material. A licensor may make special requests,
53 | such as asking that all changes be marked or described.
54 | Although not required by our licenses, you are encouraged to
55 | respect those requests where reasonable. More_considerations
56 | for the public:
57 | wiki.creativecommons.org/Considerations_for_licensees
58 |
59 | =======================================================================
60 |
61 | Creative Commons Attribution-NonCommercial 4.0 International Public
62 | License
63 |
64 | By exercising the Licensed Rights (defined below), You accept and agree
65 | to be bound by the terms and conditions of this Creative Commons
66 | Attribution-NonCommercial 4.0 International Public License ("Public
67 | License"). To the extent this Public License may be interpreted as a
68 | contract, You are granted the Licensed Rights in consideration of Your
69 | acceptance of these terms and conditions, and the Licensor grants You
70 | such rights in consideration of benefits the Licensor receives from
71 | making the Licensed Material available under these terms and
72 | conditions.
73 |
74 |
75 | Section 1 -- Definitions.
76 |
77 | a. Adapted Material means material subject to Copyright and Similar
78 | Rights that is derived from or based upon the Licensed Material
79 | and in which the Licensed Material is translated, altered,
80 | arranged, transformed, or otherwise modified in a manner requiring
81 | permission under the Copyright and Similar Rights held by the
82 | Licensor. For purposes of this Public License, where the Licensed
83 | Material is a musical work, performance, or sound recording,
84 | Adapted Material is always produced where the Licensed Material is
85 | synched in timed relation with a moving image.
86 |
87 | b. Adapter's License means the license You apply to Your Copyright
88 | and Similar Rights in Your contributions to Adapted Material in
89 | accordance with the terms and conditions of this Public License.
90 |
91 | c. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 | d. Effective Technological Measures means those measures that, in the
99 | absence of proper authority, may not be circumvented under laws
100 | fulfilling obligations under Article 11 of the WIPO Copyright
101 | Treaty adopted on December 20, 1996, and/or similar international
102 | agreements.
103 |
104 | e. Exceptions and Limitations means fair use, fair dealing, and/or
105 | any other exception or limitation to Copyright and Similar Rights
106 | that applies to Your use of the Licensed Material.
107 |
108 | f. Licensed Material means the artistic or literary work, database,
109 | or other material to which the Licensor applied this Public
110 | License.
111 |
112 | g. Licensed Rights means the rights granted to You subject to the
113 | terms and conditions of this Public License, which are limited to
114 | all Copyright and Similar Rights that apply to Your use of the
115 | Licensed Material and that the Licensor has authority to license.
116 |
117 | h. Licensor means the individual(s) or entity(ies) granting rights
118 | under this Public License.
119 |
120 | i. NonCommercial means not primarily intended for or directed towards
121 | commercial advantage or monetary compensation. For purposes of
122 | this Public License, the exchange of the Licensed Material for
123 | other material subject to Copyright and Similar Rights by digital
124 | file-sharing or similar means is NonCommercial provided there is
125 | no payment of monetary compensation in connection with the
126 | exchange.
127 |
128 | j. Share means to provide material to the public by any means or
129 | process that requires permission under the Licensed Rights, such
130 | as reproduction, public display, public performance, distribution,
131 | dissemination, communication, or importation, and to make material
132 | available to the public including in ways that members of the
133 | public may access the material from a place and at a time
134 | individually chosen by them.
135 |
136 | k. Sui Generis Database Rights means rights other than copyright
137 | resulting from Directive 96/9/EC of the European Parliament and of
138 | the Council of 11 March 1996 on the legal protection of databases,
139 | as amended and/or succeeded, as well as other essentially
140 | equivalent rights anywhere in the world.
141 |
142 | l. You means the individual or entity exercising the Licensed Rights
143 | under this Public License. Your has a corresponding meaning.
144 |
145 |
146 | Section 2 -- Scope.
147 |
148 | a. License grant.
149 |
150 | 1. Subject to the terms and conditions of this Public License,
151 | the Licensor hereby grants You a worldwide, royalty-free,
152 | non-sublicensable, non-exclusive, irrevocable license to
153 | exercise the Licensed Rights in the Licensed Material to:
154 |
155 | a. reproduce and Share the Licensed Material, in whole or
156 | in part, for NonCommercial purposes only; and
157 |
158 | b. produce, reproduce, and Share Adapted Material for
159 | NonCommercial purposes only.
160 |
161 | 2. Exceptions and Limitations. For the avoidance of doubt, where
162 | Exceptions and Limitations apply to Your use, this Public
163 | License does not apply, and You do not need to comply with
164 | its terms and conditions.
165 |
166 | 3. Term. The term of this Public License is specified in Section
167 | 6(a).
168 |
169 | 4. Media and formats; technical modifications allowed. The
170 | Licensor authorizes You to exercise the Licensed Rights in
171 | all media and formats whether now known or hereafter created,
172 | and to make technical modifications necessary to do so. The
173 | Licensor waives and/or agrees not to assert any right or
174 | authority to forbid You from making technical modifications
175 | necessary to exercise the Licensed Rights, including
176 | technical modifications necessary to circumvent Effective
177 | Technological Measures. For purposes of this Public License,
178 | simply making modifications authorized by this Section 2(a)
179 | (4) never produces Adapted Material.
180 |
181 | 5. Downstream recipients.
182 |
183 | a. Offer from the Licensor -- Licensed Material. Every
184 | recipient of the Licensed Material automatically
185 | receives an offer from the Licensor to exercise the
186 | Licensed Rights under the terms and conditions of this
187 | Public License.
188 |
189 | b. No downstream restrictions. You may not offer or impose
190 | any additional or different terms or conditions on, or
191 | apply any Effective Technological Measures to, the
192 | Licensed Material if doing so restricts exercise of the
193 | Licensed Rights by any recipient of the Licensed
194 | Material.
195 |
196 | 6. No endorsement. Nothing in this Public License constitutes or
197 | may be construed as permission to assert or imply that You
198 | are, or that Your use of the Licensed Material is, connected
199 | with, or sponsored, endorsed, or granted official status by,
200 | the Licensor or others designated to receive attribution as
201 | provided in Section 3(a)(1)(A)(i).
202 |
203 | b. Other rights.
204 |
205 | 1. Moral rights, such as the right of integrity, are not
206 | licensed under this Public License, nor are publicity,
207 | privacy, and/or other similar personality rights; however, to
208 | the extent possible, the Licensor waives and/or agrees not to
209 | assert any such rights held by the Licensor to the limited
210 | extent necessary to allow You to exercise the Licensed
211 | Rights, but not otherwise.
212 |
213 | 2. Patent and trademark rights are not licensed under this
214 | Public License.
215 |
216 | 3. To the extent possible, the Licensor waives any right to
217 | collect royalties from You for the exercise of the Licensed
218 | Rights, whether directly or through a collecting society
219 | under any voluntary or waivable statutory or compulsory
220 | licensing scheme. In all other cases the Licensor expressly
221 | reserves any right to collect such royalties, including when
222 | the Licensed Material is used other than for NonCommercial
223 | purposes.
224 |
225 |
226 | Section 3 -- License Conditions.
227 |
228 | Your exercise of the Licensed Rights is expressly made subject to the
229 | following conditions.
230 |
231 | a. Attribution.
232 |
233 | 1. If You Share the Licensed Material (including in modified
234 | form), You must:
235 |
236 | a. retain the following if it is supplied by the Licensor
237 | with the Licensed Material:
238 |
239 | i. identification of the creator(s) of the Licensed
240 | Material and any others designated to receive
241 | attribution, in any reasonable manner requested by
242 | the Licensor (including by pseudonym if
243 | designated);
244 |
245 | ii. a copyright notice;
246 |
247 | iii. a notice that refers to this Public License;
248 |
249 | iv. a notice that refers to the disclaimer of
250 | warranties;
251 |
252 | v. a URI or hyperlink to the Licensed Material to the
253 | extent reasonably practicable;
254 |
255 | b. indicate if You modified the Licensed Material and
256 | retain an indication of any previous modifications; and
257 |
258 | c. indicate the Licensed Material is licensed under this
259 | Public License, and include the text of, or the URI or
260 | hyperlink to, this Public License.
261 |
262 | 2. You may satisfy the conditions in Section 3(a)(1) in any
263 | reasonable manner based on the medium, means, and context in
264 | which You Share the Licensed Material. For example, it may be
265 | reasonable to satisfy the conditions by providing a URI or
266 | hyperlink to a resource that includes the required
267 | information.
268 |
269 | 3. If requested by the Licensor, You must remove any of the
270 | information required by Section 3(a)(1)(A) to the extent
271 | reasonably practicable.
272 |
273 | 4. If You Share Adapted Material You produce, the Adapter's
274 | License You apply must not prevent recipients of the Adapted
275 | Material from complying with this Public License.
276 |
277 |
278 | Section 4 -- Sui Generis Database Rights.
279 |
280 | Where the Licensed Rights include Sui Generis Database Rights that
281 | apply to Your use of the Licensed Material:
282 |
283 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
284 | to extract, reuse, reproduce, and Share all or a substantial
285 | portion of the contents of the database for NonCommercial purposes
286 | only;
287 |
288 | b. if You include all or a substantial portion of the database
289 | contents in a database in which You have Sui Generis Database
290 | Rights, then the database in which You have Sui Generis Database
291 | Rights (but not its individual contents) is Adapted Material; and
292 |
293 | c. You must comply with the conditions in Section 3(a) if You Share
294 | all or a substantial portion of the contents of the database.
295 |
296 | For the avoidance of doubt, this Section 4 supplements and does not
297 | replace Your obligations under this Public License where the Licensed
298 | Rights include other Copyright and Similar Rights.
299 |
300 |
301 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
302 |
303 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
304 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
305 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
306 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
307 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
308 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
309 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
310 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
311 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
312 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
313 |
314 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
315 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
316 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
317 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
318 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
319 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
320 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
321 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
322 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
323 |
324 | c. The disclaimer of warranties and limitation of liability provided
325 | above shall be interpreted in a manner that, to the extent
326 | possible, most closely approximates an absolute disclaimer and
327 | waiver of all liability.
328 |
329 |
330 | Section 6 -- Term and Termination.
331 |
332 | a. This Public License applies for the term of the Copyright and
333 | Similar Rights licensed here. However, if You fail to comply with
334 | this Public License, then Your rights under this Public License
335 | terminate automatically.
336 |
337 | b. Where Your right to use the Licensed Material has terminated under
338 | Section 6(a), it reinstates:
339 |
340 | 1. automatically as of the date the violation is cured, provided
341 | it is cured within 30 days of Your discovery of the
342 | violation; or
343 |
344 | 2. upon express reinstatement by the Licensor.
345 |
346 | For the avoidance of doubt, this Section 6(b) does not affect any
347 | right the Licensor may have to seek remedies for Your violations
348 | of this Public License.
349 |
350 | c. For the avoidance of doubt, the Licensor may also offer the
351 | Licensed Material under separate terms or conditions or stop
352 | distributing the Licensed Material at any time; however, doing so
353 | will not terminate this Public License.
354 |
355 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
356 | License.
357 |
358 |
359 | Section 7 -- Other Terms and Conditions.
360 |
361 | a. The Licensor shall not be bound by any additional or different
362 | terms or conditions communicated by You unless expressly agreed.
363 |
364 | b. Any arrangements, understandings, or agreements regarding the
365 | Licensed Material not stated herein are separate from and
366 | independent of the terms and conditions of this Public License.
367 |
368 |
369 | Section 8 -- Interpretation.
370 |
371 | a. For the avoidance of doubt, this Public License does not, and
372 | shall not be interpreted to, reduce, limit, restrict, or impose
373 | conditions on any use of the Licensed Material that could lawfully
374 | be made without permission under this Public License.
375 |
376 | b. To the extent possible, if any provision of this Public License is
377 | deemed unenforceable, it shall be automatically reformed to the
378 | minimum extent necessary to make it enforceable. If the provision
379 | cannot be reformed, it shall be severed from this Public License
380 | without affecting the enforceability of the remaining terms and
381 | conditions.
382 |
383 | c. No term or condition of this Public License will be waived and no
384 | failure to comply consented to unless expressly agreed to by the
385 | Licensor.
386 |
387 | d. Nothing in this Public License constitutes or may be interpreted
388 | as a limitation upon, or waiver of, any privileges and immunities
389 | that apply to the Licensor or You, including from the legal
390 | processes of any jurisdiction or authority.
391 |
392 | =======================================================================
393 |
394 | Creative Commons is not a party to its public
395 | licenses. Notwithstanding, Creative Commons may elect to apply one of
396 | its public licenses to material it publishes and in those instances
397 | will be considered the "Licensor." The text of the Creative Commons
398 | public licenses is dedicated to the public domain under the CC0 Public
399 | Domain Dedication. Except for the limited purpose of indicating that
400 | material is shared under a Creative Commons public license or as
401 | otherwise permitted by the Creative Commons policies published at
402 | creativecommons.org/policies, Creative Commons does not authorize the
403 | use of the trademark "Creative Commons" or any other trademark or logo
404 | of Creative Commons without its prior written consent including,
405 | without limitation, in connection with any unauthorized modifications
406 | to any of its public licenses or any other arrangements,
407 | understandings, or agreements concerning use of licensed material. For
408 | the avoidance of doubt, this paragraph does not form part of the
409 | public licenses.
410 |
411 | Creative Commons may be contacted at creativecommons.org.
412 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | StarGAN v2
2 |
3 | Copyright (c) 2020-present NAVER Corp.
4 | All rights reserved.
5 |
6 | This work is licensed under the Creative Commons Attribution-NonCommercial
7 | 4.0 International License. To view a copy of this license, visit
8 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
9 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
10 |
11 | --------------------------------------------------------------------------------------
12 |
13 | This project contains subcomponents with separate copyright notices and license terms.
14 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
15 |
16 | =====
17 |
18 | 1adrianb/face-alignment
19 | https://github.com/1adrianb/face-alignment
20 |
21 |
22 | BSD 3-Clause License
23 |
24 | Copyright (c) 2017, Adrian Bulat
25 | All rights reserved.
26 |
27 | Redistribution and use in source and binary forms, with or without
28 | modification, are permitted provided that the following conditions are met:
29 |
30 | * Redistributions of source code must retain the above copyright notice, this
31 | list of conditions and the following disclaimer.
32 |
33 | * Redistributions in binary form must reproduce the above copyright notice,
34 | this list of conditions and the following disclaimer in the documentation
35 | and/or other materials provided with the distribution.
36 |
37 | * Neither the name of the copyright holder nor the names of its
38 | contributors may be used to endorse or promote products derived from
39 | this software without specific prior written permission.
40 |
41 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
42 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
43 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
44 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
45 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
46 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
47 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
48 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
49 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
50 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
51 |
52 | =====
53 |
54 | protossw512/AdaptiveWingLoss
55 | https://github.com/protossw512/AdaptiveWingLoss
56 |
57 |
58 | [ICCV 2019] Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression - Official Implementation
59 |
60 |
61 | Licensed under the Apache License, Version 2.0 (the "License");
62 | you may not use this file except in compliance with the License.
63 | You may obtain a copy of the License at
64 |
65 | http://www.apache.org/licenses/LICENSE-2.0
66 |
67 | Unless required by applicable law or agreed to in writing, software
68 | distributed under the License is distributed on an "AS IS" BASIS,
69 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70 | See the License for the specific language governing permissions and
71 | limitations under the License.
72 |
73 | ---
74 |
75 | author = {Wang, Xinyao and Bo, Liefeng and Fuxin, Li},
76 | title = {Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression},
77 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
78 | month = {October},
79 | year = {2019}
80 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## StarGAN v2 - Official PyTorch Implementation
3 |
4 |

5 |
6 | > **StarGAN v2: Diverse Image Synthesis for Multiple Domains**
7 | > [Yunjey Choi](https://github.com/yunjey)\*, [Youngjung Uh](https://github.com/youngjung)\*, [Jaejun Yoo](http://jaejunyoo.blogspot.com/search/label/kr)\*, [Jung-Woo Ha](https://www.facebook.com/jungwoo.ha.921)
8 | > In CVPR 2020. (* indicates equal contribution)
9 |
10 | > Paper: https://arxiv.org/abs/1912.01865
11 | > Video: https://youtu.be/0EVh5Ki4dIY
12 |
13 | > **Abstract:** *A good image-to-image translation model should learn a mapping between different visual domains while satisfying the following properties: 1) diversity of generated images and 2) scalability over multiple domains. Existing methods address either of the issues, having limited diversity or multiple models for all domains. We propose StarGAN v2, a single framework that tackles both and shows significantly improved results over the baselines. Experiments on CelebA-HQ and a new animal faces dataset (AFHQ) validate our superiority in terms of visual quality, diversity, and scalability. To better assess image-to-image translation models, we release AFHQ, high-quality animal faces with large inter- and intra-domain variations. The code, pre-trained models, and dataset are available at clovaai/stargan-v2.*
14 |
15 | ## Teaser video
16 | Click the figure to watch the teaser video.
17 |
18 | [](https://youtu.be/0EVh5Ki4dIY)
19 |
20 | ## TensorFlow implementation
21 | The TensorFlow implementation of StarGAN v2 by our team member junho can be found at [clovaai/stargan-v2-tensorflow](https://github.com/clovaai/stargan-v2-tensorflow).
22 |
23 | ## Software installation
24 | Clone this repository:
25 |
26 | ```bash
27 | git clone https://github.com/clovaai/stargan-v2.git
28 | cd stargan-v2/
29 | ```
30 |
31 | Install the dependencies:
32 | ```bash
33 | conda create -n stargan-v2 python=3.6.7
34 | conda activate stargan-v2
35 | conda install -y pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.0 -c pytorch
36 | conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
37 | pip install opencv-python==4.1.2.30 ffmpeg-python==0.2.0 scikit-image==0.16.2
38 | pip install pillow==7.0.0 scipy==1.2.1 tqdm==4.43.0 munch==2.5.0
39 | ```
40 |
41 | ## Datasets and pre-trained networks
42 | We provide a script to download datasets used in StarGAN v2 and the corresponding pre-trained networks. The datasets and network checkpoints will be downloaded and stored in the `data` and `expr/checkpoints` directories, respectively.
43 |
44 | CelebA-HQ. To download the [CelebA-HQ](https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs) dataset and the pre-trained network, run the following commands:
45 | ```bash
46 | bash download.sh celeba-hq-dataset
47 | bash download.sh pretrained-network-celeba-hq
48 | bash download.sh wing
49 | ```
50 |
51 | AFHQ. To download the [AFHQ](https://github.com/clovaai/stargan-v2/blob/master/README.md#animal-faces-hq-dataset-afhq) dataset and the pre-trained network, run the following commands:
52 | ```bash
53 | bash download.sh afhq-dataset
54 | bash download.sh pretrained-network-afhq
55 | ```
56 |
57 |
58 | ## Generating interpolation videos
59 | After downloading the pre-trained networks, you can synthesize output images reflecting diverse styles (e.g., hairstyle) of reference images. The following commands will save generated images and interpolation videos to the `expr/results` directory.
60 |
61 |
62 | CelebA-HQ. To generate images and interpolation videos, run the following command:
63 | ```bash
64 | python main.py --mode sample --num_domains 2 --resume_iter 100000 --w_hpf 1 \
65 | --checkpoint_dir expr/checkpoints/celeba_hq \
66 | --result_dir expr/results/celeba_hq \
67 | --src_dir assets/representative/celeba_hq/src \
68 | --ref_dir assets/representative/celeba_hq/ref
69 | ```
70 |
71 | To transform a custom image, first crop the image manually so that the proportion of face occupied in the whole is similar to that of CelebA-HQ. Then, run the following command for additional fine rotation and cropping. All custom images in the `inp_dir` directory will be aligned and stored in the `out_dir` directory.
72 |
73 | ```bash
74 | python main.py --mode align \
75 | --inp_dir assets/representative/custom/female \
76 | --out_dir assets/representative/celeba_hq/src/female
77 | ```
78 |
79 |
80 | 
81 |
82 |
83 | AFHQ. To generate images and interpolation videos, run the following command:
84 | ```bash
85 | python main.py --mode sample --num_domains 3 --resume_iter 100000 --w_hpf 0 \
86 | --checkpoint_dir expr/checkpoints/afhq \
87 | --result_dir expr/results/afhq \
88 | --src_dir assets/representative/afhq/src \
89 | --ref_dir assets/representative/afhq/ref
90 | ```
91 |
92 | 
93 |
94 | ## Evaluation metrics
95 | To evaluate StarGAN v2 using [Fréchet Inception Distance (FID)](https://arxiv.org/abs/1706.08500) and [Learned Perceptual Image Patch Similarity (LPIPS)](https://arxiv.org/abs/1801.03924), run the following commands:
96 |
97 |
98 | ```bash
99 | # celeba-hq
100 | python main.py --mode eval --num_domains 2 --w_hpf 1 \
101 | --resume_iter 100000 \
102 | --train_img_dir data/celeba_hq/train \
103 | --val_img_dir data/celeba_hq/val \
104 | --checkpoint_dir expr/checkpoints/celeba_hq \
105 | --eval_dir expr/eval/celeba_hq
106 |
107 | # afhq
108 | python main.py --mode eval --num_domains 3 --w_hpf 0 \
109 | --resume_iter 100000 \
110 | --train_img_dir data/afhq/train \
111 | --val_img_dir data/afhq/val \
112 | --checkpoint_dir expr/checkpoints/afhq \
113 | --eval_dir expr/eval/afhq
114 | ```
115 |
116 | Note that the evaluation metrics are calculated using random latent vectors or reference images, both of which are selected by the [seed number](https://github.com/clovaai/stargan-v2/blob/master/main.py#L35). In the paper, we reported the average of values from 10 measurements using different seed numbers. The following table shows the calculated values for both latent-guided and reference-guided synthesis.
117 |
118 | | Dataset
|
FID (latent)
|
LPIPS (latent)
|
FID (reference)
| LPIPS (reference) |
Elapsed time
|
119 | | :---------- | :------------: | :----: | :-----: | :----: | :----------:|
120 | | `celeba-hq` | 13.73 ± 0.06 | 0.4515 ± 0.0006 | 23.84 ± 0.03 | 0.3880 ± 0.0001 | 49min 51s
121 | | `afhq` | 16.18 ± 0.15 | 0.4501 ± 0.0007 | 19.78 ± 0.01 | 0.4315 ± 0.0002 | 64min 49s
122 |
123 |
124 |
125 | ## Training networks
126 | To train StarGAN v2 from scratch, run the following commands. Generated images and network checkpoints will be stored in the `expr/samples` and `expr/checkpoints` directories, respectively. Training takes about three days on a single Tesla V100 GPU. Please see [here](https://github.com/clovaai/stargan-v2/blob/master/main.py#L86-L179) for training arguments and a description of them.
127 |
128 | ```bash
129 | # celeba-hq
130 | python main.py --mode train --num_domains 2 --w_hpf 1 \
131 | --lambda_reg 1 --lambda_sty 1 --lambda_ds 1 --lambda_cyc 1 \
132 | --train_img_dir data/celeba_hq/train \
133 | --val_img_dir data/celeba_hq/val
134 |
135 | # afhq
136 | python main.py --mode train --num_domains 3 --w_hpf 0 \
137 | --lambda_reg 1 --lambda_sty 1 --lambda_ds 2 --lambda_cyc 1 \
138 | --train_img_dir data/afhq/train \
139 | --val_img_dir data/afhq/val
140 | ```
141 |
142 | ## Animal Faces-HQ dataset (AFHQ)
143 |
144 | 
145 |
146 | We release a new dataset of animal faces, Animal Faces-HQ (AFHQ), consisting of 15,000 high-quality images at 512×512 resolution. The figure above shows example images of the AFHQ dataset. The dataset includes three domains of cat, dog, and wildlife, each providing about 5000 images. By having multiple (three) domains and diverse images of various breeds per each domain, AFHQ sets a challenging image-to-image translation problem. For each domain, we select 500 images as a test set and provide all remaining images as a training set. To download the dataset, run the following command:
147 |
148 | ```bash
149 | bash download.sh afhq-dataset
150 | ```
151 |
152 |
153 | **[Update: 2021.07.01]** We rebuild the original AFHQ dataset by using high-quality resize filtering (i.e., Lanczos resampling). Please see the [clean FID paper](https://arxiv.org/abs/2104.11222) that brings attention to the unfortunate software library situation for downsampling. We thank to [Alias-Free GAN](https://nvlabs.github.io/alias-free-gan/) authors for their suggestion and contribution to the updated AFHQ dataset. If you use the updated dataset, we recommend to cite not only our paper but also their paper.
154 |
155 | The differences from the original dataset are as follows:
156 | * We resize the images using Lanczos resampling instead of nearest neighbor downsampling.
157 | * About 2% of the original images had been removed. So the set is now has 15803 images, whereas the original had 16130.
158 | * Images are saved as PNG format to avoid compression artifacts. This makes the files bigger than the original, but it's worth it.
159 |
160 |
161 | To download the updated dataset, run the following command:
162 |
163 | ```bash
164 | bash download.sh afhq-v2-dataset
165 | ```
166 |
167 | 
168 |
169 |
170 |
171 | ## License
172 | The source code, pre-trained models, and dataset are available under [Creative Commons BY-NC 4.0](https://github.com/clovaai/stargan-v2/blob/master/LICENSE) license by NAVER Corporation. You can **use, copy, tranform and build upon** the material for **non-commercial purposes** as long as you give **appropriate credit** by citing our paper, and indicate if changes were made.
173 |
174 | For business inquiries, please contact clova-jobs@navercorp.com.
175 | For technical and other inquires, please contact yunjey.choi@navercorp.com.
176 |
177 |
178 | ## Citation
179 | If you find this work useful for your research, please cite our paper:
180 |
181 | ```
182 | @inproceedings{choi2020starganv2,
183 | title={StarGAN v2: Diverse Image Synthesis for Multiple Domains},
184 | author={Yunjey Choi and Youngjung Uh and Jaejun Yoo and Jung-Woo Ha},
185 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
186 | year={2020}
187 | }
188 | ```
189 |
190 | ## Acknowledgements
191 | We would like to thank the full-time and visiting Clova AI Research (now NAVER AI Lab) members for their valuable feedback and an early review: especially Seongjoon Oh, Junsuk Choe, Muhammad Ferjad Naeem, and Kyungjune Baek. We also thank Alias-Free GAN authors for their contribution to the updated AFHQ dataset.
192 |
--------------------------------------------------------------------------------
/assets/afhq_dataset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/afhq_dataset.jpg
--------------------------------------------------------------------------------
/assets/afhq_interpolation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/afhq_interpolation.gif
--------------------------------------------------------------------------------
/assets/afhqv2_teaser2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/afhqv2_teaser2.jpg
--------------------------------------------------------------------------------
/assets/celebahq_interpolation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/celebahq_interpolation.gif
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/flickr_cat_000495.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/flickr_cat_000495.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/flickr_cat_000557.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/flickr_cat_000557.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_000355.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000355.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_000491.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000491.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_000535.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000535.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_000623.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000623.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_000730.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_000730.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_001479.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_001479.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_001699.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_001699.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/cat/pixabay_cat_003046.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/cat/pixabay_cat_003046.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/flickr_dog_001072.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/flickr_dog_001072.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/pixabay_dog_000121.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000121.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/pixabay_dog_000322.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000322.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/pixabay_dog_000357.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000357.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/pixabay_dog_000409.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000409.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/pixabay_dog_000799.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000799.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/pixabay_dog_000890.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_000890.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/dog/pixabay_dog_001082.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/dog/pixabay_dog_001082.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_000731.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_000731.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_001223.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_001223.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_002020.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_002020.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_002092.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_002092.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_002933.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_002933.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_003137.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003137.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_003355.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003355.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_003796.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003796.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/flickr_wild_003969.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/flickr_wild_003969.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/ref/wild/pixabay_wild_000637.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/ref/wild/pixabay_wild_000637.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/cat/flickr_cat_000253.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/flickr_cat_000253.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/cat/pixabay_cat_000181.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_000181.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/cat/pixabay_cat_000241.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_000241.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/cat/pixabay_cat_000276.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_000276.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/cat/pixabay_cat_004826.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/cat/pixabay_cat_004826.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/dog/flickr_dog_000094.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/flickr_dog_000094.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/dog/pixabay_dog_000321.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_000321.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/dog/pixabay_dog_000322.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_000322.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/dog/pixabay_dog_001082.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_001082.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/dog/pixabay_dog_002066.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/dog/pixabay_dog_002066.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/wild/flickr_wild_000432.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_000432.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/wild/flickr_wild_000814.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_000814.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/wild/flickr_wild_002036.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_002036.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/wild/flickr_wild_002159.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/flickr_wild_002159.jpg
--------------------------------------------------------------------------------
/assets/representative/afhq/src/wild/pixabay_wild_000558.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/afhq/src/wild/pixabay_wild_000558.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/015248.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/015248.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/030321.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/030321.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/031796.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/031796.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/036619.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/036619.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/042373.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/042373.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/048197.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/048197.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/052599.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/052599.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/058150.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/058150.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/058225.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/058225.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/058881.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/058881.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/063109.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/063109.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/064119.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/064119.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/064307.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/064307.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/074075.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/074075.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/074934.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/074934.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/076551.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/076551.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/081680.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/081680.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/081871.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/081871.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/084913.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/084913.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/086986.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/086986.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/113393.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/113393.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/135626.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/135626.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/140613.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/140613.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/142595.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/142595.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/female/195650.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/female/195650.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/012712.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/012712.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/020167.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/020167.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/021612.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/021612.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/036367.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/036367.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/037023.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/037023.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/038919.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/038919.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/047763.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/047763.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/060259.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/060259.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/067791.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/067791.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/077921.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/077921.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/083510.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/083510.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/094805.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/094805.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/116032.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/116032.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/118017.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/118017.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/137590.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/137590.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/145842.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/145842.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/153793.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/153793.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/156498.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/156498.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/164930.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/164930.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/189498.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/189498.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/ref/male/191084.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/ref/male/191084.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/female/039913.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/039913.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/female/051340.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/051340.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/female/069067.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/069067.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/female/091623.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/091623.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/female/172559.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/female/172559.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/male/005735.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/005735.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/male/006930.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/006930.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/male/016387.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/016387.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/male/191300.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/191300.jpg
--------------------------------------------------------------------------------
/assets/representative/celeba_hq/src/male/196930.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/celeba_hq/src/male/196930.jpg
--------------------------------------------------------------------------------
/assets/representative/custom/female/custom_female.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/custom/female/custom_female.jpg
--------------------------------------------------------------------------------
/assets/representative/custom/male/custom_male.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/representative/custom/male/custom_male.jpg
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/teaser.jpg
--------------------------------------------------------------------------------
/assets/youtube_video.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/assets/youtube_video.jpg
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/core/__init__.py
--------------------------------------------------------------------------------
/core/checkpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import os
12 | import torch
13 |
14 |
15 | class CheckpointIO(object):
16 | def __init__(self, fname_template, data_parallel=False, **kwargs):
17 | os.makedirs(os.path.dirname(fname_template), exist_ok=True)
18 | self.fname_template = fname_template
19 | self.module_dict = kwargs
20 | self.data_parallel = data_parallel
21 |
22 | def register(self, **kwargs):
23 | self.module_dict.update(kwargs)
24 |
25 | def save(self, step):
26 | fname = self.fname_template.format(step)
27 | print('Saving checkpoint into %s...' % fname)
28 | outdict = {}
29 | for name, module in self.module_dict.items():
30 | if self.data_parallel:
31 | outdict[name] = module.module.state_dict()
32 | else:
33 | outdict[name] = module.state_dict()
34 |
35 | torch.save(outdict, fname)
36 |
37 | def load(self, step):
38 | fname = self.fname_template.format(step)
39 | assert os.path.exists(fname), fname + ' does not exist!'
40 | print('Loading checkpoint from %s...' % fname)
41 | if torch.cuda.is_available():
42 | module_dict = torch.load(fname)
43 | else:
44 | module_dict = torch.load(fname, map_location=torch.device('cpu'))
45 |
46 | for name, module in self.module_dict.items():
47 | if self.data_parallel:
48 | module.module.load_state_dict(module_dict[name])
49 | else:
50 | module.load_state_dict(module_dict[name])
51 |
--------------------------------------------------------------------------------
/core/data_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | from pathlib import Path
12 | from itertools import chain
13 | import os
14 | import random
15 |
16 | from munch import Munch
17 | from PIL import Image
18 | import numpy as np
19 |
20 | import torch
21 | from torch.utils import data
22 | from torch.utils.data.sampler import WeightedRandomSampler
23 | from torchvision import transforms
24 | from torchvision.datasets import ImageFolder
25 |
26 |
27 | def listdir(dname):
28 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext))
29 | for ext in ['png', 'jpg', 'jpeg', 'JPG']]))
30 | return fnames
31 |
32 |
33 | class DefaultDataset(data.Dataset):
34 | def __init__(self, root, transform=None):
35 | self.samples = listdir(root)
36 | self.samples.sort()
37 | self.transform = transform
38 | self.targets = None
39 |
40 | def __getitem__(self, index):
41 | fname = self.samples[index]
42 | img = Image.open(fname).convert('RGB')
43 | if self.transform is not None:
44 | img = self.transform(img)
45 | return img
46 |
47 | def __len__(self):
48 | return len(self.samples)
49 |
50 |
51 | class ReferenceDataset(data.Dataset):
52 | def __init__(self, root, transform=None):
53 | self.samples, self.targets = self._make_dataset(root)
54 | self.transform = transform
55 |
56 | def _make_dataset(self, root):
57 | domains = os.listdir(root)
58 | fnames, fnames2, labels = [], [], []
59 | for idx, domain in enumerate(sorted(domains)):
60 | class_dir = os.path.join(root, domain)
61 | cls_fnames = listdir(class_dir)
62 | fnames += cls_fnames
63 | fnames2 += random.sample(cls_fnames, len(cls_fnames))
64 | labels += [idx] * len(cls_fnames)
65 | return list(zip(fnames, fnames2)), labels
66 |
67 | def __getitem__(self, index):
68 | fname, fname2 = self.samples[index]
69 | label = self.targets[index]
70 | img = Image.open(fname).convert('RGB')
71 | img2 = Image.open(fname2).convert('RGB')
72 | if self.transform is not None:
73 | img = self.transform(img)
74 | img2 = self.transform(img2)
75 | return img, img2, label
76 |
77 | def __len__(self):
78 | return len(self.targets)
79 |
80 |
81 | def _make_balanced_sampler(labels):
82 | class_counts = np.bincount(labels)
83 | class_weights = 1. / class_counts
84 | weights = class_weights[labels]
85 | return WeightedRandomSampler(weights, len(weights))
86 |
87 |
88 | def get_train_loader(root, which='source', img_size=256,
89 | batch_size=8, prob=0.5, num_workers=4):
90 | print('Preparing DataLoader to fetch %s images '
91 | 'during the training phase...' % which)
92 |
93 | crop = transforms.RandomResizedCrop(
94 | img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1])
95 | rand_crop = transforms.Lambda(
96 | lambda x: crop(x) if random.random() < prob else x)
97 |
98 | transform = transforms.Compose([
99 | rand_crop,
100 | transforms.Resize([img_size, img_size]),
101 | transforms.RandomHorizontalFlip(),
102 | transforms.ToTensor(),
103 | transforms.Normalize(mean=[0.5, 0.5, 0.5],
104 | std=[0.5, 0.5, 0.5]),
105 | ])
106 |
107 | if which == 'source':
108 | dataset = ImageFolder(root, transform)
109 | elif which == 'reference':
110 | dataset = ReferenceDataset(root, transform)
111 | else:
112 | raise NotImplementedError
113 |
114 | sampler = _make_balanced_sampler(dataset.targets)
115 | return data.DataLoader(dataset=dataset,
116 | batch_size=batch_size,
117 | sampler=sampler,
118 | num_workers=num_workers,
119 | pin_memory=True,
120 | drop_last=True)
121 |
122 |
123 | def get_eval_loader(root, img_size=256, batch_size=32,
124 | imagenet_normalize=True, shuffle=True,
125 | num_workers=4, drop_last=False):
126 | print('Preparing DataLoader for the evaluation phase...')
127 | if imagenet_normalize:
128 | height, width = 299, 299
129 | mean = [0.485, 0.456, 0.406]
130 | std = [0.229, 0.224, 0.225]
131 | else:
132 | height, width = img_size, img_size
133 | mean = [0.5, 0.5, 0.5]
134 | std = [0.5, 0.5, 0.5]
135 |
136 | transform = transforms.Compose([
137 | transforms.Resize([img_size, img_size]),
138 | transforms.Resize([height, width]),
139 | transforms.ToTensor(),
140 | transforms.Normalize(mean=mean, std=std)
141 | ])
142 |
143 | dataset = DefaultDataset(root, transform=transform)
144 | return data.DataLoader(dataset=dataset,
145 | batch_size=batch_size,
146 | shuffle=shuffle,
147 | num_workers=num_workers,
148 | pin_memory=True,
149 | drop_last=drop_last)
150 |
151 |
152 | def get_test_loader(root, img_size=256, batch_size=32,
153 | shuffle=True, num_workers=4):
154 | print('Preparing DataLoader for the generation phase...')
155 | transform = transforms.Compose([
156 | transforms.Resize([img_size, img_size]),
157 | transforms.ToTensor(),
158 | transforms.Normalize(mean=[0.5, 0.5, 0.5],
159 | std=[0.5, 0.5, 0.5]),
160 | ])
161 |
162 | dataset = ImageFolder(root, transform)
163 | return data.DataLoader(dataset=dataset,
164 | batch_size=batch_size,
165 | shuffle=shuffle,
166 | num_workers=num_workers,
167 | pin_memory=True)
168 |
169 |
170 | class InputFetcher:
171 | def __init__(self, loader, loader_ref=None, latent_dim=16, mode=''):
172 | self.loader = loader
173 | self.loader_ref = loader_ref
174 | self.latent_dim = latent_dim
175 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
176 | self.mode = mode
177 |
178 | def _fetch_inputs(self):
179 | try:
180 | x, y = next(self.iter)
181 | except (AttributeError, StopIteration):
182 | self.iter = iter(self.loader)
183 | x, y = next(self.iter)
184 | return x, y
185 |
186 | def _fetch_refs(self):
187 | try:
188 | x, x2, y = next(self.iter_ref)
189 | except (AttributeError, StopIteration):
190 | self.iter_ref = iter(self.loader_ref)
191 | x, x2, y = next(self.iter_ref)
192 | return x, x2, y
193 |
194 | def __next__(self):
195 | x, y = self._fetch_inputs()
196 | if self.mode == 'train':
197 | x_ref, x_ref2, y_ref = self._fetch_refs()
198 | z_trg = torch.randn(x.size(0), self.latent_dim)
199 | z_trg2 = torch.randn(x.size(0), self.latent_dim)
200 | inputs = Munch(x_src=x, y_src=y, y_ref=y_ref,
201 | x_ref=x_ref, x_ref2=x_ref2,
202 | z_trg=z_trg, z_trg2=z_trg2)
203 | elif self.mode == 'val':
204 | x_ref, y_ref = self._fetch_inputs()
205 | inputs = Munch(x_src=x, y_src=y,
206 | x_ref=x_ref, y_ref=y_ref)
207 | elif self.mode == 'test':
208 | inputs = Munch(x=x, y=y)
209 | else:
210 | raise NotImplementedError
211 |
212 | return Munch({k: v.to(self.device)
213 | for k, v in inputs.items()})
--------------------------------------------------------------------------------
/core/model.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import copy
12 | import math
13 |
14 | from munch import Munch
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 |
20 | from core.wing import FAN
21 |
22 |
23 | class ResBlk(nn.Module):
24 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
25 | normalize=False, downsample=False):
26 | super().__init__()
27 | self.actv = actv
28 | self.normalize = normalize
29 | self.downsample = downsample
30 | self.learned_sc = dim_in != dim_out
31 | self._build_weights(dim_in, dim_out)
32 |
33 | def _build_weights(self, dim_in, dim_out):
34 | self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
35 | self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
36 | if self.normalize:
37 | self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
38 | self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
39 | if self.learned_sc:
40 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
41 |
42 | def _shortcut(self, x):
43 | if self.learned_sc:
44 | x = self.conv1x1(x)
45 | if self.downsample:
46 | x = F.avg_pool2d(x, 2)
47 | return x
48 |
49 | def _residual(self, x):
50 | if self.normalize:
51 | x = self.norm1(x)
52 | x = self.actv(x)
53 | x = self.conv1(x)
54 | if self.downsample:
55 | x = F.avg_pool2d(x, 2)
56 | if self.normalize:
57 | x = self.norm2(x)
58 | x = self.actv(x)
59 | x = self.conv2(x)
60 | return x
61 |
62 | def forward(self, x):
63 | x = self._shortcut(x) + self._residual(x)
64 | return x / math.sqrt(2) # unit variance
65 |
66 |
67 | class AdaIN(nn.Module):
68 | def __init__(self, style_dim, num_features):
69 | super().__init__()
70 | self.norm = nn.InstanceNorm2d(num_features, affine=False)
71 | self.fc = nn.Linear(style_dim, num_features*2)
72 |
73 | def forward(self, x, s):
74 | h = self.fc(s)
75 | h = h.view(h.size(0), h.size(1), 1, 1)
76 | gamma, beta = torch.chunk(h, chunks=2, dim=1)
77 | return (1 + gamma) * self.norm(x) + beta
78 |
79 |
80 | class AdainResBlk(nn.Module):
81 | def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0,
82 | actv=nn.LeakyReLU(0.2), upsample=False):
83 | super().__init__()
84 | self.w_hpf = w_hpf
85 | self.actv = actv
86 | self.upsample = upsample
87 | self.learned_sc = dim_in != dim_out
88 | self._build_weights(dim_in, dim_out, style_dim)
89 |
90 | def _build_weights(self, dim_in, dim_out, style_dim=64):
91 | self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
92 | self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
93 | self.norm1 = AdaIN(style_dim, dim_in)
94 | self.norm2 = AdaIN(style_dim, dim_out)
95 | if self.learned_sc:
96 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
97 |
98 | def _shortcut(self, x):
99 | if self.upsample:
100 | x = F.interpolate(x, scale_factor=2, mode='nearest')
101 | if self.learned_sc:
102 | x = self.conv1x1(x)
103 | return x
104 |
105 | def _residual(self, x, s):
106 | x = self.norm1(x, s)
107 | x = self.actv(x)
108 | if self.upsample:
109 | x = F.interpolate(x, scale_factor=2, mode='nearest')
110 | x = self.conv1(x)
111 | x = self.norm2(x, s)
112 | x = self.actv(x)
113 | x = self.conv2(x)
114 | return x
115 |
116 | def forward(self, x, s):
117 | out = self._residual(x, s)
118 | if self.w_hpf == 0:
119 | out = (out + self._shortcut(x)) / math.sqrt(2)
120 | return out
121 |
122 |
123 | class HighPass(nn.Module):
124 | def __init__(self, w_hpf, device):
125 | super(HighPass, self).__init__()
126 | self.register_buffer('filter',
127 | torch.tensor([[-1, -1, -1],
128 | [-1, 8., -1],
129 | [-1, -1, -1]]) / w_hpf)
130 |
131 | def forward(self, x):
132 | filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
133 | return F.conv2d(x, filter, padding=1, groups=x.size(1))
134 |
135 |
136 | class Generator(nn.Module):
137 | def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):
138 | super().__init__()
139 | dim_in = 2**14 // img_size
140 | self.img_size = img_size
141 | self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1)
142 | self.encode = nn.ModuleList()
143 | self.decode = nn.ModuleList()
144 | self.to_rgb = nn.Sequential(
145 | nn.InstanceNorm2d(dim_in, affine=True),
146 | nn.LeakyReLU(0.2),
147 | nn.Conv2d(dim_in, 3, 1, 1, 0))
148 |
149 | # down/up-sampling blocks
150 | repeat_num = int(np.log2(img_size)) - 4
151 | if w_hpf > 0:
152 | repeat_num += 1
153 | for _ in range(repeat_num):
154 | dim_out = min(dim_in*2, max_conv_dim)
155 | self.encode.append(
156 | ResBlk(dim_in, dim_out, normalize=True, downsample=True))
157 | self.decode.insert(
158 | 0, AdainResBlk(dim_out, dim_in, style_dim,
159 | w_hpf=w_hpf, upsample=True)) # stack-like
160 | dim_in = dim_out
161 |
162 | # bottleneck blocks
163 | for _ in range(2):
164 | self.encode.append(
165 | ResBlk(dim_out, dim_out, normalize=True))
166 | self.decode.insert(
167 | 0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))
168 |
169 | if w_hpf > 0:
170 | device = torch.device(
171 | 'cuda' if torch.cuda.is_available() else 'cpu')
172 | self.hpf = HighPass(w_hpf, device)
173 |
174 | def forward(self, x, s, masks=None):
175 | x = self.from_rgb(x)
176 | cache = {}
177 | for block in self.encode:
178 | if (masks is not None) and (x.size(2) in [32, 64, 128]):
179 | cache[x.size(2)] = x
180 | x = block(x)
181 | for block in self.decode:
182 | x = block(x, s)
183 | if (masks is not None) and (x.size(2) in [32, 64, 128]):
184 | mask = masks[0] if x.size(2) in [32] else masks[1]
185 | mask = F.interpolate(mask, size=x.size(2), mode='bilinear')
186 | x = x + self.hpf(mask * cache[x.size(2)])
187 | return self.to_rgb(x)
188 |
189 |
190 | class MappingNetwork(nn.Module):
191 | def __init__(self, latent_dim=16, style_dim=64, num_domains=2):
192 | super().__init__()
193 | layers = []
194 | layers += [nn.Linear(latent_dim, 512)]
195 | layers += [nn.ReLU()]
196 | for _ in range(3):
197 | layers += [nn.Linear(512, 512)]
198 | layers += [nn.ReLU()]
199 | self.shared = nn.Sequential(*layers)
200 |
201 | self.unshared = nn.ModuleList()
202 | for _ in range(num_domains):
203 | self.unshared += [nn.Sequential(nn.Linear(512, 512),
204 | nn.ReLU(),
205 | nn.Linear(512, 512),
206 | nn.ReLU(),
207 | nn.Linear(512, 512),
208 | nn.ReLU(),
209 | nn.Linear(512, style_dim))]
210 |
211 | def forward(self, z, y):
212 | h = self.shared(z)
213 | out = []
214 | for layer in self.unshared:
215 | out += [layer(h)]
216 | out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
217 | idx = torch.LongTensor(range(y.size(0))).to(y.device)
218 | s = out[idx, y] # (batch, style_dim)
219 | return s
220 |
221 |
222 | class StyleEncoder(nn.Module):
223 | def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
224 | super().__init__()
225 | dim_in = 2**14 // img_size
226 | blocks = []
227 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
228 |
229 | repeat_num = int(np.log2(img_size)) - 2
230 | for _ in range(repeat_num):
231 | dim_out = min(dim_in*2, max_conv_dim)
232 | blocks += [ResBlk(dim_in, dim_out, downsample=True)]
233 | dim_in = dim_out
234 |
235 | blocks += [nn.LeakyReLU(0.2)]
236 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
237 | blocks += [nn.LeakyReLU(0.2)]
238 | self.shared = nn.Sequential(*blocks)
239 |
240 | self.unshared = nn.ModuleList()
241 | for _ in range(num_domains):
242 | self.unshared += [nn.Linear(dim_out, style_dim)]
243 |
244 | def forward(self, x, y):
245 | h = self.shared(x)
246 | h = h.view(h.size(0), -1)
247 | out = []
248 | for layer in self.unshared:
249 | out += [layer(h)]
250 | out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
251 | idx = torch.LongTensor(range(y.size(0))).to(y.device)
252 | s = out[idx, y] # (batch, style_dim)
253 | return s
254 |
255 |
256 | class Discriminator(nn.Module):
257 | def __init__(self, img_size=256, num_domains=2, max_conv_dim=512):
258 | super().__init__()
259 | dim_in = 2**14 // img_size
260 | blocks = []
261 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
262 |
263 | repeat_num = int(np.log2(img_size)) - 2
264 | for _ in range(repeat_num):
265 | dim_out = min(dim_in*2, max_conv_dim)
266 | blocks += [ResBlk(dim_in, dim_out, downsample=True)]
267 | dim_in = dim_out
268 |
269 | blocks += [nn.LeakyReLU(0.2)]
270 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
271 | blocks += [nn.LeakyReLU(0.2)]
272 | blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
273 | self.main = nn.Sequential(*blocks)
274 |
275 | def forward(self, x, y):
276 | out = self.main(x)
277 | out = out.view(out.size(0), -1) # (batch, num_domains)
278 | idx = torch.LongTensor(range(y.size(0))).to(y.device)
279 | out = out[idx, y] # (batch)
280 | return out
281 |
282 |
283 | def build_model(args):
284 | generator = nn.DataParallel(Generator(args.img_size, args.style_dim, w_hpf=args.w_hpf))
285 | mapping_network = nn.DataParallel(MappingNetwork(args.latent_dim, args.style_dim, args.num_domains))
286 | style_encoder = nn.DataParallel(StyleEncoder(args.img_size, args.style_dim, args.num_domains))
287 | discriminator = nn.DataParallel(Discriminator(args.img_size, args.num_domains))
288 | generator_ema = copy.deepcopy(generator)
289 | mapping_network_ema = copy.deepcopy(mapping_network)
290 | style_encoder_ema = copy.deepcopy(style_encoder)
291 |
292 | nets = Munch(generator=generator,
293 | mapping_network=mapping_network,
294 | style_encoder=style_encoder,
295 | discriminator=discriminator)
296 | nets_ema = Munch(generator=generator_ema,
297 | mapping_network=mapping_network_ema,
298 | style_encoder=style_encoder_ema)
299 |
300 | if args.w_hpf > 0:
301 | fan = nn.DataParallel(FAN(fname_pretrained=args.wing_path).eval())
302 | fan.get_heatmap = fan.module.get_heatmap
303 | nets.fan = fan
304 | nets_ema.fan = fan
305 |
306 | return nets, nets_ema
307 |
--------------------------------------------------------------------------------
/core/solver.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import os
12 | from os.path import join as ospj
13 | import time
14 | import datetime
15 | from munch import Munch
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 | from core.model import build_model
22 | from core.checkpoint import CheckpointIO
23 | from core.data_loader import InputFetcher
24 | import core.utils as utils
25 | from metrics.eval import calculate_metrics
26 |
27 |
28 | class Solver(nn.Module):
29 | def __init__(self, args):
30 | super().__init__()
31 | self.args = args
32 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33 |
34 | self.nets, self.nets_ema = build_model(args)
35 | # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
36 | for name, module in self.nets.items():
37 | utils.print_network(module, name)
38 | setattr(self, name, module)
39 | for name, module in self.nets_ema.items():
40 | setattr(self, name + '_ema', module)
41 |
42 | if args.mode == 'train':
43 | self.optims = Munch()
44 | for net in self.nets.keys():
45 | if net == 'fan':
46 | continue
47 | self.optims[net] = torch.optim.Adam(
48 | params=self.nets[net].parameters(),
49 | lr=args.f_lr if net == 'mapping_network' else args.lr,
50 | betas=[args.beta1, args.beta2],
51 | weight_decay=args.weight_decay)
52 |
53 | self.ckptios = [
54 | CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), data_parallel=True, **self.nets),
55 | CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), data_parallel=True, **self.nets_ema),
56 | CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
57 | else:
58 | self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), data_parallel=True, **self.nets_ema)]
59 |
60 | self.to(self.device)
61 | for name, network in self.named_children():
62 | # Do not initialize the FAN parameters
63 | if ('ema' not in name) and ('fan' not in name):
64 | print('Initializing %s...' % name)
65 | network.apply(utils.he_init)
66 |
67 | def _save_checkpoint(self, step):
68 | for ckptio in self.ckptios:
69 | ckptio.save(step)
70 |
71 | def _load_checkpoint(self, step):
72 | for ckptio in self.ckptios:
73 | ckptio.load(step)
74 |
75 | def _reset_grad(self):
76 | for optim in self.optims.values():
77 | optim.zero_grad()
78 |
79 | def train(self, loaders):
80 | args = self.args
81 | nets = self.nets
82 | nets_ema = self.nets_ema
83 | optims = self.optims
84 |
85 | # fetch random validation images for debugging
86 | fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
87 | fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
88 | inputs_val = next(fetcher_val)
89 |
90 | # resume training if necessary
91 | if args.resume_iter > 0:
92 | self._load_checkpoint(args.resume_iter)
93 |
94 | # remember the initial value of ds weight
95 | initial_lambda_ds = args.lambda_ds
96 |
97 | print('Start training...')
98 | start_time = time.time()
99 | for i in range(args.resume_iter, args.total_iters):
100 | # fetch images and labels
101 | inputs = next(fetcher)
102 | x_real, y_org = inputs.x_src, inputs.y_src
103 | x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
104 | z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2
105 |
106 | masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None
107 |
108 | # train the discriminator
109 | d_loss, d_losses_latent = compute_d_loss(
110 | nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
111 | self._reset_grad()
112 | d_loss.backward()
113 | optims.discriminator.step()
114 |
115 | d_loss, d_losses_ref = compute_d_loss(
116 | nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
117 | self._reset_grad()
118 | d_loss.backward()
119 | optims.discriminator.step()
120 |
121 | # train the generator
122 | g_loss, g_losses_latent = compute_g_loss(
123 | nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
124 | self._reset_grad()
125 | g_loss.backward()
126 | optims.generator.step()
127 | optims.mapping_network.step()
128 | optims.style_encoder.step()
129 |
130 | g_loss, g_losses_ref = compute_g_loss(
131 | nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
132 | self._reset_grad()
133 | g_loss.backward()
134 | optims.generator.step()
135 |
136 | # compute moving average of network parameters
137 | moving_average(nets.generator, nets_ema.generator, beta=0.999)
138 | moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
139 | moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)
140 |
141 | # decay weight for diversity sensitive loss
142 | if args.lambda_ds > 0:
143 | args.lambda_ds -= (initial_lambda_ds / args.ds_iter)
144 |
145 | # print out log info
146 | if (i+1) % args.print_every == 0:
147 | elapsed = time.time() - start_time
148 | elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
149 | log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
150 | all_losses = dict()
151 | for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
152 | ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
153 | for key, value in loss.items():
154 | all_losses[prefix + key] = value
155 | all_losses['G/lambda_ds'] = args.lambda_ds
156 | log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
157 | print(log)
158 |
159 | # generate images for debugging
160 | if (i+1) % args.sample_every == 0:
161 | os.makedirs(args.sample_dir, exist_ok=True)
162 | utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1)
163 |
164 | # save model checkpoints
165 | if (i+1) % args.save_every == 0:
166 | self._save_checkpoint(step=i+1)
167 |
168 | # compute FID and LPIPS if necessary
169 | if (i+1) % args.eval_every == 0:
170 | calculate_metrics(nets_ema, args, i+1, mode='latent')
171 | calculate_metrics(nets_ema, args, i+1, mode='reference')
172 |
173 | @torch.no_grad()
174 | def sample(self, loaders):
175 | args = self.args
176 | nets_ema = self.nets_ema
177 | os.makedirs(args.result_dir, exist_ok=True)
178 | self._load_checkpoint(args.resume_iter)
179 |
180 | src = next(InputFetcher(loaders.src, None, args.latent_dim, 'test'))
181 | ref = next(InputFetcher(loaders.ref, None, args.latent_dim, 'test'))
182 |
183 | fname = ospj(args.result_dir, 'reference.jpg')
184 | print('Working on {}...'.format(fname))
185 | utils.translate_using_reference(nets_ema, args, src.x, ref.x, ref.y, fname)
186 |
187 | fname = ospj(args.result_dir, 'video_ref.mp4')
188 | print('Working on {}...'.format(fname))
189 | utils.video_ref(nets_ema, args, src.x, ref.x, ref.y, fname)
190 |
191 | @torch.no_grad()
192 | def evaluate(self):
193 | args = self.args
194 | nets_ema = self.nets_ema
195 | resume_iter = args.resume_iter
196 | self._load_checkpoint(args.resume_iter)
197 | calculate_metrics(nets_ema, args, step=resume_iter, mode='latent')
198 | calculate_metrics(nets_ema, args, step=resume_iter, mode='reference')
199 |
200 |
201 | def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None):
202 | assert (z_trg is None) != (x_ref is None)
203 | # with real images
204 | x_real.requires_grad_()
205 | out = nets.discriminator(x_real, y_org)
206 | loss_real = adv_loss(out, 1)
207 | loss_reg = r1_reg(out, x_real)
208 |
209 | # with fake images
210 | with torch.no_grad():
211 | if z_trg is not None:
212 | s_trg = nets.mapping_network(z_trg, y_trg)
213 | else: # x_ref is not None
214 | s_trg = nets.style_encoder(x_ref, y_trg)
215 |
216 | x_fake = nets.generator(x_real, s_trg, masks=masks)
217 | out = nets.discriminator(x_fake, y_trg)
218 | loss_fake = adv_loss(out, 0)
219 |
220 | loss = loss_real + loss_fake + args.lambda_reg * loss_reg
221 | return loss, Munch(real=loss_real.item(),
222 | fake=loss_fake.item(),
223 | reg=loss_reg.item())
224 |
225 |
226 | def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None):
227 | assert (z_trgs is None) != (x_refs is None)
228 | if z_trgs is not None:
229 | z_trg, z_trg2 = z_trgs
230 | if x_refs is not None:
231 | x_ref, x_ref2 = x_refs
232 |
233 | # adversarial loss
234 | if z_trgs is not None:
235 | s_trg = nets.mapping_network(z_trg, y_trg)
236 | else:
237 | s_trg = nets.style_encoder(x_ref, y_trg)
238 |
239 | x_fake = nets.generator(x_real, s_trg, masks=masks)
240 | out = nets.discriminator(x_fake, y_trg)
241 | loss_adv = adv_loss(out, 1)
242 |
243 | # style reconstruction loss
244 | s_pred = nets.style_encoder(x_fake, y_trg)
245 | loss_sty = torch.mean(torch.abs(s_pred - s_trg))
246 |
247 | # diversity sensitive loss
248 | if z_trgs is not None:
249 | s_trg2 = nets.mapping_network(z_trg2, y_trg)
250 | else:
251 | s_trg2 = nets.style_encoder(x_ref2, y_trg)
252 | x_fake2 = nets.generator(x_real, s_trg2, masks=masks)
253 | x_fake2 = x_fake2.detach()
254 | loss_ds = torch.mean(torch.abs(x_fake - x_fake2))
255 |
256 | # cycle-consistency loss
257 | masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None
258 | s_org = nets.style_encoder(x_real, y_org)
259 | x_rec = nets.generator(x_fake, s_org, masks=masks)
260 | loss_cyc = torch.mean(torch.abs(x_rec - x_real))
261 |
262 | loss = loss_adv + args.lambda_sty * loss_sty \
263 | - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc
264 | return loss, Munch(adv=loss_adv.item(),
265 | sty=loss_sty.item(),
266 | ds=loss_ds.item(),
267 | cyc=loss_cyc.item())
268 |
269 |
270 | def moving_average(model, model_test, beta=0.999):
271 | for param, param_test in zip(model.parameters(), model_test.parameters()):
272 | param_test.data = torch.lerp(param.data, param_test.data, beta)
273 |
274 |
275 | def adv_loss(logits, target):
276 | assert target in [1, 0]
277 | targets = torch.full_like(logits, fill_value=target)
278 | loss = F.binary_cross_entropy_with_logits(logits, targets)
279 | return loss
280 |
281 |
282 | def r1_reg(d_out, x_in):
283 | # zero-centered gradient penalty for real images
284 | batch_size = x_in.size(0)
285 | grad_dout = torch.autograd.grad(
286 | outputs=d_out.sum(), inputs=x_in,
287 | create_graph=True, retain_graph=True, only_inputs=True
288 | )[0]
289 | grad_dout2 = grad_dout.pow(2)
290 | assert(grad_dout2.size() == x_in.size())
291 | reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
292 | return reg
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import os
12 | from os.path import join as ospj
13 | import json
14 | import glob
15 | from shutil import copyfile
16 |
17 | from tqdm import tqdm
18 | import ffmpeg
19 |
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | import torchvision
25 | import torchvision.utils as vutils
26 |
27 |
28 | def save_json(json_file, filename):
29 | with open(filename, 'w') as f:
30 | json.dump(json_file, f, indent=4, sort_keys=False)
31 |
32 |
33 | def print_network(network, name):
34 | num_params = 0
35 | for p in network.parameters():
36 | num_params += p.numel()
37 | # print(network)
38 | print("Number of parameters of %s: %i" % (name, num_params))
39 |
40 |
41 | def he_init(module):
42 | if isinstance(module, nn.Conv2d):
43 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
44 | if module.bias is not None:
45 | nn.init.constant_(module.bias, 0)
46 | if isinstance(module, nn.Linear):
47 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
48 | if module.bias is not None:
49 | nn.init.constant_(module.bias, 0)
50 |
51 |
52 | def denormalize(x):
53 | out = (x + 1) / 2
54 | return out.clamp_(0, 1)
55 |
56 |
57 | def save_image(x, ncol, filename):
58 | x = denormalize(x)
59 | vutils.save_image(x.cpu(), filename, nrow=ncol, padding=0)
60 |
61 |
62 | @torch.no_grad()
63 | def translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename):
64 | N, C, H, W = x_src.size()
65 | s_ref = nets.style_encoder(x_ref, y_ref)
66 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
67 | x_fake = nets.generator(x_src, s_ref, masks=masks)
68 | s_src = nets.style_encoder(x_src, y_src)
69 | masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None
70 | x_rec = nets.generator(x_fake, s_src, masks=masks)
71 | x_concat = [x_src, x_ref, x_fake, x_rec]
72 | x_concat = torch.cat(x_concat, dim=0)
73 | save_image(x_concat, N, filename)
74 | del x_concat
75 |
76 |
77 | @torch.no_grad()
78 | def translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename):
79 | N, C, H, W = x_src.size()
80 | latent_dim = z_trg_list[0].size(1)
81 | x_concat = [x_src]
82 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
83 |
84 | for i, y_trg in enumerate(y_trg_list):
85 | z_many = torch.randn(10000, latent_dim).to(x_src.device)
86 | y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0])
87 | s_many = nets.mapping_network(z_many, y_many)
88 | s_avg = torch.mean(s_many, dim=0, keepdim=True)
89 | s_avg = s_avg.repeat(N, 1)
90 |
91 | for z_trg in z_trg_list:
92 | s_trg = nets.mapping_network(z_trg, y_trg)
93 | s_trg = torch.lerp(s_avg, s_trg, psi)
94 | x_fake = nets.generator(x_src, s_trg, masks=masks)
95 | x_concat += [x_fake]
96 |
97 | x_concat = torch.cat(x_concat, dim=0)
98 | save_image(x_concat, N, filename)
99 |
100 |
101 | @torch.no_grad()
102 | def translate_using_reference(nets, args, x_src, x_ref, y_ref, filename):
103 | N, C, H, W = x_src.size()
104 | wb = torch.ones(1, C, H, W).to(x_src.device)
105 | x_src_with_wb = torch.cat([wb, x_src], dim=0)
106 |
107 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
108 | s_ref = nets.style_encoder(x_ref, y_ref)
109 | s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1)
110 | x_concat = [x_src_with_wb]
111 | for i, s_ref in enumerate(s_ref_list):
112 | x_fake = nets.generator(x_src, s_ref, masks=masks)
113 | x_fake_with_ref = torch.cat([x_ref[i:i+1], x_fake], dim=0)
114 | x_concat += [x_fake_with_ref]
115 |
116 | x_concat = torch.cat(x_concat, dim=0)
117 | save_image(x_concat, N+1, filename)
118 | del x_concat
119 |
120 |
121 | @torch.no_grad()
122 | def debug_image(nets, args, inputs, step):
123 | x_src, y_src = inputs.x_src, inputs.y_src
124 | x_ref, y_ref = inputs.x_ref, inputs.y_ref
125 |
126 | device = inputs.x_src.device
127 | N = inputs.x_src.size(0)
128 |
129 | # translate and reconstruct (reference-guided)
130 | filename = ospj(args.sample_dir, '%06d_cycle_consistency.jpg' % (step))
131 | translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename)
132 |
133 | # latent-guided image synthesis
134 | y_trg_list = [torch.tensor(y).repeat(N).to(device)
135 | for y in range(min(args.num_domains, 5))]
136 | z_trg_list = torch.randn(args.num_outs_per_domain, 1, args.latent_dim).repeat(1, N, 1).to(device)
137 | for psi in [0.5, 0.7, 1.0]:
138 | filename = ospj(args.sample_dir, '%06d_latent_psi_%.1f.jpg' % (step, psi))
139 | translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename)
140 |
141 | # reference-guided image synthesis
142 | filename = ospj(args.sample_dir, '%06d_reference.jpg' % (step))
143 | translate_using_reference(nets, args, x_src, x_ref, y_ref, filename)
144 |
145 |
146 | # ======================= #
147 | # Video-related functions #
148 | # ======================= #
149 |
150 |
151 | def sigmoid(x, w=1):
152 | return 1. / (1 + np.exp(-w * x))
153 |
154 |
155 | def get_alphas(start=-5, end=5, step=0.5, len_tail=10):
156 | return [0] + [sigmoid(alpha) for alpha in np.arange(start, end, step)] + [1] * len_tail
157 |
158 |
159 | def interpolate(nets, args, x_src, s_prev, s_next):
160 | ''' returns T x C x H x W '''
161 | B = x_src.size(0)
162 | frames = []
163 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
164 | alphas = get_alphas()
165 |
166 | for alpha in alphas:
167 | s_ref = torch.lerp(s_prev, s_next, alpha)
168 | x_fake = nets.generator(x_src, s_ref, masks=masks)
169 | entries = torch.cat([x_src.cpu(), x_fake.cpu()], dim=2)
170 | frame = torchvision.utils.make_grid(entries, nrow=B, padding=0, pad_value=-1).unsqueeze(0)
171 | frames.append(frame)
172 | frames = torch.cat(frames)
173 | return frames
174 |
175 |
176 | def slide(entries, margin=32):
177 | """Returns a sliding reference window.
178 | Args:
179 | entries: a list containing two reference images, x_prev and x_next,
180 | both of which has a shape (1, 3, 256, 256)
181 | Returns:
182 | canvas: output slide of shape (num_frames, 3, 256*2, 256+margin)
183 | """
184 | _, C, H, W = entries[0].shape
185 | alphas = get_alphas()
186 | T = len(alphas) # number of frames
187 |
188 | canvas = - torch.ones((T, C, H*2, W + margin))
189 | merged = torch.cat(entries, dim=2) # (1, 3, 512, 256)
190 | for t, alpha in enumerate(alphas):
191 | top = int(H * (1 - alpha)) # top, bottom for canvas
192 | bottom = H * 2
193 | m_top = 0 # top, bottom for merged
194 | m_bottom = 2 * H - top
195 | canvas[t, :, top:bottom, :W] = merged[:, :, m_top:m_bottom, :]
196 | return canvas
197 |
198 |
199 | @torch.no_grad()
200 | def video_ref(nets, args, x_src, x_ref, y_ref, fname):
201 | video = []
202 | s_ref = nets.style_encoder(x_ref, y_ref)
203 | s_prev = None
204 | for data_next in tqdm(zip(x_ref, y_ref, s_ref), 'video_ref', len(x_ref)):
205 | x_next, y_next, s_next = [d.unsqueeze(0) for d in data_next]
206 | if s_prev is None:
207 | x_prev, y_prev, s_prev = x_next, y_next, s_next
208 | continue
209 | if y_prev != y_next:
210 | x_prev, y_prev, s_prev = x_next, y_next, s_next
211 | continue
212 |
213 | interpolated = interpolate(nets, args, x_src, s_prev, s_next)
214 | entries = [x_prev, x_next]
215 | slided = slide(entries) # (T, C, 256*2, 256)
216 | frames = torch.cat([slided, interpolated], dim=3).cpu() # (T, C, 256*2, 256*(batch+1))
217 | video.append(frames)
218 | x_prev, y_prev, s_prev = x_next, y_next, s_next
219 |
220 | # append last frame 10 time
221 | for _ in range(10):
222 | video.append(frames[-1:])
223 | video = tensor2ndarray255(torch.cat(video))
224 | save_video(fname, video)
225 |
226 |
227 | @torch.no_grad()
228 | def video_latent(nets, args, x_src, y_list, z_list, psi, fname):
229 | latent_dim = z_list[0].size(1)
230 | s_list = []
231 | for i, y_trg in enumerate(y_list):
232 | z_many = torch.randn(10000, latent_dim).to(x_src.device)
233 | y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0])
234 | s_many = nets.mapping_network(z_many, y_many)
235 | s_avg = torch.mean(s_many, dim=0, keepdim=True)
236 | s_avg = s_avg.repeat(x_src.size(0), 1)
237 |
238 | for z_trg in z_list:
239 | s_trg = nets.mapping_network(z_trg, y_trg)
240 | s_trg = torch.lerp(s_avg, s_trg, psi)
241 | s_list.append(s_trg)
242 |
243 | s_prev = None
244 | video = []
245 | # fetch reference images
246 | for idx_ref, s_next in enumerate(tqdm(s_list, 'video_latent', len(s_list))):
247 | if s_prev is None:
248 | s_prev = s_next
249 | continue
250 | if idx_ref % len(z_list) == 0:
251 | s_prev = s_next
252 | continue
253 | frames = interpolate(nets, args, x_src, s_prev, s_next).cpu()
254 | video.append(frames)
255 | s_prev = s_next
256 | for _ in range(10):
257 | video.append(frames[-1:])
258 | video = tensor2ndarray255(torch.cat(video))
259 | save_video(fname, video)
260 |
261 |
262 | def save_video(fname, images, output_fps=30, vcodec='libx264', filters=''):
263 | assert isinstance(images, np.ndarray), "images should be np.array: NHWC"
264 | num_frames, height, width, channels = images.shape
265 | stream = ffmpeg.input('pipe:', format='rawvideo',
266 | pix_fmt='rgb24', s='{}x{}'.format(width, height))
267 | stream = ffmpeg.filter(stream, 'setpts', '2*PTS') # 2*PTS is for slower playback
268 | stream = ffmpeg.output(stream, fname, pix_fmt='yuv420p', vcodec=vcodec, r=output_fps)
269 | stream = ffmpeg.overwrite_output(stream)
270 | process = ffmpeg.run_async(stream, pipe_stdin=True)
271 | for frame in tqdm(images, desc='writing video to %s' % fname):
272 | process.stdin.write(frame.astype(np.uint8).tobytes())
273 | process.stdin.close()
274 | process.wait()
275 |
276 |
277 | def tensor2ndarray255(images):
278 | images = torch.clamp(images * 0.5 + 0.5, 0, 1)
279 | return images.cpu().numpy().transpose(0, 2, 3, 1) * 255
--------------------------------------------------------------------------------
/core/wing.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 |
10 | Lines (19 to 80) were adapted from https://github.com/1adrianb/face-alignment
11 | Lines (83 to 235) were adapted from https://github.com/protossw512/AdaptiveWingLoss
12 | """
13 |
14 | from collections import namedtuple
15 | from copy import deepcopy
16 | from functools import partial
17 |
18 | from munch import Munch
19 | import numpy as np
20 | import cv2
21 | from skimage.filters import gaussian
22 | import torch
23 | import torch.nn as nn
24 | import torch.nn.functional as F
25 |
26 |
27 | def get_preds_fromhm(hm):
28 | max, idx = torch.max(
29 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
30 | idx += 1
31 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
32 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
33 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
34 |
35 | for i in range(preds.size(0)):
36 | for j in range(preds.size(1)):
37 | hm_ = hm[i, j, :]
38 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
39 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
40 | diff = torch.FloatTensor(
41 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
42 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
43 | preds[i, j].add_(diff.sign_().mul_(.25))
44 |
45 | preds.add_(-0.5)
46 | return preds
47 |
48 |
49 | class HourGlass(nn.Module):
50 | def __init__(self, num_modules, depth, num_features, first_one=False):
51 | super(HourGlass, self).__init__()
52 | self.num_modules = num_modules
53 | self.depth = depth
54 | self.features = num_features
55 | self.coordconv = CoordConvTh(64, 64, True, True, 256, first_one,
56 | out_channels=256,
57 | kernel_size=1, stride=1, padding=0)
58 | self._generate_network(self.depth)
59 |
60 | def _generate_network(self, level):
61 | self.add_module('b1_' + str(level), ConvBlock(256, 256))
62 | self.add_module('b2_' + str(level), ConvBlock(256, 256))
63 | if level > 1:
64 | self._generate_network(level - 1)
65 | else:
66 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
67 | self.add_module('b3_' + str(level), ConvBlock(256, 256))
68 |
69 | def _forward(self, level, inp):
70 | up1 = inp
71 | up1 = self._modules['b1_' + str(level)](up1)
72 | low1 = F.avg_pool2d(inp, 2, stride=2)
73 | low1 = self._modules['b2_' + str(level)](low1)
74 |
75 | if level > 1:
76 | low2 = self._forward(level - 1, low1)
77 | else:
78 | low2 = low1
79 | low2 = self._modules['b2_plus_' + str(level)](low2)
80 | low3 = low2
81 | low3 = self._modules['b3_' + str(level)](low3)
82 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
83 |
84 | return up1 + up2
85 |
86 | def forward(self, x, heatmap):
87 | x, last_channel = self.coordconv(x, heatmap)
88 | return self._forward(self.depth, x), last_channel
89 |
90 |
91 | class AddCoordsTh(nn.Module):
92 | def __init__(self, height=64, width=64, with_r=False, with_boundary=False):
93 | super(AddCoordsTh, self).__init__()
94 | self.with_r = with_r
95 | self.with_boundary = with_boundary
96 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
97 |
98 | with torch.no_grad():
99 | x_coords = torch.arange(height).unsqueeze(1).expand(height, width).float()
100 | y_coords = torch.arange(width).unsqueeze(0).expand(height, width).float()
101 | x_coords = (x_coords / (height - 1)) * 2 - 1
102 | y_coords = (y_coords / (width - 1)) * 2 - 1
103 | coords = torch.stack([x_coords, y_coords], dim=0) # (2, height, width)
104 |
105 | if self.with_r:
106 | rr = torch.sqrt(torch.pow(x_coords, 2) + torch.pow(y_coords, 2)) # (height, width)
107 | rr = (rr / torch.max(rr)).unsqueeze(0)
108 | coords = torch.cat([coords, rr], dim=0)
109 |
110 | self.coords = coords.unsqueeze(0).to(device) # (1, 2 or 3, height, width)
111 | self.x_coords = x_coords.to(device)
112 | self.y_coords = y_coords.to(device)
113 |
114 | def forward(self, x, heatmap=None):
115 | """
116 | x: (batch, c, x_dim, y_dim)
117 | """
118 | coords = self.coords.repeat(x.size(0), 1, 1, 1)
119 |
120 | if self.with_boundary and heatmap is not None:
121 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)
122 | zero_tensor = torch.zeros_like(self.x_coords)
123 | xx_boundary_channel = torch.where(boundary_channel > 0.05, self.x_coords, zero_tensor).to(zero_tensor.device)
124 | yy_boundary_channel = torch.where(boundary_channel > 0.05, self.y_coords, zero_tensor).to(zero_tensor.device)
125 | coords = torch.cat([coords, xx_boundary_channel, yy_boundary_channel], dim=1)
126 |
127 | x_and_coords = torch.cat([x, coords], dim=1)
128 | return x_and_coords
129 |
130 |
131 | class CoordConvTh(nn.Module):
132 | """CoordConv layer as in the paper."""
133 | def __init__(self, height, width, with_r, with_boundary,
134 | in_channels, first_one=False, *args, **kwargs):
135 | super(CoordConvTh, self).__init__()
136 | self.addcoords = AddCoordsTh(height, width, with_r, with_boundary)
137 | in_channels += 2
138 | if with_r:
139 | in_channels += 1
140 | if with_boundary and not first_one:
141 | in_channels += 2
142 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
143 |
144 | def forward(self, input_tensor, heatmap=None):
145 | ret = self.addcoords(input_tensor, heatmap)
146 | last_channel = ret[:, -2:, :, :]
147 | ret = self.conv(ret)
148 | return ret, last_channel
149 |
150 |
151 | class ConvBlock(nn.Module):
152 | def __init__(self, in_planes, out_planes):
153 | super(ConvBlock, self).__init__()
154 | self.bn1 = nn.BatchNorm2d(in_planes)
155 | conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False, dilation=1)
156 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
157 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
158 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
159 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
160 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
161 |
162 | self.downsample = None
163 | if in_planes != out_planes:
164 | self.downsample = nn.Sequential(nn.BatchNorm2d(in_planes),
165 | nn.ReLU(True),
166 | nn.Conv2d(in_planes, out_planes, 1, 1, bias=False))
167 |
168 | def forward(self, x):
169 | residual = x
170 |
171 | out1 = self.bn1(x)
172 | out1 = F.relu(out1, True)
173 | out1 = self.conv1(out1)
174 |
175 | out2 = self.bn2(out1)
176 | out2 = F.relu(out2, True)
177 | out2 = self.conv2(out2)
178 |
179 | out3 = self.bn3(out2)
180 | out3 = F.relu(out3, True)
181 | out3 = self.conv3(out3)
182 |
183 | out3 = torch.cat((out1, out2, out3), 1)
184 | if self.downsample is not None:
185 | residual = self.downsample(residual)
186 | out3 += residual
187 | return out3
188 |
189 |
190 | class FAN(nn.Module):
191 | def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None):
192 | super(FAN, self).__init__()
193 | self.num_modules = num_modules
194 | self.end_relu = end_relu
195 |
196 | # Base part
197 | self.conv1 = CoordConvTh(256, 256, True, False,
198 | in_channels=3, out_channels=64,
199 | kernel_size=7, stride=2, padding=3)
200 | self.bn1 = nn.BatchNorm2d(64)
201 | self.conv2 = ConvBlock(64, 128)
202 | self.conv3 = ConvBlock(128, 128)
203 | self.conv4 = ConvBlock(128, 256)
204 |
205 | # Stacking part
206 | self.add_module('m0', HourGlass(1, 4, 256, first_one=True))
207 | self.add_module('top_m_0', ConvBlock(256, 256))
208 | self.add_module('conv_last0', nn.Conv2d(256, 256, 1, 1, 0))
209 | self.add_module('bn_end0', nn.BatchNorm2d(256))
210 | self.add_module('l0', nn.Conv2d(256, num_landmarks+1, 1, 1, 0))
211 |
212 | if fname_pretrained is not None:
213 | self.load_pretrained_weights(fname_pretrained)
214 |
215 | def load_pretrained_weights(self, fname):
216 | if torch.cuda.is_available():
217 | checkpoint = torch.load(fname)
218 | else:
219 | checkpoint = torch.load(fname, map_location=torch.device('cpu'))
220 | model_weights = self.state_dict()
221 | model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
222 | if k in model_weights})
223 | self.load_state_dict(model_weights)
224 |
225 | def forward(self, x):
226 | x, _ = self.conv1(x)
227 | x = F.relu(self.bn1(x), True)
228 | x = F.avg_pool2d(self.conv2(x), 2, stride=2)
229 | x = self.conv3(x)
230 | x = self.conv4(x)
231 |
232 | outputs = []
233 | boundary_channels = []
234 | tmp_out = None
235 | ll, boundary_channel = self._modules['m0'](x, tmp_out)
236 | ll = self._modules['top_m_0'](ll)
237 | ll = F.relu(self._modules['bn_end0']
238 | (self._modules['conv_last0'](ll)), True)
239 |
240 | # Predict heatmaps
241 | tmp_out = self._modules['l0'](ll)
242 | if self.end_relu:
243 | tmp_out = F.relu(tmp_out) # HACK: Added relu
244 | outputs.append(tmp_out)
245 | boundary_channels.append(boundary_channel)
246 | return outputs, boundary_channels
247 |
248 | @torch.no_grad()
249 | def get_heatmap(self, x, b_preprocess=True):
250 | ''' outputs 0-1 normalized heatmap '''
251 | x = F.interpolate(x, size=256, mode='bilinear')
252 | x_01 = x*0.5 + 0.5
253 | outputs, _ = self(x_01)
254 | heatmaps = outputs[-1][:, :-1, :, :]
255 | scale_factor = x.size(2) // heatmaps.size(2)
256 | if b_preprocess:
257 | heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor,
258 | mode='bilinear', align_corners=True)
259 | heatmaps = preprocess(heatmaps)
260 | return heatmaps
261 |
262 | @torch.no_grad()
263 | def get_landmark(self, x):
264 | ''' outputs landmarks of x.shape '''
265 | heatmaps = self.get_heatmap(x, b_preprocess=False)
266 | landmarks = []
267 | for i in range(x.size(0)):
268 | pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0))
269 | landmarks.append(pred_landmarks)
270 | scale_factor = x.size(2) // heatmaps.size(2)
271 | landmarks = torch.cat(landmarks) * scale_factor
272 | return landmarks
273 |
274 |
275 | # ========================== #
276 | # Align related functions #
277 | # ========================== #
278 |
279 |
280 | def tensor2numpy255(tensor):
281 | """Converts torch tensor to numpy array."""
282 | return ((tensor.permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5) * 255).astype('uint8')
283 |
284 |
285 | def np2tensor(image):
286 | """Converts numpy array to torch tensor."""
287 | return torch.FloatTensor(image).permute(2, 0, 1) / 255 * 2 - 1
288 |
289 |
290 | class FaceAligner():
291 | def __init__(self, fname_wing, fname_celeba_mean, output_size):
292 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
293 | self.fan = FAN(fname_pretrained=fname_wing).to(self.device).eval()
294 | scale = output_size // 256
295 | self.CELEB_REF = np.float32(np.load(fname_celeba_mean)['mean']) * scale
296 | self.xaxis_ref = landmarks2xaxis(self.CELEB_REF)
297 | self.output_size = output_size
298 |
299 | def align(self, imgs, output_size=256):
300 | ''' imgs = torch.CUDATensor of BCHW '''
301 | imgs = imgs.to(self.device)
302 | landmarkss = self.fan.get_landmark(imgs).cpu().numpy()
303 | for i, (img, landmarks) in enumerate(zip(imgs, landmarkss)):
304 | img_np = tensor2numpy255(img)
305 | img_np, landmarks = pad_mirror(img_np, landmarks)
306 | transform = self.landmarks2mat(landmarks)
307 | rows, cols, _ = img_np.shape
308 | rows = max(rows, self.output_size)
309 | cols = max(cols, self.output_size)
310 | aligned = cv2.warpPerspective(img_np, transform, (cols, rows), flags=cv2.INTER_LANCZOS4)
311 | imgs[i] = np2tensor(aligned[:self.output_size, :self.output_size, :])
312 | return imgs
313 |
314 | def landmarks2mat(self, landmarks):
315 | T_origin = points2T(landmarks, 'from')
316 | xaxis_src = landmarks2xaxis(landmarks)
317 | R = vecs2R(xaxis_src, self.xaxis_ref)
318 | S = landmarks2S(landmarks, self.CELEB_REF)
319 | T_ref = points2T(self.CELEB_REF, 'to')
320 | matrix = np.dot(T_ref, np.dot(S, np.dot(R, T_origin)))
321 | return matrix
322 |
323 |
324 | def points2T(point, direction):
325 | point_mean = point.mean(axis=0)
326 | T = np.eye(3)
327 | coef = -1 if direction == 'from' else 1
328 | T[:2, 2] = coef * point_mean
329 | return T
330 |
331 |
332 | def landmarks2eyes(landmarks):
333 | idx_left = np.array(list(range(60, 67+1)) + [96])
334 | idx_right = np.array(list(range(68, 75+1)) + [97])
335 | left = landmarks[idx_left]
336 | right = landmarks[idx_right]
337 | return left.mean(axis=0), right.mean(axis=0)
338 |
339 |
340 | def landmarks2mouthends(landmarks):
341 | left = landmarks[76]
342 | right = landmarks[82]
343 | return left, right
344 |
345 |
346 | def rotate90(vec):
347 | x, y = vec
348 | return np.array([y, -x])
349 |
350 |
351 | def landmarks2xaxis(landmarks):
352 | eye_left, eye_right = landmarks2eyes(landmarks)
353 | mouth_left, mouth_right = landmarks2mouthends(landmarks)
354 | xp = eye_right - eye_left # x' in pggan
355 | eye_center = (eye_left + eye_right) * 0.5
356 | mouth_center = (mouth_left + mouth_right) * 0.5
357 | yp = eye_center - mouth_center
358 | xaxis = xp - rotate90(yp)
359 | return xaxis / np.linalg.norm(xaxis)
360 |
361 |
362 | def vecs2R(vec_x, vec_y):
363 | vec_x = vec_x / np.linalg.norm(vec_x)
364 | vec_y = vec_y / np.linalg.norm(vec_y)
365 | c = np.dot(vec_x, vec_y)
366 | s = np.sqrt(1 - c * c) * np.sign(np.cross(vec_x, vec_y))
367 | R = np.array(((c, -s, 0), (s, c, 0), (0, 0, 1)))
368 | return R
369 |
370 |
371 | def landmarks2S(x, y):
372 | x_mean = x.mean(axis=0).squeeze()
373 | y_mean = y.mean(axis=0).squeeze()
374 | # vectors = mean -> each point
375 | x_vectors = x - x_mean
376 | y_vectors = y - y_mean
377 |
378 | x_norms = np.linalg.norm(x_vectors, axis=1)
379 | y_norms = np.linalg.norm(y_vectors, axis=1)
380 |
381 | indices = [96, 97, 76, 82] # indices for eyes, lips
382 | scale = (y_norms / x_norms)[indices].mean()
383 |
384 | S = np.eye(3)
385 | S[0, 0] = S[1, 1] = scale
386 | return S
387 |
388 |
389 | def pad_mirror(img, landmarks):
390 | H, W, _ = img.shape
391 | img = np.pad(img, ((H//2, H//2), (W//2, W//2), (0, 0)), 'reflect')
392 | small_blurred = gaussian(cv2.resize(img, (W, H)), H//100, multichannel=True)
393 | blurred = cv2.resize(small_blurred, (W * 2, H * 2)) * 255
394 |
395 | H, W, _ = img.shape
396 | coords = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
397 | weight_y = np.clip(coords[0] / (H//4), 0, 1)
398 | weight_x = np.clip(coords[1] / (H//4), 0, 1)
399 | weight_y = np.minimum(weight_y, np.flip(weight_y, axis=0))
400 | weight_x = np.minimum(weight_x, np.flip(weight_x, axis=1))
401 | weight = np.expand_dims(np.minimum(weight_y, weight_x), 2)**4
402 | img = img * weight + blurred * (1 - weight)
403 | landmarks += np.array([W//4, H//4])
404 | return img, landmarks
405 |
406 |
407 | def align_faces(args, input_dir, output_dir):
408 | import os
409 | from torchvision import transforms
410 | from PIL import Image
411 | from core.utils import save_image
412 |
413 | aligner = FaceAligner(args.wing_path, args.lm_path, args.img_size)
414 | transform = transforms.Compose([
415 | transforms.Resize((args.img_size, args.img_size)),
416 | transforms.ToTensor(),
417 | transforms.Normalize(mean=[0.5, 0.5, 0.5],
418 | std=[0.5, 0.5, 0.5]),
419 | ])
420 |
421 | fnames = os.listdir(input_dir)
422 | os.makedirs(output_dir, exist_ok=True)
423 | fnames.sort()
424 | for fname in fnames:
425 | image = Image.open(os.path.join(input_dir, fname)).convert('RGB')
426 | x = transform(image).unsqueeze(0)
427 | x_aligned = aligner.align(x)
428 | save_image(x_aligned, 1, filename=os.path.join(output_dir, fname))
429 | print('Saved the aligned image to %s...' % fname)
430 |
431 |
432 | # ========================== #
433 | # Mask related functions #
434 | # ========================== #
435 |
436 |
437 | def normalize(x, eps=1e-6):
438 | """Apply min-max normalization."""
439 | x = x.contiguous()
440 | N, C, H, W = x.size()
441 | x_ = x.view(N*C, -1)
442 | max_val = torch.max(x_, dim=1, keepdim=True)[0]
443 | min_val = torch.min(x_, dim=1, keepdim=True)[0]
444 | x_ = (x_ - min_val) / (max_val - min_val + eps)
445 | out = x_.view(N, C, H, W)
446 | return out
447 |
448 |
449 | def truncate(x, thres=0.1):
450 | """Remove small values in heatmaps."""
451 | return torch.where(x < thres, torch.zeros_like(x), x)
452 |
453 |
454 | def resize(x, p=2):
455 | """Resize heatmaps."""
456 | return x**p
457 |
458 |
459 | def shift(x, N):
460 | """Shift N pixels up or down."""
461 | up = N >= 0
462 | N = abs(N)
463 | _, _, H, W = x.size()
464 | head = torch.arange(N)
465 | tail = torch.arange(H-N)
466 |
467 | if up:
468 | head = torch.arange(H-N)+N
469 | tail = torch.arange(N)
470 | else:
471 | head = torch.arange(N) + (H-N)
472 | tail = torch.arange(H-N)
473 |
474 | # permutation indices
475 | perm = torch.cat([head, tail]).to(x.device)
476 | out = x[:, :, perm, :]
477 | return out
478 |
479 |
480 | IDXPAIR = namedtuple('IDXPAIR', 'start end')
481 | index_map = Munch(chin=IDXPAIR(0 + 8, 33 - 8),
482 | eyebrows=IDXPAIR(33, 51),
483 | eyebrowsedges=IDXPAIR(33, 46),
484 | nose=IDXPAIR(51, 55),
485 | nostrils=IDXPAIR(55, 60),
486 | eyes=IDXPAIR(60, 76),
487 | lipedges=IDXPAIR(76, 82),
488 | lipupper=IDXPAIR(77, 82),
489 | liplower=IDXPAIR(83, 88),
490 | lipinner=IDXPAIR(88, 96))
491 | OPPAIR = namedtuple('OPPAIR', 'shift resize')
492 |
493 |
494 | def preprocess(x):
495 | """Preprocess 98-dimensional heatmaps."""
496 | N, C, H, W = x.size()
497 | x = truncate(x)
498 | x = normalize(x)
499 |
500 | sw = H // 256
501 | operations = Munch(chin=OPPAIR(0, 3),
502 | eyebrows=OPPAIR(-7*sw, 2),
503 | nostrils=OPPAIR(8*sw, 4),
504 | lipupper=OPPAIR(-8*sw, 4),
505 | liplower=OPPAIR(8*sw, 4),
506 | lipinner=OPPAIR(-2*sw, 3))
507 |
508 | for part, ops in operations.items():
509 | start, end = index_map[part]
510 | x[:, start:end] = resize(shift(x[:, start:end], ops.shift), ops.resize)
511 |
512 | zero_out = torch.cat([torch.arange(0, index_map.chin.start),
513 | torch.arange(index_map.chin.end, 33),
514 | torch.LongTensor([index_map.eyebrowsedges.start,
515 | index_map.eyebrowsedges.end,
516 | index_map.lipedges.start,
517 | index_map.lipedges.end])])
518 | x[:, zero_out] = 0
519 |
520 | start, end = index_map.nose
521 | x[:, start+1:end] = shift(x[:, start+1:end], 4*sw)
522 | x[:, start:end] = resize(x[:, start:end], 1)
523 |
524 | start, end = index_map.eyes
525 | x[:, start:end] = resize(x[:, start:end], 1)
526 | x[:, start:end] = resize(shift(x[:, start:end], -8), 3) + \
527 | shift(x[:, start:end], -24)
528 |
529 | # Second-level mask
530 | x2 = deepcopy(x)
531 | x2[:, index_map.chin.start:index_map.chin.end] = 0 # start:end was 0:33
532 | x2[:, index_map.lipedges.start:index_map.lipinner.end] = 0 # start:end was 76:96
533 | x2[:, index_map.eyebrows.start:index_map.eyebrows.end] = 0 # start:end was 33:51
534 |
535 | x = torch.sum(x, dim=1, keepdim=True) # (N, 1, H, W)
536 | x2 = torch.sum(x2, dim=1, keepdim=True) # mask without faceline and mouth
537 |
538 | x[x != x] = 0 # set nan to zero
539 | x2[x != x] = 0 # set nan to zero
540 | return x.clamp_(0, 1), x2.clamp_(0, 1)
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | FILE=$1
12 |
13 | if [ $FILE == "pretrained-network-celeba-hq" ]; then
14 | URL=https://www.dropbox.com/s/96fmei6c93o8b8t/100000_nets_ema.ckpt?dl=0
15 | mkdir -p ./expr/checkpoints/celeba_hq
16 | OUT_FILE=./expr/checkpoints/celeba_hq/100000_nets_ema.ckpt
17 | wget -N $URL -O $OUT_FILE
18 |
19 | elif [ $FILE == "pretrained-network-afhq" ]; then
20 | URL=https://www.dropbox.com/s/etwm810v25h42sn/100000_nets_ema.ckpt?dl=0
21 | mkdir -p ./expr/checkpoints/afhq
22 | OUT_FILE=./expr/checkpoints/afhq/100000_nets_ema.ckpt
23 | wget -N $URL -O $OUT_FILE
24 |
25 | elif [ $FILE == "wing" ]; then
26 | URL=https://www.dropbox.com/s/tjxpypwpt38926e/wing.ckpt?dl=0
27 | mkdir -p ./expr/checkpoints/
28 | OUT_FILE=./expr/checkpoints/wing.ckpt
29 | wget -N $URL -O $OUT_FILE
30 | URL=https://www.dropbox.com/s/91fth49gyb7xksk/celeba_lm_mean.npz?dl=0
31 | OUT_FILE=./expr/checkpoints/celeba_lm_mean.npz
32 | wget -N $URL -O $OUT_FILE
33 |
34 | elif [ $FILE == "celeba-hq-dataset" ]; then
35 | URL=https://www.dropbox.com/s/f7pvjij2xlpff59/celeba_hq.zip?dl=0
36 | ZIP_FILE=./data/celeba_hq.zip
37 | mkdir -p ./data
38 | wget -N $URL -O $ZIP_FILE
39 | unzip $ZIP_FILE -d ./data
40 | rm $ZIP_FILE
41 |
42 | elif [ $FILE == "afhq-dataset" ]; then
43 | URL=https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=0
44 | ZIP_FILE=./data/afhq.zip
45 | mkdir -p ./data
46 | wget -N $URL -O $ZIP_FILE
47 | unzip $ZIP_FILE -d ./data
48 | rm $ZIP_FILE
49 |
50 | elif [ $FILE == "afhq-v2-dataset" ]; then
51 | #URL=https://www.dropbox.com/s/scckftx13grwmiv/afhq_v2.zip?dl=0
52 | URL=https://www.dropbox.com/s/vkzjokiwof5h8w6/afhq_v2.zip?dl=0
53 | ZIP_FILE=./data/afhq_v2.zip
54 | mkdir -p ./data
55 | wget -N $URL -O $ZIP_FILE
56 | unzip $ZIP_FILE -d ./data
57 | rm $ZIP_FILE
58 |
59 | else
60 | echo "Available arguments are pretrained-network-celeba-hq, pretrained-network-afhq, celeba-hq-dataset, and afhq-dataset."
61 | exit 1
62 |
63 | fi
64 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import os
12 | import argparse
13 |
14 | from munch import Munch
15 | from torch.backends import cudnn
16 | import torch
17 |
18 | from core.data_loader import get_train_loader
19 | from core.data_loader import get_test_loader
20 | from core.solver import Solver
21 |
22 |
23 | def str2bool(v):
24 | return v.lower() in ('true')
25 |
26 |
27 | def subdirs(dname):
28 | return [d for d in os.listdir(dname)
29 | if os.path.isdir(os.path.join(dname, d))]
30 |
31 |
32 | def main(args):
33 | print(args)
34 | cudnn.benchmark = True
35 | torch.manual_seed(args.seed)
36 |
37 | solver = Solver(args)
38 |
39 | if args.mode == 'train':
40 | assert len(subdirs(args.train_img_dir)) == args.num_domains
41 | assert len(subdirs(args.val_img_dir)) == args.num_domains
42 | loaders = Munch(src=get_train_loader(root=args.train_img_dir,
43 | which='source',
44 | img_size=args.img_size,
45 | batch_size=args.batch_size,
46 | prob=args.randcrop_prob,
47 | num_workers=args.num_workers),
48 | ref=get_train_loader(root=args.train_img_dir,
49 | which='reference',
50 | img_size=args.img_size,
51 | batch_size=args.batch_size,
52 | prob=args.randcrop_prob,
53 | num_workers=args.num_workers),
54 | val=get_test_loader(root=args.val_img_dir,
55 | img_size=args.img_size,
56 | batch_size=args.val_batch_size,
57 | shuffle=True,
58 | num_workers=args.num_workers))
59 | solver.train(loaders)
60 | elif args.mode == 'sample':
61 | assert len(subdirs(args.src_dir)) == args.num_domains
62 | assert len(subdirs(args.ref_dir)) == args.num_domains
63 | loaders = Munch(src=get_test_loader(root=args.src_dir,
64 | img_size=args.img_size,
65 | batch_size=args.val_batch_size,
66 | shuffle=False,
67 | num_workers=args.num_workers),
68 | ref=get_test_loader(root=args.ref_dir,
69 | img_size=args.img_size,
70 | batch_size=args.val_batch_size,
71 | shuffle=False,
72 | num_workers=args.num_workers))
73 | solver.sample(loaders)
74 | elif args.mode == 'eval':
75 | solver.evaluate()
76 | elif args.mode == 'align':
77 | from core.wing import align_faces
78 | align_faces(args, args.inp_dir, args.out_dir)
79 | else:
80 | raise NotImplementedError
81 |
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 |
86 | # model arguments
87 | parser.add_argument('--img_size', type=int, default=256,
88 | help='Image resolution')
89 | parser.add_argument('--num_domains', type=int, default=2,
90 | help='Number of domains')
91 | parser.add_argument('--latent_dim', type=int, default=16,
92 | help='Latent vector dimension')
93 | parser.add_argument('--hidden_dim', type=int, default=512,
94 | help='Hidden dimension of mapping network')
95 | parser.add_argument('--style_dim', type=int, default=64,
96 | help='Style code dimension')
97 |
98 | # weight for objective functions
99 | parser.add_argument('--lambda_reg', type=float, default=1,
100 | help='Weight for R1 regularization')
101 | parser.add_argument('--lambda_cyc', type=float, default=1,
102 | help='Weight for cyclic consistency loss')
103 | parser.add_argument('--lambda_sty', type=float, default=1,
104 | help='Weight for style reconstruction loss')
105 | parser.add_argument('--lambda_ds', type=float, default=1,
106 | help='Weight for diversity sensitive loss')
107 | parser.add_argument('--ds_iter', type=int, default=100000,
108 | help='Number of iterations to optimize diversity sensitive loss')
109 | parser.add_argument('--w_hpf', type=float, default=1,
110 | help='weight for high-pass filtering')
111 |
112 | # training arguments
113 | parser.add_argument('--randcrop_prob', type=float, default=0.5,
114 | help='Probabilty of using random-resized cropping')
115 | parser.add_argument('--total_iters', type=int, default=100000,
116 | help='Number of total iterations')
117 | parser.add_argument('--resume_iter', type=int, default=0,
118 | help='Iterations to resume training/testing')
119 | parser.add_argument('--batch_size', type=int, default=8,
120 | help='Batch size for training')
121 | parser.add_argument('--val_batch_size', type=int, default=32,
122 | help='Batch size for validation')
123 | parser.add_argument('--lr', type=float, default=1e-4,
124 | help='Learning rate for D, E and G')
125 | parser.add_argument('--f_lr', type=float, default=1e-6,
126 | help='Learning rate for F')
127 | parser.add_argument('--beta1', type=float, default=0.0,
128 | help='Decay rate for 1st moment of Adam')
129 | parser.add_argument('--beta2', type=float, default=0.99,
130 | help='Decay rate for 2nd moment of Adam')
131 | parser.add_argument('--weight_decay', type=float, default=1e-4,
132 | help='Weight decay for optimizer')
133 | parser.add_argument('--num_outs_per_domain', type=int, default=10,
134 | help='Number of generated images per domain during sampling')
135 |
136 | # misc
137 | parser.add_argument('--mode', type=str, required=True,
138 | choices=['train', 'sample', 'eval', 'align'],
139 | help='This argument is used in solver')
140 | parser.add_argument('--num_workers', type=int, default=4,
141 | help='Number of workers used in DataLoader')
142 | parser.add_argument('--seed', type=int, default=777,
143 | help='Seed for random number generator')
144 |
145 | # directory for training
146 | parser.add_argument('--train_img_dir', type=str, default='data/celeba_hq/train',
147 | help='Directory containing training images')
148 | parser.add_argument('--val_img_dir', type=str, default='data/celeba_hq/val',
149 | help='Directory containing validation images')
150 | parser.add_argument('--sample_dir', type=str, default='expr/samples',
151 | help='Directory for saving generated images')
152 | parser.add_argument('--checkpoint_dir', type=str, default='expr/checkpoints',
153 | help='Directory for saving network checkpoints')
154 |
155 | # directory for calculating metrics
156 | parser.add_argument('--eval_dir', type=str, default='expr/eval',
157 | help='Directory for saving metrics, i.e., FID and LPIPS')
158 |
159 | # directory for testing
160 | parser.add_argument('--result_dir', type=str, default='expr/results',
161 | help='Directory for saving generated images and videos')
162 | parser.add_argument('--src_dir', type=str, default='assets/representative/celeba_hq/src',
163 | help='Directory containing input source images')
164 | parser.add_argument('--ref_dir', type=str, default='assets/representative/celeba_hq/ref',
165 | help='Directory containing input reference images')
166 | parser.add_argument('--inp_dir', type=str, default='assets/representative/custom/female',
167 | help='input directory when aligning faces')
168 | parser.add_argument('--out_dir', type=str, default='assets/representative/celeba_hq/src/female',
169 | help='output directory when aligning faces')
170 |
171 | # face alignment
172 | parser.add_argument('--wing_path', type=str, default='expr/checkpoints/wing.ckpt')
173 | parser.add_argument('--lm_path', type=str, default='expr/checkpoints/celeba_lm_mean.npz')
174 |
175 | # step size
176 | parser.add_argument('--print_every', type=int, default=10)
177 | parser.add_argument('--sample_every', type=int, default=5000)
178 | parser.add_argument('--save_every', type=int, default=10000)
179 | parser.add_argument('--eval_every', type=int, default=50000)
180 |
181 | args = parser.parse_args()
182 | main(args)
183 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import os
12 | import shutil
13 | from collections import OrderedDict
14 | from tqdm import tqdm
15 |
16 | import numpy as np
17 | import torch
18 |
19 | from metrics.fid import calculate_fid_given_paths
20 | from metrics.lpips import calculate_lpips_given_images
21 | from core.data_loader import get_eval_loader
22 | from core import utils
23 |
24 |
25 | @torch.no_grad()
26 | def calculate_metrics(nets, args, step, mode):
27 | print('Calculating evaluation metrics...')
28 | assert mode in ['latent', 'reference']
29 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 |
31 | domains = os.listdir(args.val_img_dir)
32 | domains.sort()
33 | num_domains = len(domains)
34 | print('Number of domains: %d' % num_domains)
35 |
36 | lpips_dict = OrderedDict()
37 | for trg_idx, trg_domain in enumerate(domains):
38 | src_domains = [x for x in domains if x != trg_domain]
39 |
40 | if mode == 'reference':
41 | path_ref = os.path.join(args.val_img_dir, trg_domain)
42 | loader_ref = get_eval_loader(root=path_ref,
43 | img_size=args.img_size,
44 | batch_size=args.val_batch_size,
45 | imagenet_normalize=False,
46 | drop_last=True)
47 |
48 | for src_idx, src_domain in enumerate(src_domains):
49 | path_src = os.path.join(args.val_img_dir, src_domain)
50 | loader_src = get_eval_loader(root=path_src,
51 | img_size=args.img_size,
52 | batch_size=args.val_batch_size,
53 | imagenet_normalize=False)
54 |
55 | task = '%s2%s' % (src_domain, trg_domain)
56 | path_fake = os.path.join(args.eval_dir, task)
57 | shutil.rmtree(path_fake, ignore_errors=True)
58 | os.makedirs(path_fake)
59 |
60 | lpips_values = []
61 | print('Generating images and calculating LPIPS for %s...' % task)
62 | for i, x_src in enumerate(tqdm(loader_src, total=len(loader_src))):
63 | N = x_src.size(0)
64 | x_src = x_src.to(device)
65 | y_trg = torch.tensor([trg_idx] * N).to(device)
66 | masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
67 |
68 | # generate 10 outputs from the same input
69 | group_of_images = []
70 | for j in range(args.num_outs_per_domain):
71 | if mode == 'latent':
72 | z_trg = torch.randn(N, args.latent_dim).to(device)
73 | s_trg = nets.mapping_network(z_trg, y_trg)
74 | else:
75 | try:
76 | x_ref = next(iter_ref).to(device)
77 | except:
78 | iter_ref = iter(loader_ref)
79 | x_ref = next(iter_ref).to(device)
80 |
81 | if x_ref.size(0) > N:
82 | x_ref = x_ref[:N]
83 | s_trg = nets.style_encoder(x_ref, y_trg)
84 |
85 | x_fake = nets.generator(x_src, s_trg, masks=masks)
86 | group_of_images.append(x_fake)
87 |
88 | # save generated images to calculate FID later
89 | for k in range(N):
90 | filename = os.path.join(
91 | path_fake,
92 | '%.4i_%.2i.png' % (i*args.val_batch_size+(k+1), j+1))
93 | utils.save_image(x_fake[k], ncol=1, filename=filename)
94 |
95 | lpips_value = calculate_lpips_given_images(group_of_images)
96 | lpips_values.append(lpips_value)
97 |
98 | # calculate LPIPS for each task (e.g. cat2dog, dog2cat)
99 | lpips_mean = np.array(lpips_values).mean()
100 | lpips_dict['LPIPS_%s/%s' % (mode, task)] = lpips_mean
101 |
102 | # delete dataloaders
103 | del loader_src
104 | if mode == 'reference':
105 | del loader_ref
106 | del iter_ref
107 |
108 | # calculate the average LPIPS for all tasks
109 | lpips_mean = 0
110 | for _, value in lpips_dict.items():
111 | lpips_mean += value / len(lpips_dict)
112 | lpips_dict['LPIPS_%s/mean' % mode] = lpips_mean
113 |
114 | # report LPIPS values
115 | filename = os.path.join(args.eval_dir, 'LPIPS_%.5i_%s.json' % (step, mode))
116 | utils.save_json(lpips_dict, filename)
117 |
118 | # calculate and report fid values
119 | calculate_fid_for_all_tasks(args, domains, step=step, mode=mode)
120 |
121 |
122 | def calculate_fid_for_all_tasks(args, domains, step, mode):
123 | print('Calculating FID for all tasks...')
124 | fid_values = OrderedDict()
125 | for trg_domain in domains:
126 | src_domains = [x for x in domains if x != trg_domain]
127 |
128 | for src_domain in src_domains:
129 | task = '%s2%s' % (src_domain, trg_domain)
130 | path_real = os.path.join(args.train_img_dir, trg_domain)
131 | path_fake = os.path.join(args.eval_dir, task)
132 | print('Calculating FID for %s...' % task)
133 | fid_value = calculate_fid_given_paths(
134 | paths=[path_real, path_fake],
135 | img_size=args.img_size,
136 | batch_size=args.val_batch_size)
137 | fid_values['FID_%s/%s' % (mode, task)] = fid_value
138 |
139 | # calculate the average FID for all tasks
140 | fid_mean = 0
141 | for _, value in fid_values.items():
142 | fid_mean += value / len(fid_values)
143 | fid_values['FID_%s/mean' % mode] = fid_mean
144 |
145 | # report FID values
146 | filename = os.path.join(args.eval_dir, 'FID_%.5i_%s.json' % (step, mode))
147 | utils.save_json(fid_values, filename)
148 |
--------------------------------------------------------------------------------
/metrics/fid.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import os
12 | import argparse
13 |
14 | import torch
15 | import torch.nn as nn
16 | import numpy as np
17 | from torchvision import models
18 | from scipy import linalg
19 | from core.data_loader import get_eval_loader
20 |
21 | try:
22 | from tqdm import tqdm
23 | except ImportError:
24 | def tqdm(x): return x
25 |
26 |
27 | class InceptionV3(nn.Module):
28 | def __init__(self):
29 | super().__init__()
30 | inception = models.inception_v3(pretrained=True)
31 | self.block1 = nn.Sequential(
32 | inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
33 | inception.Conv2d_2b_3x3,
34 | nn.MaxPool2d(kernel_size=3, stride=2))
35 | self.block2 = nn.Sequential(
36 | inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
37 | nn.MaxPool2d(kernel_size=3, stride=2))
38 | self.block3 = nn.Sequential(
39 | inception.Mixed_5b, inception.Mixed_5c,
40 | inception.Mixed_5d, inception.Mixed_6a,
41 | inception.Mixed_6b, inception.Mixed_6c,
42 | inception.Mixed_6d, inception.Mixed_6e)
43 | self.block4 = nn.Sequential(
44 | inception.Mixed_7a, inception.Mixed_7b,
45 | inception.Mixed_7c,
46 | nn.AdaptiveAvgPool2d(output_size=(1, 1)))
47 |
48 | def forward(self, x):
49 | x = self.block1(x)
50 | x = self.block2(x)
51 | x = self.block3(x)
52 | x = self.block4(x)
53 | return x.view(x.size(0), -1)
54 |
55 |
56 | def frechet_distance(mu, cov, mu2, cov2):
57 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
58 | dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc)
59 | return np.real(dist)
60 |
61 |
62 | @torch.no_grad()
63 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50):
64 | print('Calculating FID given paths %s and %s...' % (paths[0], paths[1]))
65 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
66 | inception = InceptionV3().eval().to(device)
67 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]
68 |
69 | mu, cov = [], []
70 | for loader in loaders:
71 | actvs = []
72 | for x in tqdm(loader, total=len(loader)):
73 | actv = inception(x.to(device))
74 | actvs.append(actv)
75 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
76 | mu.append(np.mean(actvs, axis=0))
77 | cov.append(np.cov(actvs, rowvar=False))
78 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
79 | return fid_value
80 |
81 |
82 | if __name__ == '__main__':
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument('--paths', type=str, nargs=2, help='paths to real and fake images')
85 | parser.add_argument('--img_size', type=int, default=256, help='image resolution')
86 | parser.add_argument('--batch_size', type=int, default=64, help='batch size to use')
87 | args = parser.parse_args()
88 | fid_value = calculate_fid_given_paths(args.paths, args.img_size, args.batch_size)
89 | print('FID: ', fid_value)
90 |
91 | # python -m metrics.fid --paths PATH_REAL PATH_FAKE
--------------------------------------------------------------------------------
/metrics/lpips.py:
--------------------------------------------------------------------------------
1 | """
2 | StarGAN v2
3 | Copyright (c) 2020-present NAVER Corp.
4 |
5 | This work is licensed under the Creative Commons Attribution-NonCommercial
6 | 4.0 International License. To view a copy of this license, visit
7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torchvision import models
14 |
15 |
16 | def normalize(x, eps=1e-10):
17 | return x * torch.rsqrt(torch.sum(x**2, dim=1, keepdim=True) + eps)
18 |
19 |
20 | class AlexNet(nn.Module):
21 | def __init__(self):
22 | super().__init__()
23 | self.layers = models.alexnet(pretrained=True).features
24 | self.channels = []
25 | for layer in self.layers:
26 | if isinstance(layer, nn.Conv2d):
27 | self.channels.append(layer.out_channels)
28 |
29 | def forward(self, x):
30 | fmaps = []
31 | for layer in self.layers:
32 | x = layer(x)
33 | if isinstance(layer, nn.ReLU):
34 | fmaps.append(x)
35 | return fmaps
36 |
37 |
38 | class Conv1x1(nn.Module):
39 | def __init__(self, in_channels, out_channels=1):
40 | super().__init__()
41 | self.main = nn.Sequential(
42 | nn.Dropout(0.5),
43 | nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False))
44 |
45 | def forward(self, x):
46 | return self.main(x)
47 |
48 |
49 | class LPIPS(nn.Module):
50 | def __init__(self):
51 | super().__init__()
52 | self.alexnet = AlexNet()
53 | self.lpips_weights = nn.ModuleList()
54 | for channels in self.alexnet.channels:
55 | self.lpips_weights.append(Conv1x1(channels, 1))
56 | self._load_lpips_weights()
57 | # imagenet normalization for range [-1, 1]
58 | self.mu = torch.tensor([-0.03, -0.088, -0.188]).view(1, 3, 1, 1).cuda()
59 | self.sigma = torch.tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1).cuda()
60 |
61 | def _load_lpips_weights(self):
62 | own_state_dict = self.state_dict()
63 | if torch.cuda.is_available():
64 | state_dict = torch.load('metrics/lpips_weights.ckpt')
65 | else:
66 | state_dict = torch.load('metrics/lpips_weights.ckpt',
67 | map_location=torch.device('cpu'))
68 | for name, param in state_dict.items():
69 | if name in own_state_dict:
70 | own_state_dict[name].copy_(param)
71 |
72 | def forward(self, x, y):
73 | x = (x - self.mu) / self.sigma
74 | y = (y - self.mu) / self.sigma
75 | x_fmaps = self.alexnet(x)
76 | y_fmaps = self.alexnet(y)
77 | lpips_value = 0
78 | for x_fmap, y_fmap, conv1x1 in zip(x_fmaps, y_fmaps, self.lpips_weights):
79 | x_fmap = normalize(x_fmap)
80 | y_fmap = normalize(y_fmap)
81 | lpips_value += torch.mean(conv1x1((x_fmap - y_fmap)**2))
82 | return lpips_value
83 |
84 |
85 | @torch.no_grad()
86 | def calculate_lpips_given_images(group_of_images):
87 | # group_of_images = [torch.randn(N, C, H, W) for _ in range(10)]
88 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89 | lpips = LPIPS().eval().to(device)
90 | lpips_values = []
91 | num_rand_outputs = len(group_of_images)
92 |
93 | # calculate the average of pairwise distances among all random outputs
94 | for i in range(num_rand_outputs-1):
95 | for j in range(i+1, num_rand_outputs):
96 | lpips_values.append(lpips(group_of_images[i], group_of_images[j]))
97 | lpips_value = torch.mean(torch.stack(lpips_values, dim=0))
98 | return lpips_value.item()
--------------------------------------------------------------------------------
/metrics/lpips_weights.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/stargan-v2/875b70a150609e8a678ed8482562e7074cdce7e5/metrics/lpips_weights.ckpt
--------------------------------------------------------------------------------