├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── VERSION
├── VGGFace2-HQ.png
├── VGGFace2-HQ.pptx
├── docs
└── img
│ ├── 2.png
│ ├── VGGFace2-HQ.png
│ ├── girl2-RGB.png
│ ├── girl2.gif
│ ├── logo.png
│ ├── simswap.png
│ ├── title.png
│ └── vggface2_hq_compare.png
├── gfpgan
├── __init__.py
├── archs
│ ├── __init__.py
│ ├── arcface_arch.py
│ ├── gfpganv1_arch.py
│ ├── gfpganv1_clean_arch.py
│ └── stylegan2_clean_arch.py
├── data
│ ├── __init__.py
│ └── ffhq_degradation_dataset.py
├── models
│ ├── __init__.py
│ └── gfpgan_model.py
├── train.py
├── utils.py
└── weights
│ └── README.md
├── inference_gfpgan.py
├── insightface_func
├── __init__.py
├── face_detect_crop.py
├── face_detect_crop_ffhq_newarcAlign.py
└── utils
│ ├── face_align.py
│ └── face_align_ffhqandnewarc.py
├── options
├── train_gfpgan_v1.yml
└── train_gfpgan_v1_simple.yml
├── requirements.txt
├── scripts
├── crop_align_vggface2_FFHQalign.py
├── crop_align_vggface2_FFHQalignandNewarcalign.py
├── inference_gfpgan_forvggface2.py
└── vggface_dataset.py
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignored folders
2 | datasets/*
3 | experiments/*
4 | results/*
5 | tb_logger/*
6 | wandb/*
7 | tmp/*
8 |
9 | version.py
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | pip-wheel-metadata/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 | .ppt
141 | .pptx
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # flake8
3 | - repo: https://github.com/PyCQA/flake8
4 | rev: 3.8.3
5 | hooks:
6 | - id: flake8
7 | args: ["--config=setup.cfg", "--ignore=W504, W503"]
8 |
9 | # modify known_third_party
10 | - repo: https://github.com/asottile/seed-isort-config
11 | rev: v2.2.0
12 | hooks:
13 | - id: seed-isort-config
14 |
15 | # isort
16 | - repo: https://github.com/timothycrosley/isort
17 | rev: 5.2.2
18 | hooks:
19 | - id: isort
20 |
21 | # yapf
22 | - repo: https://github.com/pre-commit/mirrors-yapf
23 | rev: v0.30.0
24 | hooks:
25 | - id: yapf
26 |
27 | # codespell
28 | - repo: https://github.com/codespell-project/codespell
29 | rev: v2.1.0
30 | hooks:
31 | - id: codespell
32 |
33 | # pre-commit-hooks
34 | - repo: https://github.com/pre-commit/pre-commit-hooks
35 | rev: v3.2.0
36 | hooks:
37 | - id: trailing-whitespace # Trim trailing whitespace
38 | - id: check-yaml # Attempt to load all yaml files to verify syntax
39 | - id: check-merge-conflict # Check for files that contain merge conflict strings
40 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings
41 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline
42 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
43 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*-
44 | args: ["--remove"]
45 | - id: mixed-line-ending # Replace or check mixed line ending
46 | args: ["--fix=lf"]
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 | Section 1 -- Definitions.
71 |
72 | a. Adapted Material means material subject to Copyright and Similar
73 | Rights that is derived from or based upon the Licensed Material
74 | and in which the Licensed Material is translated, altered,
75 | arranged, transformed, or otherwise modified in a manner requiring
76 | permission under the Copyright and Similar Rights held by the
77 | Licensor. For purposes of this Public License, where the Licensed
78 | Material is a musical work, performance, or sound recording,
79 | Adapted Material is always produced where the Licensed Material is
80 | synched in timed relation with a moving image.
81 |
82 | b. Adapter's License means the license You apply to Your Copyright
83 | and Similar Rights in Your contributions to Adapted Material in
84 | accordance with the terms and conditions of this Public License.
85 |
86 | c. Copyright and Similar Rights means copyright and/or similar rights
87 | closely related to copyright including, without limitation,
88 | performance, broadcast, sound recording, and Sui Generis Database
89 | Rights, without regard to how the rights are labeled or
90 | categorized. For purposes of this Public License, the rights
91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
92 | Rights.
93 | d. Effective Technological Measures means those measures that, in the
94 | absence of proper authority, may not be circumvented under laws
95 | fulfilling obligations under Article 11 of the WIPO Copyright
96 | Treaty adopted on December 20, 1996, and/or similar international
97 | agreements.
98 |
99 | e. Exceptions and Limitations means fair use, fair dealing, and/or
100 | any other exception or limitation to Copyright and Similar Rights
101 | that applies to Your use of the Licensed Material.
102 |
103 | f. Licensed Material means the artistic or literary work, database,
104 | or other material to which the Licensor applied this Public
105 | License.
106 |
107 | g. Licensed Rights means the rights granted to You subject to the
108 | terms and conditions of this Public License, which are limited to
109 | all Copyright and Similar Rights that apply to Your use of the
110 | Licensed Material and that the Licensor has authority to license.
111 |
112 | h. Licensor means the individual(s) or entity(ies) granting rights
113 | under this Public License.
114 |
115 | i. NonCommercial means not primarily intended for or directed towards
116 | commercial advantage or monetary compensation. For purposes of
117 | this Public License, the exchange of the Licensed Material for
118 | other material subject to Copyright and Similar Rights by digital
119 | file-sharing or similar means is NonCommercial provided there is
120 | no payment of monetary compensation in connection with the
121 | exchange.
122 |
123 | j. Share means to provide material to the public by any means or
124 | process that requires permission under the Licensed Rights, such
125 | as reproduction, public display, public performance, distribution,
126 | dissemination, communication, or importation, and to make material
127 | available to the public including in ways that members of the
128 | public may access the material from a place and at a time
129 | individually chosen by them.
130 |
131 | k. Sui Generis Database Rights means rights other than copyright
132 | resulting from Directive 96/9/EC of the European Parliament and of
133 | the Council of 11 March 1996 on the legal protection of databases,
134 | as amended and/or succeeded, as well as other essentially
135 | equivalent rights anywhere in the world.
136 |
137 | l. You means the individual or entity exercising the Licensed Rights
138 | under this Public License. Your has a corresponding meaning.
139 |
140 | Section 2 -- Scope.
141 |
142 | a. License grant.
143 |
144 | 1. Subject to the terms and conditions of this Public License,
145 | the Licensor hereby grants You a worldwide, royalty-free,
146 | non-sublicensable, non-exclusive, irrevocable license to
147 | exercise the Licensed Rights in the Licensed Material to:
148 |
149 | a. reproduce and Share the Licensed Material, in whole or
150 | in part, for NonCommercial purposes only; and
151 |
152 | b. produce, reproduce, and Share Adapted Material for
153 | NonCommercial purposes only.
154 |
155 | 2. Exceptions and Limitations. For the avoidance of doubt, where
156 | Exceptions and Limitations apply to Your use, this Public
157 | License does not apply, and You do not need to comply with
158 | its terms and conditions.
159 |
160 | 3. Term. The term of this Public License is specified in Section
161 | 6(a).
162 |
163 | 4. Media and formats; technical modifications allowed. The
164 | Licensor authorizes You to exercise the Licensed Rights in
165 | all media and formats whether now known or hereafter created,
166 | and to make technical modifications necessary to do so. The
167 | Licensor waives and/or agrees not to assert any right or
168 | authority to forbid You from making technical modifications
169 | necessary to exercise the Licensed Rights, including
170 | technical modifications necessary to circumvent Effective
171 | Technological Measures. For purposes of this Public License,
172 | simply making modifications authorized by this Section 2(a)
173 | (4) never produces Adapted Material.
174 |
175 | 5. Downstream recipients.
176 |
177 | a. Offer from the Licensor -- Licensed Material. Every
178 | recipient of the Licensed Material automatically
179 | receives an offer from the Licensor to exercise the
180 | Licensed Rights under the terms and conditions of this
181 | Public License.
182 |
183 | b. No downstream restrictions. You may not offer or impose
184 | any additional or different terms or conditions on, or
185 | apply any Effective Technological Measures to, the
186 | Licensed Material if doing so restricts exercise of the
187 | Licensed Rights by any recipient of the Licensed
188 | Material.
189 |
190 | 6. No endorsement. Nothing in this Public License constitutes or
191 | may be construed as permission to assert or imply that You
192 | are, or that Your use of the Licensed Material is, connected
193 | with, or sponsored, endorsed, or granted official status by,
194 | the Licensor or others designated to receive attribution as
195 | provided in Section 3(a)(1)(A)(i).
196 |
197 | b. Other rights.
198 |
199 | 1. Moral rights, such as the right of integrity, are not
200 | licensed under this Public License, nor are publicity,
201 | privacy, and/or other similar personality rights; however, to
202 | the extent possible, the Licensor waives and/or agrees not to
203 | assert any such rights held by the Licensor to the limited
204 | extent necessary to allow You to exercise the Licensed
205 | Rights, but not otherwise.
206 |
207 | 2. Patent and trademark rights are not licensed under this
208 | Public License.
209 |
210 | 3. To the extent possible, the Licensor waives any right to
211 | collect royalties from You for the exercise of the Licensed
212 | Rights, whether directly or through a collecting society
213 | under any voluntary or waivable statutory or compulsory
214 | licensing scheme. In all other cases the Licensor expressly
215 | reserves any right to collect such royalties, including when
216 | the Licensed Material is used other than for NonCommercial
217 | purposes.
218 |
219 | Section 3 -- License Conditions.
220 |
221 | Your exercise of the Licensed Rights is expressly made subject to the
222 | following conditions.
223 |
224 | a. Attribution.
225 |
226 | 1. If You Share the Licensed Material (including in modified
227 | form), You must:
228 |
229 | a. retain the following if it is supplied by the Licensor
230 | with the Licensed Material:
231 |
232 | i. identification of the creator(s) of the Licensed
233 | Material and any others designated to receive
234 | attribution, in any reasonable manner requested by
235 | the Licensor (including by pseudonym if
236 | designated);
237 |
238 | ii. a copyright notice;
239 |
240 | iii. a notice that refers to this Public License;
241 |
242 | iv. a notice that refers to the disclaimer of
243 | warranties;
244 |
245 | v. a URI or hyperlink to the Licensed Material to the
246 | extent reasonably practicable;
247 |
248 | b. indicate if You modified the Licensed Material and
249 | retain an indication of any previous modifications; and
250 |
251 | c. indicate the Licensed Material is licensed under this
252 | Public License, and include the text of, or the URI or
253 | hyperlink to, this Public License.
254 |
255 | 2. You may satisfy the conditions in Section 3(a)(1) in any
256 | reasonable manner based on the medium, means, and context in
257 | which You Share the Licensed Material. For example, it may be
258 | reasonable to satisfy the conditions by providing a URI or
259 | hyperlink to a resource that includes the required
260 | information.
261 |
262 | 3. If requested by the Licensor, You must remove any of the
263 | information required by Section 3(a)(1)(A) to the extent
264 | reasonably practicable.
265 |
266 | 4. If You Share Adapted Material You produce, the Adapter's
267 | License You apply must not prevent recipients of the Adapted
268 | Material from complying with this Public License.
269 |
270 | Section 4 -- Sui Generis Database Rights.
271 |
272 | Where the Licensed Rights include Sui Generis Database Rights that
273 | apply to Your use of the Licensed Material:
274 |
275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276 | to extract, reuse, reproduce, and Share all or a substantial
277 | portion of the contents of the database for NonCommercial purposes
278 | only;
279 |
280 | b. if You include all or a substantial portion of the database
281 | contents in a database in which You have Sui Generis Database
282 | Rights, then the database in which You have Sui Generis Database
283 | Rights (but not its individual contents) is Adapted Material; and
284 |
285 | c. You must comply with the conditions in Section 3(a) if You Share
286 | all or a substantial portion of the contents of the database.
287 |
288 | For the avoidance of doubt, this Section 4 supplements and does not
289 | replace Your obligations under this Public License where the Licensed
290 | Rights include other Copyright and Similar Rights.
291 |
292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304 |
305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314 |
315 | c. The disclaimer of warranties and limitation of liability provided
316 | above shall be interpreted in a manner that, to the extent
317 | possible, most closely approximates an absolute disclaimer and
318 | waiver of all liability.
319 |
320 | Section 6 -- Term and Termination.
321 |
322 | a. This Public License applies for the term of the Copyright and
323 | Similar Rights licensed here. However, if You fail to comply with
324 | this Public License, then Your rights under this Public License
325 | terminate automatically.
326 |
327 | b. Where Your right to use the Licensed Material has terminated under
328 | Section 6(a), it reinstates:
329 |
330 | 1. automatically as of the date the violation is cured, provided
331 | it is cured within 30 days of Your discovery of the
332 | violation; or
333 |
334 | 2. upon express reinstatement by the Licensor.
335 |
336 | For the avoidance of doubt, this Section 6(b) does not affect any
337 | right the Licensor may have to seek remedies for Your violations
338 | of this Public License.
339 |
340 | c. For the avoidance of doubt, the Licensor may also offer the
341 | Licensed Material under separate terms or conditions or stop
342 | distributing the Licensed Material at any time; however, doing so
343 | will not terminate this Public License.
344 |
345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346 | License.
347 |
348 | Section 7 -- Other Terms and Conditions.
349 |
350 | a. The Licensor shall not be bound by any additional or different
351 | terms or conditions communicated by You unless expressly agreed.
352 |
353 | b. Any arrangements, understandings, or agreements regarding the
354 | Licensed Material not stated herein are separate from and
355 | independent of the terms and conditions of this Public License.
356 |
357 | Section 8 -- Interpretation.
358 |
359 | a. For the avoidance of doubt, this Public License does not, and
360 | shall not be interpreted to, reduce, limit, restrict, or impose
361 | conditions on any use of the Licensed Material that could lawfully
362 | be made without permission under this Public License.
363 |
364 | b. To the extent possible, if any provision of this Public License is
365 | deemed unenforceable, it shall be automatically reformed to the
366 | minimum extent necessary to make it enforceable. If the provision
367 | cannot be reformed, it shall be severed from this Public License
368 | without affecting the enforceability of the remaining terms and
369 | conditions.
370 |
371 | c. No term or condition of this Public License will be waived and no
372 | failure to comply consented to unless expressly agreed to by the
373 | Licensor.
374 |
375 | d. Nothing in this Public License constitutes or may be interpreted
376 | as a limitation upon, or waiver of, any privileges and immunities
377 | that apply to the Licensor or You, including from the legal
378 | processes of any jurisdiction or authority.
379 |
380 | =======================================================================
381 |
382 | Creative Commons is not a party to its public
383 | licenses. Notwithstanding, Creative Commons may elect to apply one of
384 | its public licenses to material it publishes and in those instances
385 | will be considered the “Licensor.” The text of the Creative Commons
386 | public licenses is dedicated to the public domain under the CC0 Public
387 | Domain Dedication. Except for the limited purpose of indicating that
388 | material is shared under a Creative Commons public license or as
389 | otherwise permitted by the Creative Commons policies published at
390 | creativecommons.org/policies, Creative Commons does not authorize the
391 | use of the trademark "Creative Commons" or any other trademark or logo
392 | of Creative Commons without its prior written consent including,
393 | without limitation, in connection with any unauthorized modifications
394 | to any of its public licenses or any other arrangements,
395 | understandings, or agreements concerning use of licensed material. For
396 | the avoidance of doubt, this paragraph does not form part of the
397 | public licenses.
398 |
399 | Creative Commons may be contacted at creativecommons.org.
400 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include assets/*
2 | include inputs/*
3 | include scripts/*.py
4 | include inference_gfpgan.py
5 | include VERSION
6 | include LICENSE
7 | include requirements.txt
8 | include gfpgan/weights/README.md
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VGGFace2-HQ
2 |
3 | Related paper: [TPAMI](https://github.com/neuralchen/SimSwapPlus)
4 |
5 | ## The first open source high resolution dataset for face swapping!!!
6 |
7 | A high resolution version of [VGGFace2](https://github.com/ox-vgg/vgg_face2) for academic face editing purpose.This project uses [GFPGAN](https://github.com/TencentARC/GFPGAN) for image restoration and [insightface](https://github.com/deepinsight/insightface) for data preprocessing (crop and align).
8 |
9 | [](https://github.com/NNNNAI/VGGFace2-HQ)
10 |
11 | We provide a download link for users to download the data, and also provide guidance on how to generate the VGGFace2 dataset from scratch.
12 |
13 | If you find this project useful, please star it. It is the greatest appreciation of our work.
14 |
15 |
16 |
17 | # Get the VGGFace2-HQ dataset from cloud!
18 |
19 | We have uploaded the dataset of VGGFace2 HQ to the cloud, and you can download it from the cloud.
20 |
21 | ### Google Drive
22 |
23 | [[Google Drive]](https://drive.google.com/drive/folders/1ZHy7jrd6cGb2lUa4qYugXe41G_Ef9Ibw?usp=sharing)
24 |
25 | ***We are especially grateful to [Kairui Feng](https://scholar.google.com.hk/citations?user=4N5hE8YAAAAJ&hl=zh-CN) PhD student from Princeton University.***
26 |
27 | ### Baidu Drive
28 |
29 | [[Baidu Drive]](https://pan.baidu.com/s/1LwPFhgbdBj5AeoPTXgoqDw) Password: ```sjtu```
30 |
31 |
32 | # Generate the HQ dataset by yourself. (If you want to do so)
33 | ## Preparation
34 | ### Installation
35 | **We highly recommand that you use Anaconda for Installation**
36 | ```
37 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch
38 | pip install insightface==0.2.1 onnxruntime
39 | (optional) pip install onnxruntime-gpu==1.2.0
40 |
41 | pip install basicsr
42 | pip install facexlib
43 | pip install -r requirements.txt
44 | python setup.py develop
45 | ```
46 | - The pytorch and cuda versions above are most recommanded. They may vary.
47 | - Using insightface with different versions is not recommanded. Please use this specific version.
48 | - These settings are tested valid on both Windows and Ununtu.
49 | ### Pretrained model
50 | - We use the face detection and alignment methods from **[insightface](https://github.com/deepinsight/insightface)** for image preprocessing. Please download the relative files and unzip them to **./insightface_func/models** from [this link](https://onedrive.live.com/?authkey=%21ADJ0aAOSsc90neY&cid=4A83B6B633B029CC&id=4A83B6B633B029CC%215837&parId=4A83B6B633B029CC%215834&action=locate).
51 | - Download [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) from GFPGAN offical repo. Place "GFPGANCleanv1-NoCE-C2.pth" in **./experiments/pretrained_models**.
52 |
53 | ### Data preparation
54 | - Download VGGFace2 Dataset from [VGGFace2 Dataset for Face Recognition](https://github.com/ox-vgg/vgg_face2)
55 |
56 | ## Inference
57 |
58 | - Frist, perform data preprocessing on all photos in VGGFACE2, that is, detect faces and align them to the same alignment format as FFHQdataset.
59 | ```
60 | python scripts/crop_align_vggface2_FFHQalign.py --input_dir $DATAPATH$/VGGface2/train --output_dir_ffhqalign $ALIGN_OUTDIR$ --mode ffhq --crop_size 256
61 | ```
62 | - And then, do the magic of image restoration with GFPGAN for processed photos.
63 | ```
64 | python scripts/inference_gfpgan_forvggface2.py --input_path $ALIGN_OUTDIR$ --batchSize 8 --save_dir $HQ_OUTDIR$
65 | ```
66 |
67 | ## Citation
68 |
69 | If you find our work useful in your research, please consider citing:
70 |
71 | ```
72 | @Article{simswapplusplus,
73 | author = {Xuanhong Chen and
74 | Bingbing Ni and
75 | Yutian Liu and
76 | Naiyuan Liu and
77 | Zhilin Zeng and
78 | Hang Wang},
79 | title = {SimSwap++: Towards Faster and High-Quality Identity Swapping},
80 | journal = {{IEEE} Trans. Pattern Anal. Mach. Intell.},
81 | volume = {46},
82 | number = {1},
83 | pages = {576--592},
84 | year = {2024}
85 | }
86 | ```
87 |
88 | ## Related Projects
89 |
90 | ***Please visit our popular face swapping project***
91 |
92 | [](https://github.com/neuralchen/SimSwap)
93 |
94 | ***Please visit our another ACMMM2020 high-quality style transfer project***
95 |
96 | [](https://github.com/neuralchen/ASMAGAN)
97 |
98 | [](https://github.com/neuralchen/ASMAGAN)
99 |
100 | ***Please visit our AAAI2021 sketch based rendering project***
101 |
102 | [](https://github.com/TZYSJTU/Sketch-Generation-with-Drawing-Process-Guided-by-Vector-Flow-and-Grayscale)
103 | [](https://github.com/TZYSJTU/Sketch-Generation-with-Drawing-Process-Guided-by-Vector-Flow-and-Grayscale)
104 |
105 | Learn about our other projects
106 |
107 | [[VGGFace2-HQ]](https://github.com/NNNNAI/VGGFace2-HQ);
108 |
109 | [[RainNet]](https://neuralchen.github.io/RainNet);
110 |
111 | [[Sketch Generation]](https://github.com/TZYSJTU/Sketch-Generation-with-Drawing-Process-Guided-by-Vector-Flow-and-Grayscale);
112 |
113 | [[CooGAN]](https://github.com/neuralchen/CooGAN);
114 |
115 | [[Knowledge Style Transfer]](https://github.com/AceSix/Knowledge_Transfer);
116 |
117 | [[SimSwap]](https://github.com/neuralchen/SimSwap);
118 |
119 | [[ASMA-GAN]](https://github.com/neuralchen/ASMAGAN);
120 |
121 | [[SNGAN-Projection-pytorch]](https://github.com/neuralchen/SNGAN_Projection)
122 |
123 | [[Pretrained_VGG19]](https://github.com/neuralchen/Pretrained_VGG19).
124 |
125 |
126 |
127 | # Acknowledgements
128 |
129 |
130 | * [GFPGAN](https://github.com/TencentARC/GFPGAN)
131 | * [Insightface](https://github.com/deepinsight/insightface)
132 | * [VGGFace2 Dataset for Face Recognition](https://github.com/ox-vgg/vgg_face2)
133 |
134 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.2.3
2 |
--------------------------------------------------------------------------------
/VGGFace2-HQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/VGGFace2-HQ.png
--------------------------------------------------------------------------------
/VGGFace2-HQ.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/VGGFace2-HQ.pptx
--------------------------------------------------------------------------------
/docs/img/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/2.png
--------------------------------------------------------------------------------
/docs/img/VGGFace2-HQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/VGGFace2-HQ.png
--------------------------------------------------------------------------------
/docs/img/girl2-RGB.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/girl2-RGB.png
--------------------------------------------------------------------------------
/docs/img/girl2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/girl2.gif
--------------------------------------------------------------------------------
/docs/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/logo.png
--------------------------------------------------------------------------------
/docs/img/simswap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/simswap.png
--------------------------------------------------------------------------------
/docs/img/title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/title.png
--------------------------------------------------------------------------------
/docs/img/vggface2_hq_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/docs/img/vggface2_hq_compare.png
--------------------------------------------------------------------------------
/gfpgan/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .archs import *
3 | from .data import *
4 | from .models import *
5 | from .utils import *
6 | from .version import __gitsha__, __version__
7 |
--------------------------------------------------------------------------------
/gfpgan/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from basicsr.utils import scandir
3 | from os import path as osp
4 |
5 | # automatically scan and import arch modules for registry
6 | # scan all the files that end with '_arch.py' under the archs folder
7 | arch_folder = osp.dirname(osp.abspath(__file__))
8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
9 | # import all the arch modules
10 | _arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
11 |
--------------------------------------------------------------------------------
/gfpgan/archs/arcface_arch.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from basicsr.utils.registry import ARCH_REGISTRY
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, inplanes, planes, stride=1, downsample=None):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = conv3x3(inplanes, planes, stride)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.conv2 = conv3x3(planes, planes)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | self.downsample = downsample
21 | self.stride = stride
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 |
33 | if self.downsample is not None:
34 | residual = self.downsample(x)
35 |
36 | out += residual
37 | out = self.relu(out)
38 |
39 | return out
40 |
41 |
42 | class IRBlock(nn.Module):
43 | expansion = 1
44 |
45 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
46 | super(IRBlock, self).__init__()
47 | self.bn0 = nn.BatchNorm2d(inplanes)
48 | self.conv1 = conv3x3(inplanes, inplanes)
49 | self.bn1 = nn.BatchNorm2d(inplanes)
50 | self.prelu = nn.PReLU()
51 | self.conv2 = conv3x3(inplanes, planes, stride)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 | self.use_se = use_se
56 | if self.use_se:
57 | self.se = SEBlock(planes)
58 |
59 | def forward(self, x):
60 | residual = x
61 | out = self.bn0(x)
62 | out = self.conv1(out)
63 | out = self.bn1(out)
64 | out = self.prelu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 | if self.use_se:
69 | out = self.se(out)
70 |
71 | if self.downsample is not None:
72 | residual = self.downsample(x)
73 |
74 | out += residual
75 | out = self.prelu(out)
76 |
77 | return out
78 |
79 |
80 | class Bottleneck(nn.Module):
81 | expansion = 4
82 |
83 | def __init__(self, inplanes, planes, stride=1, downsample=None):
84 | super(Bottleneck, self).__init__()
85 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
86 | self.bn1 = nn.BatchNorm2d(planes)
87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
88 | self.bn2 = nn.BatchNorm2d(planes)
89 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
90 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | residual = x
97 |
98 | out = self.conv1(x)
99 | out = self.bn1(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv3(out)
107 | out = self.bn3(out)
108 |
109 | if self.downsample is not None:
110 | residual = self.downsample(x)
111 |
112 | out += residual
113 | out = self.relu(out)
114 |
115 | return out
116 |
117 |
118 | class SEBlock(nn.Module):
119 |
120 | def __init__(self, channel, reduction=16):
121 | super(SEBlock, self).__init__()
122 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
123 | self.fc = nn.Sequential(
124 | nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
125 | nn.Sigmoid())
126 |
127 | def forward(self, x):
128 | b, c, _, _ = x.size()
129 | y = self.avg_pool(x).view(b, c)
130 | y = self.fc(y).view(b, c, 1, 1)
131 | return x * y
132 |
133 |
134 | @ARCH_REGISTRY.register()
135 | class ResNetArcFace(nn.Module):
136 |
137 | def __init__(self, block, layers, use_se=True):
138 | if block == 'IRBlock':
139 | block = IRBlock
140 | self.inplanes = 64
141 | self.use_se = use_se
142 | super(ResNetArcFace, self).__init__()
143 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
144 | self.bn1 = nn.BatchNorm2d(64)
145 | self.prelu = nn.PReLU()
146 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
147 | self.layer1 = self._make_layer(block, 64, layers[0])
148 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
149 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
151 | self.bn4 = nn.BatchNorm2d(512)
152 | self.dropout = nn.Dropout()
153 | self.fc5 = nn.Linear(512 * 8 * 8, 512)
154 | self.bn5 = nn.BatchNorm1d(512)
155 |
156 | for m in self.modules():
157 | if isinstance(m, nn.Conv2d):
158 | nn.init.xavier_normal_(m.weight)
159 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
160 | nn.init.constant_(m.weight, 1)
161 | nn.init.constant_(m.bias, 0)
162 | elif isinstance(m, nn.Linear):
163 | nn.init.xavier_normal_(m.weight)
164 | nn.init.constant_(m.bias, 0)
165 |
166 | def _make_layer(self, block, planes, blocks, stride=1):
167 | downsample = None
168 | if stride != 1 or self.inplanes != planes * block.expansion:
169 | downsample = nn.Sequential(
170 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
171 | nn.BatchNorm2d(planes * block.expansion),
172 | )
173 | layers = []
174 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
175 | self.inplanes = planes
176 | for _ in range(1, blocks):
177 | layers.append(block(self.inplanes, planes, use_se=self.use_se))
178 |
179 | return nn.Sequential(*layers)
180 |
181 | def forward(self, x):
182 | x = self.conv1(x)
183 | x = self.bn1(x)
184 | x = self.prelu(x)
185 | x = self.maxpool(x)
186 |
187 | x = self.layer1(x)
188 | x = self.layer2(x)
189 | x = self.layer3(x)
190 | x = self.layer4(x)
191 | x = self.bn4(x)
192 | x = self.dropout(x)
193 | x = x.view(x.size(0), -1)
194 | x = self.fc5(x)
195 | x = self.bn5(x)
196 |
197 | return x
198 |
--------------------------------------------------------------------------------
/gfpgan/archs/gfpganv1_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import torch
4 | from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
5 | StyleGAN2Generator)
6 | from basicsr.ops.fused_act import FusedLeakyReLU
7 | from basicsr.utils.registry import ARCH_REGISTRY
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 |
12 | class StyleGAN2GeneratorSFT(StyleGAN2Generator):
13 | """StyleGAN2 Generator.
14 |
15 | Args:
16 | out_size (int): The spatial size of outputs.
17 | num_style_feat (int): Channel number of style features. Default: 512.
18 | num_mlp (int): Layer number of MLP style layers. Default: 8.
19 | channel_multiplier (int): Channel multiplier for large networks of
20 | StyleGAN2. Default: 2.
21 | resample_kernel (list[int]): A list indicating the 1D resample kernel
22 | magnitude. A cross production will be applied to extent 1D resample
23 | kernel to 2D resample kernel. Default: [1, 3, 3, 1].
24 | lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
25 | """
26 |
27 | def __init__(self,
28 | out_size,
29 | num_style_feat=512,
30 | num_mlp=8,
31 | channel_multiplier=2,
32 | resample_kernel=(1, 3, 3, 1),
33 | lr_mlp=0.01,
34 | narrow=1,
35 | sft_half=False):
36 | super(StyleGAN2GeneratorSFT, self).__init__(
37 | out_size,
38 | num_style_feat=num_style_feat,
39 | num_mlp=num_mlp,
40 | channel_multiplier=channel_multiplier,
41 | resample_kernel=resample_kernel,
42 | lr_mlp=lr_mlp,
43 | narrow=narrow)
44 | self.sft_half = sft_half
45 |
46 | def forward(self,
47 | styles,
48 | conditions,
49 | input_is_latent=False,
50 | noise=None,
51 | randomize_noise=True,
52 | truncation=1,
53 | truncation_latent=None,
54 | inject_index=None,
55 | return_latents=False):
56 | """Forward function for StyleGAN2Generator.
57 |
58 | Args:
59 | styles (list[Tensor]): Sample codes of styles.
60 | input_is_latent (bool): Whether input is latent style.
61 | Default: False.
62 | noise (Tensor | None): Input noise or None. Default: None.
63 | randomize_noise (bool): Randomize noise, used when 'noise' is
64 | False. Default: True.
65 | truncation (float): TODO. Default: 1.
66 | truncation_latent (Tensor | None): TODO. Default: None.
67 | inject_index (int | None): The injection index for mixing noise.
68 | Default: None.
69 | return_latents (bool): Whether to return style latents.
70 | Default: False.
71 | """
72 | # style codes -> latents with Style MLP layer
73 | if not input_is_latent:
74 | styles = [self.style_mlp(s) for s in styles]
75 | # noises
76 | if noise is None:
77 | if randomize_noise:
78 | noise = [None] * self.num_layers # for each style conv layer
79 | else: # use the stored noise
80 | noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
81 | # style truncation
82 | if truncation < 1:
83 | style_truncation = []
84 | for style in styles:
85 | style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
86 | styles = style_truncation
87 | # get style latent with injection
88 | if len(styles) == 1:
89 | inject_index = self.num_latent
90 |
91 | if styles[0].ndim < 3:
92 | # repeat latent code for all the layers
93 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
94 | else: # used for encoder with different latent code for each layer
95 | latent = styles[0]
96 | elif len(styles) == 2: # mixing noises
97 | if inject_index is None:
98 | inject_index = random.randint(1, self.num_latent - 1)
99 | latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
100 | latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
101 | latent = torch.cat([latent1, latent2], 1)
102 |
103 | # main generation
104 | out = self.constant_input(latent.shape[0])
105 | out = self.style_conv1(out, latent[:, 0], noise=noise[0])
106 | skip = self.to_rgb1(out, latent[:, 1])
107 |
108 | i = 1
109 | for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
110 | noise[2::2], self.to_rgbs):
111 | out = conv1(out, latent[:, i], noise=noise1)
112 |
113 | # the conditions may have fewer levels
114 | if i < len(conditions):
115 | # SFT part to combine the conditions
116 | if self.sft_half:
117 | out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
118 | out_sft = out_sft * conditions[i - 1] + conditions[i]
119 | out = torch.cat([out_same, out_sft], dim=1)
120 | else:
121 | out = out * conditions[i - 1] + conditions[i]
122 |
123 | out = conv2(out, latent[:, i + 1], noise=noise2)
124 | skip = to_rgb(out, latent[:, i + 2], skip)
125 | i += 2
126 |
127 | image = skip
128 |
129 | if return_latents:
130 | return image, latent
131 | else:
132 | return image, None
133 |
134 |
135 | class ConvUpLayer(nn.Module):
136 | """Conv Up Layer. Bilinear upsample + Conv.
137 |
138 | Args:
139 | in_channels (int): Channel number of the input.
140 | out_channels (int): Channel number of the output.
141 | kernel_size (int): Size of the convolving kernel.
142 | stride (int): Stride of the convolution. Default: 1
143 | padding (int): Zero-padding added to both sides of the input.
144 | Default: 0.
145 | bias (bool): If ``True``, adds a learnable bias to the output.
146 | Default: ``True``.
147 | bias_init_val (float): Bias initialized value. Default: 0.
148 | activate (bool): Whether use activateion. Default: True.
149 | """
150 |
151 | def __init__(self,
152 | in_channels,
153 | out_channels,
154 | kernel_size,
155 | stride=1,
156 | padding=0,
157 | bias=True,
158 | bias_init_val=0,
159 | activate=True):
160 | super(ConvUpLayer, self).__init__()
161 | self.in_channels = in_channels
162 | self.out_channels = out_channels
163 | self.kernel_size = kernel_size
164 | self.stride = stride
165 | self.padding = padding
166 | self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
167 |
168 | self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
169 |
170 | if bias and not activate:
171 | self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
172 | else:
173 | self.register_parameter('bias', None)
174 |
175 | # activation
176 | if activate:
177 | if bias:
178 | self.activation = FusedLeakyReLU(out_channels)
179 | else:
180 | self.activation = ScaledLeakyReLU(0.2)
181 | else:
182 | self.activation = None
183 |
184 | def forward(self, x):
185 | # bilinear upsample
186 | out = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
187 | # conv
188 | out = F.conv2d(
189 | out,
190 | self.weight * self.scale,
191 | bias=self.bias,
192 | stride=self.stride,
193 | padding=self.padding,
194 | )
195 | # activation
196 | if self.activation is not None:
197 | out = self.activation(out)
198 | return out
199 |
200 |
201 | class ResUpBlock(nn.Module):
202 | """Residual block with upsampling.
203 |
204 | Args:
205 | in_channels (int): Channel number of the input.
206 | out_channels (int): Channel number of the output.
207 | """
208 |
209 | def __init__(self, in_channels, out_channels):
210 | super(ResUpBlock, self).__init__()
211 |
212 | self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
213 | self.conv2 = ConvUpLayer(in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True)
214 | self.skip = ConvUpLayer(in_channels, out_channels, 1, bias=False, activate=False)
215 |
216 | def forward(self, x):
217 | out = self.conv1(x)
218 | out = self.conv2(out)
219 | skip = self.skip(x)
220 | out = (out + skip) / math.sqrt(2)
221 | return out
222 |
223 |
224 | @ARCH_REGISTRY.register()
225 | class GFPGANv1(nn.Module):
226 | """Unet + StyleGAN2 decoder with SFT."""
227 |
228 | def __init__(
229 | self,
230 | out_size,
231 | num_style_feat=512,
232 | channel_multiplier=1,
233 | resample_kernel=(1, 3, 3, 1),
234 | decoder_load_path=None,
235 | fix_decoder=True,
236 | # for stylegan decoder
237 | num_mlp=8,
238 | lr_mlp=0.01,
239 | input_is_latent=False,
240 | different_w=False,
241 | narrow=1,
242 | sft_half=False):
243 |
244 | super(GFPGANv1, self).__init__()
245 | self.input_is_latent = input_is_latent
246 | self.different_w = different_w
247 | self.num_style_feat = num_style_feat
248 |
249 | unet_narrow = narrow * 0.5
250 | channels = {
251 | '4': int(512 * unet_narrow),
252 | '8': int(512 * unet_narrow),
253 | '16': int(512 * unet_narrow),
254 | '32': int(512 * unet_narrow),
255 | '64': int(256 * channel_multiplier * unet_narrow),
256 | '128': int(128 * channel_multiplier * unet_narrow),
257 | '256': int(64 * channel_multiplier * unet_narrow),
258 | '512': int(32 * channel_multiplier * unet_narrow),
259 | '1024': int(16 * channel_multiplier * unet_narrow)
260 | }
261 |
262 | self.log_size = int(math.log(out_size, 2))
263 | first_out_size = 2**(int(math.log(out_size, 2)))
264 |
265 | self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
266 |
267 | # downsample
268 | in_channels = channels[f'{first_out_size}']
269 | self.conv_body_down = nn.ModuleList()
270 | for i in range(self.log_size, 2, -1):
271 | out_channels = channels[f'{2**(i - 1)}']
272 | self.conv_body_down.append(ResBlock(in_channels, out_channels, resample_kernel))
273 | in_channels = out_channels
274 |
275 | self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
276 |
277 | # upsample
278 | in_channels = channels['4']
279 | self.conv_body_up = nn.ModuleList()
280 | for i in range(3, self.log_size + 1):
281 | out_channels = channels[f'{2**i}']
282 | self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
283 | in_channels = out_channels
284 |
285 | # to RGB
286 | self.toRGB = nn.ModuleList()
287 | for i in range(3, self.log_size + 1):
288 | self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
289 |
290 | if different_w:
291 | linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
292 | else:
293 | linear_out_channel = num_style_feat
294 |
295 | self.final_linear = EqualLinear(
296 | channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
297 |
298 | self.stylegan_decoder = StyleGAN2GeneratorSFT(
299 | out_size=out_size,
300 | num_style_feat=num_style_feat,
301 | num_mlp=num_mlp,
302 | channel_multiplier=channel_multiplier,
303 | resample_kernel=resample_kernel,
304 | lr_mlp=lr_mlp,
305 | narrow=narrow,
306 | sft_half=sft_half)
307 |
308 | if decoder_load_path:
309 | self.stylegan_decoder.load_state_dict(
310 | torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
311 | if fix_decoder:
312 | for _, param in self.stylegan_decoder.named_parameters():
313 | param.requires_grad = False
314 |
315 | # for SFT
316 | self.condition_scale = nn.ModuleList()
317 | self.condition_shift = nn.ModuleList()
318 | for i in range(3, self.log_size + 1):
319 | out_channels = channels[f'{2**i}']
320 | if sft_half:
321 | sft_out_channels = out_channels
322 | else:
323 | sft_out_channels = out_channels * 2
324 | self.condition_scale.append(
325 | nn.Sequential(
326 | EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
327 | ScaledLeakyReLU(0.2),
328 | EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
329 | self.condition_shift.append(
330 | nn.Sequential(
331 | EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
332 | ScaledLeakyReLU(0.2),
333 | EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
334 |
335 | def forward(self,
336 | x,
337 | return_latents=False,
338 | save_feat_path=None,
339 | load_feat_path=None,
340 | return_rgb=True,
341 | randomize_noise=True):
342 | conditions = []
343 | unet_skips = []
344 | out_rgbs = []
345 |
346 | # encoder
347 | feat = self.conv_body_first(x)
348 | for i in range(self.log_size - 2):
349 | feat = self.conv_body_down[i](feat)
350 | unet_skips.insert(0, feat)
351 |
352 | feat = self.final_conv(feat)
353 |
354 | # style code
355 | style_code = self.final_linear(feat.view(feat.size(0), -1))
356 | if self.different_w:
357 | style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
358 |
359 | # decode
360 | for i in range(self.log_size - 2):
361 | # add unet skip
362 | feat = feat + unet_skips[i]
363 | # ResUpLayer
364 | feat = self.conv_body_up[i](feat)
365 | # generate scale and shift for SFT layer
366 | scale = self.condition_scale[i](feat)
367 | conditions.append(scale.clone())
368 | shift = self.condition_shift[i](feat)
369 | conditions.append(shift.clone())
370 | # generate rgb images
371 | if return_rgb:
372 | out_rgbs.append(self.toRGB[i](feat))
373 |
374 | if save_feat_path is not None:
375 | torch.save(conditions, save_feat_path)
376 | if load_feat_path is not None:
377 | conditions = torch.load(load_feat_path)
378 | conditions = [v.cuda() for v in conditions]
379 |
380 | # decoder
381 | image, _ = self.stylegan_decoder([style_code],
382 | conditions,
383 | return_latents=return_latents,
384 | input_is_latent=self.input_is_latent,
385 | randomize_noise=randomize_noise)
386 |
387 | return image, out_rgbs
388 |
389 |
390 | @ARCH_REGISTRY.register()
391 | class FacialComponentDiscriminator(nn.Module):
392 |
393 | def __init__(self):
394 | super(FacialComponentDiscriminator, self).__init__()
395 |
396 | self.conv1 = ConvLayer(3, 64, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
397 | self.conv2 = ConvLayer(64, 128, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
398 | self.conv3 = ConvLayer(128, 128, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
399 | self.conv4 = ConvLayer(128, 256, 3, downsample=True, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
400 | self.conv5 = ConvLayer(256, 256, 3, downsample=False, resample_kernel=(1, 3, 3, 1), bias=True, activate=True)
401 | self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
402 |
403 | def forward(self, x, return_feats=False):
404 | feat = self.conv1(x)
405 | feat = self.conv3(self.conv2(feat))
406 | rlt_feats = []
407 | if return_feats:
408 | rlt_feats.append(feat.clone())
409 | feat = self.conv5(self.conv4(feat))
410 | if return_feats:
411 | rlt_feats.append(feat.clone())
412 | out = self.final_conv(feat)
413 |
414 | if return_feats:
415 | return out, rlt_feats
416 | else:
417 | return out, None
418 |
--------------------------------------------------------------------------------
/gfpgan/archs/gfpganv1_clean_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from .stylegan2_clean_arch import StyleGAN2GeneratorClean
8 |
9 |
10 | class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
11 | """StyleGAN2 Generator.
12 |
13 | Args:
14 | out_size (int): The spatial size of outputs.
15 | num_style_feat (int): Channel number of style features. Default: 512.
16 | num_mlp (int): Layer number of MLP style layers. Default: 8.
17 | channel_multiplier (int): Channel multiplier for large networks of
18 | StyleGAN2. Default: 2.
19 | """
20 |
21 | def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1, sft_half=False):
22 | super(StyleGAN2GeneratorCSFT, self).__init__(
23 | out_size,
24 | num_style_feat=num_style_feat,
25 | num_mlp=num_mlp,
26 | channel_multiplier=channel_multiplier,
27 | narrow=narrow)
28 |
29 | self.sft_half = sft_half
30 |
31 | def forward(self,
32 | styles,
33 | conditions,
34 | input_is_latent=False,
35 | noise=None,
36 | randomize_noise=True,
37 | truncation=1,
38 | truncation_latent=None,
39 | inject_index=None,
40 | return_latents=False):
41 | """Forward function for StyleGAN2Generator.
42 |
43 | Args:
44 | styles (list[Tensor]): Sample codes of styles.
45 | input_is_latent (bool): Whether input is latent style.
46 | Default: False.
47 | noise (Tensor | None): Input noise or None. Default: None.
48 | randomize_noise (bool): Randomize noise, used when 'noise' is
49 | False. Default: True.
50 | truncation (float): TODO. Default: 1.
51 | truncation_latent (Tensor | None): TODO. Default: None.
52 | inject_index (int | None): The injection index for mixing noise.
53 | Default: None.
54 | return_latents (bool): Whether to return style latents.
55 | Default: False.
56 | """
57 | # style codes -> latents with Style MLP layer
58 | if not input_is_latent:
59 | styles = [self.style_mlp(s) for s in styles]
60 | # noises
61 | if noise is None:
62 | if randomize_noise:
63 | noise = [None] * self.num_layers # for each style conv layer
64 | else: # use the stored noise
65 | noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
66 | # style truncation
67 | if truncation < 1:
68 | style_truncation = []
69 | for style in styles:
70 | style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
71 | styles = style_truncation
72 | # get style latent with injection
73 | if len(styles) == 1:
74 | inject_index = self.num_latent
75 |
76 | if styles[0].ndim < 3:
77 | # repeat latent code for all the layers
78 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
79 | else: # used for encoder with different latent code for each layer
80 | latent = styles[0]
81 | elif len(styles) == 2: # mixing noises
82 | if inject_index is None:
83 | inject_index = random.randint(1, self.num_latent - 1)
84 | latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
85 | latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
86 | latent = torch.cat([latent1, latent2], 1)
87 |
88 | # main generation
89 | out = self.constant_input(latent.shape[0])
90 | out = self.style_conv1(out, latent[:, 0], noise=noise[0])
91 | skip = self.to_rgb1(out, latent[:, 1])
92 |
93 | i = 1
94 | for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
95 | noise[2::2], self.to_rgbs):
96 | out = conv1(out, latent[:, i], noise=noise1)
97 |
98 | # the conditions may have fewer levels
99 | if i < len(conditions):
100 | # SFT part to combine the conditions
101 | if self.sft_half:
102 | out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
103 | out_sft = out_sft * conditions[i - 1] + conditions[i]
104 | out = torch.cat([out_same, out_sft], dim=1)
105 | else:
106 | out = out * conditions[i - 1] + conditions[i]
107 |
108 | out = conv2(out, latent[:, i + 1], noise=noise2)
109 | skip = to_rgb(out, latent[:, i + 2], skip)
110 | i += 2
111 |
112 | image = skip
113 |
114 | if return_latents:
115 | return image, latent
116 | else:
117 | return image, None
118 |
119 |
120 | class ResBlock(nn.Module):
121 | """Residual block with upsampling/downsampling.
122 |
123 | Args:
124 | in_channels (int): Channel number of the input.
125 | out_channels (int): Channel number of the output.
126 | """
127 |
128 | def __init__(self, in_channels, out_channels, mode='down'):
129 | super(ResBlock, self).__init__()
130 |
131 | self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
132 | self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
133 | self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
134 | if mode == 'down':
135 | self.scale_factor = 0.5
136 | elif mode == 'up':
137 | self.scale_factor = 2
138 |
139 | def forward(self, x):
140 | out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
141 | # upsample/downsample
142 | out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
143 | out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
144 | # skip
145 | x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
146 | skip = self.skip(x)
147 | out = out + skip
148 | return out
149 |
150 |
151 | class GFPGANv1Clean(nn.Module):
152 | """GFPGANv1 Clean version."""
153 |
154 | def __init__(
155 | self,
156 | out_size,
157 | num_style_feat=512,
158 | channel_multiplier=1,
159 | decoder_load_path=None,
160 | fix_decoder=True,
161 | # for stylegan decoder
162 | num_mlp=8,
163 | input_is_latent=False,
164 | different_w=False,
165 | narrow=1,
166 | sft_half=False):
167 |
168 | super(GFPGANv1Clean, self).__init__()
169 | self.input_is_latent = input_is_latent
170 | self.different_w = different_w
171 | self.num_style_feat = num_style_feat
172 |
173 | unet_narrow = narrow * 0.5
174 | channels = {
175 | '4': int(512 * unet_narrow),
176 | '8': int(512 * unet_narrow),
177 | '16': int(512 * unet_narrow),
178 | '32': int(512 * unet_narrow),
179 | '64': int(256 * channel_multiplier * unet_narrow),
180 | '128': int(128 * channel_multiplier * unet_narrow),
181 | '256': int(64 * channel_multiplier * unet_narrow),
182 | '512': int(32 * channel_multiplier * unet_narrow),
183 | '1024': int(16 * channel_multiplier * unet_narrow)
184 | }
185 |
186 | self.log_size = int(math.log(out_size, 2))
187 | first_out_size = 2**(int(math.log(out_size, 2)))
188 |
189 | self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1)
190 |
191 | # downsample
192 | in_channels = channels[f'{first_out_size}']
193 | self.conv_body_down = nn.ModuleList()
194 | for i in range(self.log_size, 2, -1):
195 | out_channels = channels[f'{2**(i - 1)}']
196 | self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
197 | in_channels = out_channels
198 |
199 | self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
200 |
201 | # upsample
202 | in_channels = channels['4']
203 | self.conv_body_up = nn.ModuleList()
204 | for i in range(3, self.log_size + 1):
205 | out_channels = channels[f'{2**i}']
206 | self.conv_body_up.append(ResBlock(in_channels, out_channels, mode='up'))
207 | in_channels = out_channels
208 |
209 | # to RGB
210 | self.toRGB = nn.ModuleList()
211 | for i in range(3, self.log_size + 1):
212 | self.toRGB.append(nn.Conv2d(channels[f'{2**i}'], 3, 1))
213 |
214 | if different_w:
215 | linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
216 | else:
217 | linear_out_channel = num_style_feat
218 |
219 | self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
220 |
221 | self.stylegan_decoder = StyleGAN2GeneratorCSFT(
222 | out_size=out_size,
223 | num_style_feat=num_style_feat,
224 | num_mlp=num_mlp,
225 | channel_multiplier=channel_multiplier,
226 | narrow=narrow,
227 | sft_half=sft_half)
228 |
229 | if decoder_load_path:
230 | self.stylegan_decoder.load_state_dict(
231 | torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
232 | if fix_decoder:
233 | for _, param in self.stylegan_decoder.named_parameters():
234 | param.requires_grad = False
235 |
236 | # for SFT
237 | self.condition_scale = nn.ModuleList()
238 | self.condition_shift = nn.ModuleList()
239 | for i in range(3, self.log_size + 1):
240 | out_channels = channels[f'{2**i}']
241 | if sft_half:
242 | sft_out_channels = out_channels
243 | else:
244 | sft_out_channels = out_channels * 2
245 | self.condition_scale.append(
246 | nn.Sequential(
247 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
248 | nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
249 | self.condition_shift.append(
250 | nn.Sequential(
251 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.LeakyReLU(0.2, True),
252 | nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1)))
253 |
254 | def forward(self,
255 | x,
256 | return_latents=False,
257 | save_feat_path=None,
258 | load_feat_path=None,
259 | return_rgb=True,
260 | randomize_noise=True):
261 | conditions = []
262 | unet_skips = []
263 | out_rgbs = []
264 |
265 | # encoder
266 | feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
267 | for i in range(self.log_size - 2):
268 | feat = self.conv_body_down[i](feat)
269 | unet_skips.insert(0, feat)
270 | feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
271 |
272 | # style code
273 | style_code = self.final_linear(feat.view(feat.size(0), -1))
274 | if self.different_w:
275 | style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
276 | # decode
277 | for i in range(self.log_size - 2):
278 | # add unet skip
279 | feat = feat + unet_skips[i]
280 | # ResUpLayer
281 | feat = self.conv_body_up[i](feat)
282 | # generate scale and shift for SFT layer
283 | scale = self.condition_scale[i](feat)
284 | conditions.append(scale.clone())
285 | shift = self.condition_shift[i](feat)
286 | conditions.append(shift.clone())
287 | # generate rgb images
288 | if return_rgb:
289 | out_rgbs.append(self.toRGB[i](feat))
290 |
291 | if save_feat_path is not None:
292 | torch.save(conditions, save_feat_path)
293 | if load_feat_path is not None:
294 | conditions = torch.load(load_feat_path)
295 | conditions = [v.cuda() for v in conditions]
296 |
297 | # decoder
298 | image, _ = self.stylegan_decoder([style_code],
299 | conditions,
300 | return_latents=return_latents,
301 | input_is_latent=self.input_is_latent,
302 | randomize_noise=randomize_noise)
303 |
304 | return image, out_rgbs
305 |
--------------------------------------------------------------------------------
/gfpgan/archs/stylegan2_clean_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import torch
4 | from basicsr.archs.arch_util import default_init_weights
5 | from basicsr.utils.registry import ARCH_REGISTRY
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 |
10 | class NormStyleCode(nn.Module):
11 |
12 | def forward(self, x):
13 | """Normalize the style codes.
14 |
15 | Args:
16 | x (Tensor): Style codes with shape (b, c).
17 |
18 | Returns:
19 | Tensor: Normalized tensor.
20 | """
21 | return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22 |
23 |
24 | class ModulatedConv2d(nn.Module):
25 | """Modulated Conv2d used in StyleGAN2.
26 |
27 | There is no bias in ModulatedConv2d.
28 |
29 | Args:
30 | in_channels (int): Channel number of the input.
31 | out_channels (int): Channel number of the output.
32 | kernel_size (int): Size of the convolving kernel.
33 | num_style_feat (int): Channel number of style features.
34 | demodulate (bool): Whether to demodulate in the conv layer.
35 | Default: True.
36 | sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
37 | Default: None.
38 | eps (float): A value added to the denominator for numerical stability.
39 | Default: 1e-8.
40 | """
41 |
42 | def __init__(self,
43 | in_channels,
44 | out_channels,
45 | kernel_size,
46 | num_style_feat,
47 | demodulate=True,
48 | sample_mode=None,
49 | eps=1e-8):
50 | super(ModulatedConv2d, self).__init__()
51 | self.in_channels = in_channels
52 | self.out_channels = out_channels
53 | self.kernel_size = kernel_size
54 | self.demodulate = demodulate
55 | self.sample_mode = sample_mode
56 | self.eps = eps
57 |
58 | # modulation inside each modulated conv
59 | self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
60 | # initialization
61 | default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear')
62 |
63 | self.weight = nn.Parameter(
64 | torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) /
65 | math.sqrt(in_channels * kernel_size**2))
66 | self.padding = kernel_size // 2
67 |
68 | def forward(self, x, style):
69 | """Forward function.
70 |
71 | Args:
72 | x (Tensor): Tensor with shape (b, c, h, w).
73 | style (Tensor): Tensor with shape (b, num_style_feat).
74 |
75 | Returns:
76 | Tensor: Modulated tensor after convolution.
77 | """
78 | b, c, h, w = x.shape # c = c_in
79 | # weight modulation
80 | style = self.modulation(style).view(b, 1, c, 1, 1)
81 | # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
82 | weight = self.weight * style # (b, c_out, c_in, k, k)
83 |
84 | if self.demodulate:
85 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
86 | weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
87 |
88 | weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
89 |
90 | if self.sample_mode == 'upsample':
91 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
92 | elif self.sample_mode == 'downsample':
93 | x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
94 |
95 | b, c, h, w = x.shape
96 | x = x.view(1, b * c, h, w)
97 | # weight: (b*c_out, c_in, k, k), groups=b
98 | out = F.conv2d(x, weight, padding=self.padding, groups=b)
99 | out = out.view(b, self.out_channels, *out.shape[2:4])
100 |
101 | return out
102 |
103 | def __repr__(self):
104 | return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
105 | f'out_channels={self.out_channels}, '
106 | f'kernel_size={self.kernel_size}, '
107 | f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
108 |
109 |
110 | class StyleConv(nn.Module):
111 | """Style conv.
112 |
113 | Args:
114 | in_channels (int): Channel number of the input.
115 | out_channels (int): Channel number of the output.
116 | kernel_size (int): Size of the convolving kernel.
117 | num_style_feat (int): Channel number of style features.
118 | demodulate (bool): Whether demodulate in the conv layer. Default: True.
119 | sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
120 | Default: None.
121 | """
122 |
123 | def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None):
124 | super(StyleConv, self).__init__()
125 | self.modulated_conv = ModulatedConv2d(
126 | in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode)
127 | self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
128 | self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
129 | self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
130 |
131 | def forward(self, x, style, noise=None):
132 | # modulate
133 | out = self.modulated_conv(x, style) * 2**0.5 # for conversion
134 | # noise injection
135 | if noise is None:
136 | b, _, h, w = out.shape
137 | noise = out.new_empty(b, 1, h, w).normal_()
138 | out = out + self.weight * noise
139 | # add bias
140 | out = out + self.bias
141 | # activation
142 | out = self.activate(out)
143 | return out
144 |
145 |
146 | class ToRGB(nn.Module):
147 | """To RGB from features.
148 |
149 | Args:
150 | in_channels (int): Channel number of input.
151 | num_style_feat (int): Channel number of style features.
152 | upsample (bool): Whether to upsample. Default: True.
153 | """
154 |
155 | def __init__(self, in_channels, num_style_feat, upsample=True):
156 | super(ToRGB, self).__init__()
157 | self.upsample = upsample
158 | self.modulated_conv = ModulatedConv2d(
159 | in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
160 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
161 |
162 | def forward(self, x, style, skip=None):
163 | """Forward function.
164 |
165 | Args:
166 | x (Tensor): Feature tensor with shape (b, c, h, w).
167 | style (Tensor): Tensor with shape (b, num_style_feat).
168 | skip (Tensor): Base/skip tensor. Default: None.
169 |
170 | Returns:
171 | Tensor: RGB images.
172 | """
173 | out = self.modulated_conv(x, style)
174 | out = out + self.bias
175 | if skip is not None:
176 | if self.upsample:
177 | skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False)
178 | out = out + skip
179 | return out
180 |
181 |
182 | class ConstantInput(nn.Module):
183 | """Constant input.
184 |
185 | Args:
186 | num_channel (int): Channel number of constant input.
187 | size (int): Spatial size of constant input.
188 | """
189 |
190 | def __init__(self, num_channel, size):
191 | super(ConstantInput, self).__init__()
192 | self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
193 |
194 | def forward(self, batch):
195 | out = self.weight.repeat(batch, 1, 1, 1)
196 | return out
197 |
198 |
199 | @ARCH_REGISTRY.register()
200 | class StyleGAN2GeneratorClean(nn.Module):
201 | """Clean version of StyleGAN2 Generator.
202 |
203 | Args:
204 | out_size (int): The spatial size of outputs.
205 | num_style_feat (int): Channel number of style features. Default: 512.
206 | num_mlp (int): Layer number of MLP style layers. Default: 8.
207 | channel_multiplier (int): Channel multiplier for large networks of
208 | StyleGAN2. Default: 2.
209 | narrow (float): Narrow ratio for channels. Default: 1.0.
210 | """
211 |
212 | def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
213 | super(StyleGAN2GeneratorClean, self).__init__()
214 | # Style MLP layers
215 | self.num_style_feat = num_style_feat
216 | style_mlp_layers = [NormStyleCode()]
217 | for i in range(num_mlp):
218 | style_mlp_layers.extend(
219 | [nn.Linear(num_style_feat, num_style_feat, bias=True),
220 | nn.LeakyReLU(negative_slope=0.2, inplace=True)])
221 | self.style_mlp = nn.Sequential(*style_mlp_layers)
222 | # initialization
223 | default_init_weights(self.style_mlp, scale=1, bias_fill=0, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
224 |
225 | channels = {
226 | '4': int(512 * narrow),
227 | '8': int(512 * narrow),
228 | '16': int(512 * narrow),
229 | '32': int(512 * narrow),
230 | '64': int(256 * channel_multiplier * narrow),
231 | '128': int(128 * channel_multiplier * narrow),
232 | '256': int(64 * channel_multiplier * narrow),
233 | '512': int(32 * channel_multiplier * narrow),
234 | '1024': int(16 * channel_multiplier * narrow)
235 | }
236 | self.channels = channels
237 |
238 | self.constant_input = ConstantInput(channels['4'], size=4)
239 | self.style_conv1 = StyleConv(
240 | channels['4'],
241 | channels['4'],
242 | kernel_size=3,
243 | num_style_feat=num_style_feat,
244 | demodulate=True,
245 | sample_mode=None)
246 | self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False)
247 |
248 | self.log_size = int(math.log(out_size, 2))
249 | self.num_layers = (self.log_size - 2) * 2 + 1
250 | self.num_latent = self.log_size * 2 - 2
251 |
252 | self.style_convs = nn.ModuleList()
253 | self.to_rgbs = nn.ModuleList()
254 | self.noises = nn.Module()
255 |
256 | in_channels = channels['4']
257 | # noise
258 | for layer_idx in range(self.num_layers):
259 | resolution = 2**((layer_idx + 5) // 2)
260 | shape = [1, 1, resolution, resolution]
261 | self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
262 | # style convs and to_rgbs
263 | for i in range(3, self.log_size + 1):
264 | out_channels = channels[f'{2**i}']
265 | self.style_convs.append(
266 | StyleConv(
267 | in_channels,
268 | out_channels,
269 | kernel_size=3,
270 | num_style_feat=num_style_feat,
271 | demodulate=True,
272 | sample_mode='upsample'))
273 | self.style_convs.append(
274 | StyleConv(
275 | out_channels,
276 | out_channels,
277 | kernel_size=3,
278 | num_style_feat=num_style_feat,
279 | demodulate=True,
280 | sample_mode=None))
281 | self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
282 | in_channels = out_channels
283 |
284 | def make_noise(self):
285 | """Make noise for noise injection."""
286 | device = self.constant_input.weight.device
287 | noises = [torch.randn(1, 1, 4, 4, device=device)]
288 |
289 | for i in range(3, self.log_size + 1):
290 | for _ in range(2):
291 | noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
292 |
293 | return noises
294 |
295 | def get_latent(self, x):
296 | return self.style_mlp(x)
297 |
298 | def mean_latent(self, num_latent):
299 | latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
300 | latent = self.style_mlp(latent_in).mean(0, keepdim=True)
301 | return latent
302 |
303 | def forward(self,
304 | styles,
305 | input_is_latent=False,
306 | noise=None,
307 | randomize_noise=True,
308 | truncation=1,
309 | truncation_latent=None,
310 | inject_index=None,
311 | return_latents=False):
312 | """Forward function for StyleGAN2Generator.
313 |
314 | Args:
315 | styles (list[Tensor]): Sample codes of styles.
316 | input_is_latent (bool): Whether input is latent style.
317 | Default: False.
318 | noise (Tensor | None): Input noise or None. Default: None.
319 | randomize_noise (bool): Randomize noise, used when 'noise' is
320 | False. Default: True.
321 | truncation (float): TODO. Default: 1.
322 | truncation_latent (Tensor | None): TODO. Default: None.
323 | inject_index (int | None): The injection index for mixing noise.
324 | Default: None.
325 | return_latents (bool): Whether to return style latents.
326 | Default: False.
327 | """
328 | # style codes -> latents with Style MLP layer
329 | if not input_is_latent:
330 | styles = [self.style_mlp(s) for s in styles]
331 | # noises
332 | if noise is None:
333 | if randomize_noise:
334 | noise = [None] * self.num_layers # for each style conv layer
335 | else: # use the stored noise
336 | noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
337 | # style truncation
338 | if truncation < 1:
339 | style_truncation = []
340 | for style in styles:
341 | style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
342 | styles = style_truncation
343 | # get style latent with injection
344 | if len(styles) == 1:
345 | inject_index = self.num_latent
346 |
347 | if styles[0].ndim < 3:
348 | # repeat latent code for all the layers
349 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
350 | else: # used for encoder with different latent code for each layer
351 | latent = styles[0]
352 | elif len(styles) == 2: # mixing noises
353 | if inject_index is None:
354 | inject_index = random.randint(1, self.num_latent - 1)
355 | latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
356 | latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
357 | latent = torch.cat([latent1, latent2], 1)
358 |
359 | # main generation
360 | out = self.constant_input(latent.shape[0])
361 | out = self.style_conv1(out, latent[:, 0], noise=noise[0])
362 | skip = self.to_rgb1(out, latent[:, 1])
363 |
364 | i = 1
365 | for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
366 | noise[2::2], self.to_rgbs):
367 | out = conv1(out, latent[:, i], noise=noise1)
368 | out = conv2(out, latent[:, i + 1], noise=noise2)
369 | skip = to_rgb(out, latent[:, i + 2], skip)
370 | i += 2
371 |
372 | image = skip
373 |
374 | if return_latents:
375 | return image, latent
376 | else:
377 | return image, None
378 |
--------------------------------------------------------------------------------
/gfpgan/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from basicsr.utils import scandir
3 | from os import path as osp
4 |
5 | # automatically scan and import dataset modules for registry
6 | # scan all the files that end with '_dataset.py' under the data folder
7 | data_folder = osp.dirname(osp.abspath(__file__))
8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
9 | # import all the dataset modules
10 | _dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
11 |
--------------------------------------------------------------------------------
/gfpgan/data/ffhq_degradation_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os.path as osp
5 | import torch
6 | import torch.utils.data as data
7 | from basicsr.data import degradations as degradations
8 | from basicsr.data.data_util import paths_from_folder
9 | from basicsr.data.transforms import augment
10 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
11 | from basicsr.utils.registry import DATASET_REGISTRY
12 | from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
13 | normalize)
14 |
15 |
16 | @DATASET_REGISTRY.register()
17 | class FFHQDegradationDataset(data.Dataset):
18 |
19 | def __init__(self, opt):
20 | super(FFHQDegradationDataset, self).__init__()
21 | self.opt = opt
22 | # file client (io backend)
23 | self.file_client = None
24 | self.io_backend_opt = opt['io_backend']
25 |
26 | self.gt_folder = opt['dataroot_gt']
27 | self.mean = opt['mean']
28 | self.std = opt['std']
29 | self.out_size = opt['out_size']
30 |
31 | self.crop_components = opt.get('crop_components', False) # facial components
32 | self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1)
33 |
34 | if self.crop_components:
35 | self.components_list = torch.load(opt.get('component_path'))
36 |
37 | if self.io_backend_opt['type'] == 'lmdb':
38 | self.io_backend_opt['db_paths'] = self.gt_folder
39 | if not self.gt_folder.endswith('.lmdb'):
40 | raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
42 | self.paths = [line.split('.')[0] for line in fin]
43 | else:
44 | self.paths = paths_from_folder(self.gt_folder)
45 |
46 | # degradations
47 | self.blur_kernel_size = opt['blur_kernel_size']
48 | self.kernel_list = opt['kernel_list']
49 | self.kernel_prob = opt['kernel_prob']
50 | self.blur_sigma = opt['blur_sigma']
51 | self.downsample_range = opt['downsample_range']
52 | self.noise_range = opt['noise_range']
53 | self.jpeg_range = opt['jpeg_range']
54 |
55 | # color jitter
56 | self.color_jitter_prob = opt.get('color_jitter_prob')
57 | self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
58 | self.color_jitter_shift = opt.get('color_jitter_shift', 20)
59 | # to gray
60 | self.gray_prob = opt.get('gray_prob')
61 |
62 | logger = get_root_logger()
63 | logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
64 | f'sigma: [{", ".join(map(str, self.blur_sigma))}]')
65 | logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
66 | logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
67 | logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
68 |
69 | if self.color_jitter_prob is not None:
70 | logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, '
71 | f'shift: {self.color_jitter_shift}')
72 | if self.gray_prob is not None:
73 | logger.info(f'Use random gray. Prob: {self.gray_prob}')
74 |
75 | self.color_jitter_shift /= 255.
76 |
77 | @staticmethod
78 | def color_jitter(img, shift):
79 | jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
80 | img = img + jitter_val
81 | img = np.clip(img, 0, 1)
82 | return img
83 |
84 | @staticmethod
85 | def color_jitter_pt(img, brightness, contrast, saturation, hue):
86 | fn_idx = torch.randperm(4)
87 | for fn_id in fn_idx:
88 | if fn_id == 0 and brightness is not None:
89 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
90 | img = adjust_brightness(img, brightness_factor)
91 |
92 | if fn_id == 1 and contrast is not None:
93 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
94 | img = adjust_contrast(img, contrast_factor)
95 |
96 | if fn_id == 2 and saturation is not None:
97 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
98 | img = adjust_saturation(img, saturation_factor)
99 |
100 | if fn_id == 3 and hue is not None:
101 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
102 | img = adjust_hue(img, hue_factor)
103 | return img
104 |
105 | def get_component_coordinates(self, index, status):
106 | components_bbox = self.components_list[f'{index:08d}']
107 | if status[0]: # hflip
108 | # exchange right and left eye
109 | tmp = components_bbox['left_eye']
110 | components_bbox['left_eye'] = components_bbox['right_eye']
111 | components_bbox['right_eye'] = tmp
112 | # modify the width coordinate
113 | components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
114 | components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
115 | components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
116 |
117 | # get coordinates
118 | locations = []
119 | for part in ['left_eye', 'right_eye', 'mouth']:
120 | mean = components_bbox[part][0:2]
121 | half_len = components_bbox[part][2]
122 | if 'eye' in part:
123 | half_len *= self.eye_enlarge_ratio
124 | loc = np.hstack((mean - half_len + 1, mean + half_len))
125 | loc = torch.from_numpy(loc).float()
126 | locations.append(loc)
127 | return locations
128 |
129 | def __getitem__(self, index):
130 | if self.file_client is None:
131 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
132 |
133 | # load gt image
134 | gt_path = self.paths[index]
135 | img_bytes = self.file_client.get(gt_path)
136 | img_gt = imfrombytes(img_bytes, float32=True)
137 |
138 | # random horizontal flip
139 | img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
140 | h, w, _ = img_gt.shape
141 |
142 | if self.crop_components:
143 | locations = self.get_component_coordinates(index, status)
144 | loc_left_eye, loc_right_eye, loc_mouth = locations
145 |
146 | # ------------------------ generate lq image ------------------------ #
147 | # blur
148 | kernel = degradations.random_mixed_kernels(
149 | self.kernel_list,
150 | self.kernel_prob,
151 | self.blur_kernel_size,
152 | self.blur_sigma,
153 | self.blur_sigma, [-math.pi, math.pi],
154 | noise_range=None)
155 | img_lq = cv2.filter2D(img_gt, -1, kernel)
156 | # downsample
157 | scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
158 | img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
159 | # noise
160 | if self.noise_range is not None:
161 | img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
162 | # jpeg compression
163 | if self.jpeg_range is not None:
164 | img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
165 |
166 | # resize to original size
167 | img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
168 |
169 | # random color jitter (only for lq)
170 | if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
171 | img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
172 | # random to gray (only for lq)
173 | if self.gray_prob and np.random.uniform() < self.gray_prob:
174 | img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
175 | img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
176 | if self.opt.get('gt_gray'):
177 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
178 | img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
179 |
180 | # BGR to RGB, HWC to CHW, numpy to tensor
181 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
182 |
183 | # random color jitter (pytorch version) (only for lq)
184 | if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
185 | brightness = self.opt.get('brightness', (0.5, 1.5))
186 | contrast = self.opt.get('contrast', (0.5, 1.5))
187 | saturation = self.opt.get('saturation', (0, 1.5))
188 | hue = self.opt.get('hue', (-0.1, 0.1))
189 | img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
190 |
191 | # round and clip
192 | img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
193 |
194 | # normalize
195 | normalize(img_gt, self.mean, self.std, inplace=True)
196 | normalize(img_lq, self.mean, self.std, inplace=True)
197 |
198 | if self.crop_components:
199 | return_dict = {
200 | 'lq': img_lq,
201 | 'gt': img_gt,
202 | 'gt_path': gt_path,
203 | 'loc_left_eye': loc_left_eye,
204 | 'loc_right_eye': loc_right_eye,
205 | 'loc_mouth': loc_mouth
206 | }
207 | return return_dict
208 | else:
209 | return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
210 |
211 | def __len__(self):
212 | return len(self.paths)
213 |
--------------------------------------------------------------------------------
/gfpgan/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from basicsr.utils import scandir
3 | from os import path as osp
4 |
5 | # automatically scan and import model modules for registry
6 | # scan all the files that end with '_model.py' under the model folder
7 | model_folder = osp.dirname(osp.abspath(__file__))
8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
9 | # import all the model modules
10 | _model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
11 |
--------------------------------------------------------------------------------
/gfpgan/models/gfpgan_model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os.path as osp
3 | import torch
4 | from basicsr.archs import build_network
5 | from basicsr.losses import build_loss
6 | from basicsr.losses.losses import r1_penalty
7 | from basicsr.metrics import calculate_metric
8 | from basicsr.models.base_model import BaseModel
9 | from basicsr.utils import get_root_logger, imwrite, tensor2img
10 | from basicsr.utils.registry import MODEL_REGISTRY
11 | from collections import OrderedDict
12 | from torch.nn import functional as F
13 | from torchvision.ops import roi_align
14 | from tqdm import tqdm
15 |
16 |
17 | @MODEL_REGISTRY.register()
18 | class GFPGANModel(BaseModel):
19 | """GFPGAN model for """
20 |
21 | def __init__(self, opt):
22 | super(GFPGANModel, self).__init__(opt)
23 | self.idx = 0
24 |
25 | # define network
26 | self.net_g = build_network(opt['network_g'])
27 | self.net_g = self.model_to_device(self.net_g)
28 | self.print_network(self.net_g)
29 |
30 | # load pretrained model
31 | load_path = self.opt['path'].get('pretrain_network_g', None)
32 | if load_path is not None:
33 | param_key = self.opt['path'].get('param_key_g', 'params')
34 | self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
35 |
36 | self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))
37 |
38 | if self.is_train:
39 | self.init_training_settings()
40 |
41 | def init_training_settings(self):
42 | train_opt = self.opt['train']
43 |
44 | # ----------- define net_d ----------- #
45 | self.net_d = build_network(self.opt['network_d'])
46 | self.net_d = self.model_to_device(self.net_d)
47 | self.print_network(self.net_d)
48 | # load pretrained model
49 | load_path = self.opt['path'].get('pretrain_network_d', None)
50 | if load_path is not None:
51 | self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
52 |
53 | # ----------- define net_g with Exponential Moving Average (EMA) ----------- #
54 | # net_g_ema only used for testing on one GPU and saving
55 | # There is no need to wrap with DistributedDataParallel
56 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
57 | # load pretrained model
58 | load_path = self.opt['path'].get('pretrain_network_g', None)
59 | if load_path is not None:
60 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
61 | else:
62 | self.model_ema(0) # copy net_g weight
63 |
64 | self.net_g.train()
65 | self.net_d.train()
66 | self.net_g_ema.eval()
67 |
68 | # ----------- facial components networks ----------- #
69 | if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
70 | self.use_facial_disc = True
71 | else:
72 | self.use_facial_disc = False
73 |
74 | if self.use_facial_disc:
75 | # left eye
76 | self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
77 | self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
78 | self.print_network(self.net_d_left_eye)
79 | load_path = self.opt['path'].get('pretrain_network_d_left_eye')
80 | if load_path is not None:
81 | self.load_network(self.net_d_left_eye, load_path, True, 'params')
82 | # right eye
83 | self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
84 | self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
85 | self.print_network(self.net_d_right_eye)
86 | load_path = self.opt['path'].get('pretrain_network_d_right_eye')
87 | if load_path is not None:
88 | self.load_network(self.net_d_right_eye, load_path, True, 'params')
89 | # mouth
90 | self.net_d_mouth = build_network(self.opt['network_d_mouth'])
91 | self.net_d_mouth = self.model_to_device(self.net_d_mouth)
92 | self.print_network(self.net_d_mouth)
93 | load_path = self.opt['path'].get('pretrain_network_d_mouth')
94 | if load_path is not None:
95 | self.load_network(self.net_d_mouth, load_path, True, 'params')
96 |
97 | self.net_d_left_eye.train()
98 | self.net_d_right_eye.train()
99 | self.net_d_mouth.train()
100 |
101 | # ----------- define facial component gan loss ----------- #
102 | self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)
103 |
104 | # ----------- define losses ----------- #
105 | if train_opt.get('pixel_opt'):
106 | self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
107 | else:
108 | self.cri_pix = None
109 |
110 | if train_opt.get('perceptual_opt'):
111 | self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
112 | else:
113 | self.cri_perceptual = None
114 |
115 | # L1 loss used in pyramid loss, component style loss and identity loss
116 | self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)
117 |
118 | # gan loss (wgan)
119 | self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
120 |
121 | # ----------- define identity loss ----------- #
122 | if 'network_identity' in self.opt:
123 | self.use_identity = True
124 | else:
125 | self.use_identity = False
126 |
127 | if self.use_identity:
128 | # define identity network
129 | self.network_identity = build_network(self.opt['network_identity'])
130 | self.network_identity = self.model_to_device(self.network_identity)
131 | self.print_network(self.network_identity)
132 | load_path = self.opt['path'].get('pretrain_network_identity')
133 | if load_path is not None:
134 | self.load_network(self.network_identity, load_path, True, None)
135 | self.network_identity.eval()
136 | for param in self.network_identity.parameters():
137 | param.requires_grad = False
138 |
139 | # regularization weights
140 | self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
141 | self.net_d_iters = train_opt.get('net_d_iters', 1)
142 | self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
143 | self.net_d_reg_every = train_opt['net_d_reg_every']
144 |
145 | # set up optimizers and schedulers
146 | self.setup_optimizers()
147 | self.setup_schedulers()
148 |
149 | def setup_optimizers(self):
150 | train_opt = self.opt['train']
151 |
152 | # ----------- optimizer g ----------- #
153 | net_g_reg_ratio = 1
154 | normal_params = []
155 | for _, param in self.net_g.named_parameters():
156 | normal_params.append(param)
157 | optim_params_g = [{ # add normal params first
158 | 'params': normal_params,
159 | 'lr': train_opt['optim_g']['lr']
160 | }]
161 | optim_type = train_opt['optim_g'].pop('type')
162 | lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
163 | betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
164 | self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
165 | self.optimizers.append(self.optimizer_g)
166 |
167 | # ----------- optimizer d ----------- #
168 | net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
169 | normal_params = []
170 | for _, param in self.net_d.named_parameters():
171 | normal_params.append(param)
172 | optim_params_d = [{ # add normal params first
173 | 'params': normal_params,
174 | 'lr': train_opt['optim_d']['lr']
175 | }]
176 | optim_type = train_opt['optim_d'].pop('type')
177 | lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
178 | betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
179 | self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
180 | self.optimizers.append(self.optimizer_d)
181 |
182 | if self.use_facial_disc:
183 | # setup optimizers for facial component discriminators
184 | optim_type = train_opt['optim_component'].pop('type')
185 | lr = train_opt['optim_component']['lr']
186 | # left eye
187 | self.optimizer_d_left_eye = self.get_optimizer(
188 | optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
189 | self.optimizers.append(self.optimizer_d_left_eye)
190 | # right eye
191 | self.optimizer_d_right_eye = self.get_optimizer(
192 | optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
193 | self.optimizers.append(self.optimizer_d_right_eye)
194 | # mouth
195 | self.optimizer_d_mouth = self.get_optimizer(
196 | optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
197 | self.optimizers.append(self.optimizer_d_mouth)
198 |
199 | def feed_data(self, data):
200 | self.lq = data['lq'].to(self.device)
201 | if 'gt' in data:
202 | self.gt = data['gt'].to(self.device)
203 |
204 | if 'loc_left_eye' in data:
205 | # get facial component locations, shape (batch, 4)
206 | self.loc_left_eyes = data['loc_left_eye']
207 | self.loc_right_eyes = data['loc_right_eye']
208 | self.loc_mouths = data['loc_mouth']
209 |
210 | # uncomment to check data
211 | # import torchvision
212 | # if self.opt['rank'] == 0:
213 | # import os
214 | # os.makedirs('tmp/gt', exist_ok=True)
215 | # os.makedirs('tmp/lq', exist_ok=True)
216 | # print(self.idx)
217 | # torchvision.utils.save_image(
218 | # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
219 | # torchvision.utils.save_image(
220 | # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
221 | # self.idx = self.idx + 1
222 |
223 | def construct_img_pyramid(self):
224 | pyramid_gt = [self.gt]
225 | down_img = self.gt
226 | for _ in range(0, self.log_size - 3):
227 | down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
228 | pyramid_gt.insert(0, down_img)
229 | return pyramid_gt
230 |
231 | def get_roi_regions(self, eye_out_size=80, mouth_out_size=120):
232 | # hard code
233 | face_ratio = int(self.opt['network_g']['out_size'] / 512)
234 | eye_out_size *= face_ratio
235 | mouth_out_size *= face_ratio
236 |
237 | rois_eyes = []
238 | rois_mouths = []
239 | for b in range(self.loc_left_eyes.size(0)): # loop for batch size
240 | # left eye and right eye
241 | img_inds = self.loc_left_eyes.new_full((2, 1), b)
242 | bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0) # shape: (2, 4)
243 | rois = torch.cat([img_inds, bbox], dim=-1) # shape: (2, 5)
244 | rois_eyes.append(rois)
245 | # mouse
246 | img_inds = self.loc_left_eyes.new_full((1, 1), b)
247 | rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1) # shape: (1, 5)
248 | rois_mouths.append(rois)
249 |
250 | rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
251 | rois_mouths = torch.cat(rois_mouths, 0).to(self.device)
252 |
253 | # real images
254 | all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
255 | self.left_eyes_gt = all_eyes[0::2, :, :, :]
256 | self.right_eyes_gt = all_eyes[1::2, :, :, :]
257 | self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
258 | # output
259 | all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
260 | self.left_eyes = all_eyes[0::2, :, :, :]
261 | self.right_eyes = all_eyes[1::2, :, :, :]
262 | self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio
263 |
264 | def _gram_mat(self, x):
265 | """Calculate Gram matrix.
266 |
267 | Args:
268 | x (torch.Tensor): Tensor with shape of (n, c, h, w).
269 |
270 | Returns:
271 | torch.Tensor: Gram matrix.
272 | """
273 | n, c, h, w = x.size()
274 | features = x.view(n, c, w * h)
275 | features_t = features.transpose(1, 2)
276 | gram = features.bmm(features_t) / (c * h * w)
277 | return gram
278 |
279 | def gray_resize_for_identity(self, out, size=128):
280 | out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
281 | out_gray = out_gray.unsqueeze(1)
282 | out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
283 | return out_gray
284 |
285 | def optimize_parameters(self, current_iter):
286 | # optimize net_g
287 | for p in self.net_d.parameters():
288 | p.requires_grad = False
289 | self.optimizer_g.zero_grad()
290 |
291 | if self.use_facial_disc:
292 | for p in self.net_d_left_eye.parameters():
293 | p.requires_grad = False
294 | for p in self.net_d_right_eye.parameters():
295 | p.requires_grad = False
296 | for p in self.net_d_mouth.parameters():
297 | p.requires_grad = False
298 |
299 | # image pyramid loss weight
300 | if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
301 | pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
302 | else:
303 | pyramid_loss_weight = 1e-12 # very small loss
304 | if pyramid_loss_weight > 0:
305 | self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
306 | pyramid_gt = self.construct_img_pyramid()
307 | else:
308 | self.output, out_rgbs = self.net_g(self.lq, return_rgb=False)
309 |
310 | # get roi-align regions
311 | if self.use_facial_disc:
312 | self.get_roi_regions(eye_out_size=80, mouth_out_size=120)
313 |
314 | l_g_total = 0
315 | loss_dict = OrderedDict()
316 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
317 | # pixel loss
318 | if self.cri_pix:
319 | l_g_pix = self.cri_pix(self.output, self.gt)
320 | l_g_total += l_g_pix
321 | loss_dict['l_g_pix'] = l_g_pix
322 |
323 | # image pyramid loss
324 | if pyramid_loss_weight > 0:
325 | for i in range(0, self.log_size - 2):
326 | l_pyramid = self.cri_l1(out_rgbs[i], pyramid_gt[i]) * pyramid_loss_weight
327 | l_g_total += l_pyramid
328 | loss_dict[f'l_p_{2**(i+3)}'] = l_pyramid
329 |
330 | # perceptual loss
331 | if self.cri_perceptual:
332 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
333 | if l_g_percep is not None:
334 | l_g_total += l_g_percep
335 | loss_dict['l_g_percep'] = l_g_percep
336 | if l_g_style is not None:
337 | l_g_total += l_g_style
338 | loss_dict['l_g_style'] = l_g_style
339 |
340 | # gan loss
341 | fake_g_pred = self.net_d(self.output)
342 | l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
343 | l_g_total += l_g_gan
344 | loss_dict['l_g_gan'] = l_g_gan
345 |
346 | # facial component loss
347 | if self.use_facial_disc:
348 | # left eye
349 | fake_left_eye, fake_left_eye_feats = self.net_d_left_eye(self.left_eyes, return_feats=True)
350 | l_g_gan = self.cri_component(fake_left_eye, True, is_disc=False)
351 | l_g_total += l_g_gan
352 | loss_dict['l_g_gan_left_eye'] = l_g_gan
353 | # right eye
354 | fake_right_eye, fake_right_eye_feats = self.net_d_right_eye(self.right_eyes, return_feats=True)
355 | l_g_gan = self.cri_component(fake_right_eye, True, is_disc=False)
356 | l_g_total += l_g_gan
357 | loss_dict['l_g_gan_right_eye'] = l_g_gan
358 | # mouth
359 | fake_mouth, fake_mouth_feats = self.net_d_mouth(self.mouths, return_feats=True)
360 | l_g_gan = self.cri_component(fake_mouth, True, is_disc=False)
361 | l_g_total += l_g_gan
362 | loss_dict['l_g_gan_mouth'] = l_g_gan
363 |
364 | if self.opt['train'].get('comp_style_weight', 0) > 0:
365 | # get gt feat
366 | _, real_left_eye_feats = self.net_d_left_eye(self.left_eyes_gt, return_feats=True)
367 | _, real_right_eye_feats = self.net_d_right_eye(self.right_eyes_gt, return_feats=True)
368 | _, real_mouth_feats = self.net_d_mouth(self.mouths_gt, return_feats=True)
369 |
370 | def _comp_style(feat, feat_gt, criterion):
371 | return criterion(self._gram_mat(feat[0]), self._gram_mat(
372 | feat_gt[0].detach())) * 0.5 + criterion(
373 | self._gram_mat(feat[1]), self._gram_mat(feat_gt[1].detach()))
374 |
375 | # facial component style loss
376 | comp_style_loss = 0
377 | comp_style_loss += _comp_style(fake_left_eye_feats, real_left_eye_feats, self.cri_l1)
378 | comp_style_loss += _comp_style(fake_right_eye_feats, real_right_eye_feats, self.cri_l1)
379 | comp_style_loss += _comp_style(fake_mouth_feats, real_mouth_feats, self.cri_l1)
380 | comp_style_loss = comp_style_loss * self.opt['train']['comp_style_weight']
381 | l_g_total += comp_style_loss
382 | loss_dict['l_g_comp_style_loss'] = comp_style_loss
383 |
384 | # identity loss
385 | if self.use_identity:
386 | identity_weight = self.opt['train']['identity_weight']
387 | # get gray images and resize
388 | out_gray = self.gray_resize_for_identity(self.output)
389 | gt_gray = self.gray_resize_for_identity(self.gt)
390 |
391 | identity_gt = self.network_identity(gt_gray).detach()
392 | identity_out = self.network_identity(out_gray)
393 | l_identity = self.cri_l1(identity_out, identity_gt) * identity_weight
394 | l_g_total += l_identity
395 | loss_dict['l_identity'] = l_identity
396 |
397 | l_g_total.backward()
398 | self.optimizer_g.step()
399 |
400 | # EMA
401 | self.model_ema(decay=0.5**(32 / (10 * 1000)))
402 |
403 | # ----------- optimize net_d ----------- #
404 | for p in self.net_d.parameters():
405 | p.requires_grad = True
406 | self.optimizer_d.zero_grad()
407 | if self.use_facial_disc:
408 | for p in self.net_d_left_eye.parameters():
409 | p.requires_grad = True
410 | for p in self.net_d_right_eye.parameters():
411 | p.requires_grad = True
412 | for p in self.net_d_mouth.parameters():
413 | p.requires_grad = True
414 | self.optimizer_d_left_eye.zero_grad()
415 | self.optimizer_d_right_eye.zero_grad()
416 | self.optimizer_d_mouth.zero_grad()
417 |
418 | fake_d_pred = self.net_d(self.output.detach())
419 | real_d_pred = self.net_d(self.gt)
420 | l_d = self.cri_gan(real_d_pred, True, is_disc=True) + self.cri_gan(fake_d_pred, False, is_disc=True)
421 | loss_dict['l_d'] = l_d
422 | # In wgan, real_score should be positive and fake_score should benegative
423 | loss_dict['real_score'] = real_d_pred.detach().mean()
424 | loss_dict['fake_score'] = fake_d_pred.detach().mean()
425 | l_d.backward()
426 |
427 | if current_iter % self.net_d_reg_every == 0:
428 | self.gt.requires_grad = True
429 | real_pred = self.net_d(self.gt)
430 | l_d_r1 = r1_penalty(real_pred, self.gt)
431 | l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
432 | loss_dict['l_d_r1'] = l_d_r1.detach().mean()
433 | l_d_r1.backward()
434 |
435 | self.optimizer_d.step()
436 |
437 | if self.use_facial_disc:
438 | # lefe eye
439 | fake_d_pred, _ = self.net_d_left_eye(self.left_eyes.detach())
440 | real_d_pred, _ = self.net_d_left_eye(self.left_eyes_gt)
441 | l_d_left_eye = self.cri_component(
442 | real_d_pred, True, is_disc=True) + self.cri_gan(
443 | fake_d_pred, False, is_disc=True)
444 | loss_dict['l_d_left_eye'] = l_d_left_eye
445 | l_d_left_eye.backward()
446 | # right eye
447 | fake_d_pred, _ = self.net_d_right_eye(self.right_eyes.detach())
448 | real_d_pred, _ = self.net_d_right_eye(self.right_eyes_gt)
449 | l_d_right_eye = self.cri_component(
450 | real_d_pred, True, is_disc=True) + self.cri_gan(
451 | fake_d_pred, False, is_disc=True)
452 | loss_dict['l_d_right_eye'] = l_d_right_eye
453 | l_d_right_eye.backward()
454 | # mouth
455 | fake_d_pred, _ = self.net_d_mouth(self.mouths.detach())
456 | real_d_pred, _ = self.net_d_mouth(self.mouths_gt)
457 | l_d_mouth = self.cri_component(
458 | real_d_pred, True, is_disc=True) + self.cri_gan(
459 | fake_d_pred, False, is_disc=True)
460 | loss_dict['l_d_mouth'] = l_d_mouth
461 | l_d_mouth.backward()
462 |
463 | self.optimizer_d_left_eye.step()
464 | self.optimizer_d_right_eye.step()
465 | self.optimizer_d_mouth.step()
466 |
467 | self.log_dict = self.reduce_loss_dict(loss_dict)
468 |
469 | def test(self):
470 | with torch.no_grad():
471 | if hasattr(self, 'net_g_ema'):
472 | self.net_g_ema.eval()
473 | self.output, _ = self.net_g_ema(self.lq)
474 | else:
475 | logger = get_root_logger()
476 | logger.warning('Do not have self.net_g_ema, use self.net_g.')
477 | self.net_g.eval()
478 | self.output, _ = self.net_g(self.lq)
479 | self.net_g.train()
480 |
481 | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
482 | if self.opt['rank'] == 0:
483 | self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
484 |
485 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
486 | dataset_name = dataloader.dataset.opt['name']
487 | with_metrics = self.opt['val'].get('metrics') is not None
488 | if with_metrics:
489 | self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
490 | pbar = tqdm(total=len(dataloader), unit='image')
491 |
492 | for idx, val_data in enumerate(dataloader):
493 | img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
494 | self.feed_data(val_data)
495 | self.test()
496 |
497 | visuals = self.get_current_visuals()
498 | sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
499 | gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
500 |
501 | if 'gt' in visuals:
502 | gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
503 | del self.gt
504 | # tentative for out of GPU memory
505 | del self.lq
506 | del self.output
507 | torch.cuda.empty_cache()
508 |
509 | if save_img:
510 | if self.opt['is_train']:
511 | save_img_path = osp.join(self.opt['path']['visualization'], img_name,
512 | f'{img_name}_{current_iter}.png')
513 | else:
514 | if self.opt['val']['suffix']:
515 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
516 | f'{img_name}_{self.opt["val"]["suffix"]}.png')
517 | else:
518 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
519 | f'{img_name}_{self.opt["name"]}.png')
520 | imwrite(sr_img, save_img_path)
521 |
522 | if with_metrics:
523 | # calculate metrics
524 | for name, opt_ in self.opt['val']['metrics'].items():
525 | metric_data = dict(img1=sr_img, img2=gt_img)
526 | self.metric_results[name] += calculate_metric(metric_data, opt_)
527 | pbar.update(1)
528 | pbar.set_description(f'Test {img_name}')
529 | pbar.close()
530 |
531 | if with_metrics:
532 | for metric in self.metric_results.keys():
533 | self.metric_results[metric] /= (idx + 1)
534 |
535 | self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
536 |
537 | def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
538 | log_str = f'Validation {dataset_name}\n'
539 | for metric, value in self.metric_results.items():
540 | log_str += f'\t # {metric}: {value:.4f}\n'
541 | logger = get_root_logger()
542 | logger.info(log_str)
543 | if tb_logger:
544 | for metric, value in self.metric_results.items():
545 | tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
546 |
547 | def get_current_visuals(self):
548 | out_dict = OrderedDict()
549 | out_dict['gt'] = self.gt.detach().cpu()
550 | out_dict['sr'] = self.output.detach().cpu()
551 | return out_dict
552 |
553 | def save(self, epoch, current_iter):
554 | self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
555 | self.save_network(self.net_d, 'net_d', current_iter)
556 | # save component discriminators
557 | if self.use_facial_disc:
558 | self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
559 | self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
560 | self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
561 | self.save_training_state(epoch, current_iter)
562 |
--------------------------------------------------------------------------------
/gfpgan/train.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | import os.path as osp
3 | from basicsr.train import train_pipeline
4 |
5 | import gfpgan.archs
6 | import gfpgan.data
7 | import gfpgan.models
8 |
9 | if __name__ == '__main__':
10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11 | train_pipeline(root_path)
12 |
--------------------------------------------------------------------------------
/gfpgan/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import torch
4 | from basicsr.utils import img2tensor, tensor2img
5 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
6 | from torch.hub import download_url_to_file, get_dir
7 | from torchvision.transforms.functional import normalize
8 | from urllib.parse import urlparse
9 |
10 | from gfpgan.archs.gfpganv1_arch import GFPGANv1
11 | from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
12 |
13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14 |
15 |
16 | class GFPGANer():
17 |
18 | def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
19 | self.upscale = upscale
20 | self.bg_upsampler = bg_upsampler
21 |
22 | # initialize model
23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24 | # initialize the GFP-GAN
25 | if arch == 'clean':
26 | self.gfpgan = GFPGANv1Clean(
27 | out_size=512,
28 | num_style_feat=512,
29 | channel_multiplier=channel_multiplier,
30 | decoder_load_path=None,
31 | fix_decoder=False,
32 | num_mlp=8,
33 | input_is_latent=True,
34 | different_w=True,
35 | narrow=1,
36 | sft_half=True)
37 | else:
38 | self.gfpgan = GFPGANv1(
39 | out_size=512,
40 | num_style_feat=512,
41 | channel_multiplier=channel_multiplier,
42 | decoder_load_path=None,
43 | fix_decoder=True,
44 | num_mlp=8,
45 | input_is_latent=True,
46 | different_w=True,
47 | narrow=1,
48 | sft_half=True)
49 | # initialize face helper
50 | self.face_helper = FaceRestoreHelper(
51 | upscale,
52 | face_size=512,
53 | crop_ratio=(1, 1),
54 | det_model='retinaface_resnet50',
55 | save_ext='png',
56 | device=self.device)
57 |
58 | if model_path.startswith('https://'):
59 | model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
60 | loadnet = torch.load(model_path)
61 | if 'params_ema' in loadnet:
62 | keyname = 'params_ema'
63 | else:
64 | keyname = 'params'
65 | self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
66 | self.gfpgan.eval()
67 | self.gfpgan = self.gfpgan.to(self.device)
68 |
69 | @torch.no_grad()
70 | def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
71 | self.face_helper.clean_all()
72 |
73 | if has_aligned:
74 | img = cv2.resize(img, (512, 512))
75 | self.face_helper.cropped_faces = [img]
76 | else:
77 | self.face_helper.read_image(img)
78 | # get face landmarks for each face
79 | self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
80 | # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
81 | # align and warp each face
82 | self.face_helper.align_warp_face()
83 |
84 | # face restoration
85 | for cropped_face in self.face_helper.cropped_faces:
86 | # prepare data
87 | cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
88 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
89 | cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
90 |
91 | try:
92 | output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
93 | # convert to image
94 | restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
95 | except RuntimeError as error:
96 | print(f'\tFailed inference for GFPGAN: {error}.')
97 | restored_face = cropped_face
98 |
99 | restored_face = restored_face.astype('uint8')
100 | self.face_helper.add_restored_face(restored_face)
101 |
102 | if not has_aligned and paste_back:
103 |
104 | if self.bg_upsampler is not None:
105 | # Now only support RealESRGAN
106 | bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
107 | else:
108 | bg_img = None
109 |
110 | self.face_helper.get_inverse_affine(None)
111 | # paste each restored face to the input image
112 | restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
113 | return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
114 | else:
115 | return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
116 |
117 |
118 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
119 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
120 | """
121 | if model_dir is None:
122 | hub_dir = get_dir()
123 | model_dir = os.path.join(hub_dir, 'checkpoints')
124 |
125 | os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
126 |
127 | parts = urlparse(url)
128 | filename = os.path.basename(parts.path)
129 | if file_name is not None:
130 | filename = file_name
131 | cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
132 | if not os.path.exists(cached_file):
133 | print(f'Downloading: "{url}" to {cached_file}\n')
134 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
135 | return cached_file
136 |
--------------------------------------------------------------------------------
/gfpgan/weights/README.md:
--------------------------------------------------------------------------------
1 | # Weights
2 |
3 | Put the downloaded weights to this folder.
4 |
--------------------------------------------------------------------------------
/inference_gfpgan.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import numpy as np
5 | import os
6 | import torch
7 | from basicsr.utils import imwrite
8 |
9 | from gfpgan import GFPGANer
10 |
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument('--upscale', type=int, default=2)
16 | parser.add_argument('--arch', type=str, default='clean')
17 | parser.add_argument('--channel', type=int, default=2)
18 | parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
19 | parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
20 | parser.add_argument('--bg_tile', type=int, default=400)
21 | parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
22 | parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
23 | parser.add_argument('--only_center_face', action='store_true')
24 | parser.add_argument('--aligned', action='store_true')
25 | parser.add_argument('--paste_back', action='store_false')
26 | parser.add_argument('--save_root', type=str, default='results')
27 | parser.add_argument(
28 | '--ext',
29 | type=str,
30 | default='auto',
31 | help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
32 | args = parser.parse_args()
33 |
34 | args = parser.parse_args()
35 | if args.test_path.endswith('/'):
36 | args.test_path = args.test_path[:-1]
37 | os.makedirs(args.save_root, exist_ok=True)
38 |
39 | # background upsampler
40 | if args.bg_upsampler == 'realesrgan':
41 | if not torch.cuda.is_available(): # CPU
42 | import warnings
43 | warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
44 | 'If you really want to use it, please modify the corresponding codes.')
45 | bg_upsampler = None
46 | else:
47 | from realesrgan import RealESRGANer
48 | bg_upsampler = RealESRGANer(
49 | scale=2,
50 | model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
51 | tile=args.bg_tile,
52 | tile_pad=10,
53 | pre_pad=0,
54 | half=True) # need to set False in CPU mode
55 | else:
56 | bg_upsampler = None
57 | # set up GFPGAN restorer
58 | restorer = GFPGANer(
59 | model_path=args.model_path,
60 | upscale=args.upscale,
61 | arch=args.arch,
62 | channel_multiplier=args.channel,
63 | bg_upsampler=bg_upsampler)
64 |
65 | img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
66 | for img_path in img_list:
67 | # read image
68 | img_name = os.path.basename(img_path)
69 | print(f'Processing {img_name} ...')
70 | basename, ext = os.path.splitext(img_name)
71 | input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
72 |
73 | cropped_faces, restored_faces, restored_img = restorer.enhance(
74 | input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
75 |
76 | # save faces
77 | for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
78 | # save cropped face
79 | save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
80 | imwrite(cropped_face, save_crop_path)
81 | # save restored face
82 | if args.suffix is not None:
83 | save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
84 | else:
85 | save_face_name = f'{basename}_{idx:02d}.png'
86 | save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
87 | imwrite(restored_face, save_restore_path)
88 | # save cmp image
89 | cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
90 | imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
91 |
92 | # save restored img
93 | if restored_img is not None:
94 | if args.ext == 'auto':
95 | extension = ext[1:]
96 | else:
97 | extension = args.ext
98 |
99 | if args.suffix is not None:
100 | save_restore_path = os.path.join(args.save_root, 'restored_imgs',
101 | f'{basename}_{args.suffix}.{extension}')
102 | else:
103 | save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
104 | imwrite(restored_img, save_restore_path)
105 |
106 | print(f'Results are in the [{args.save_root}] folder.')
107 |
108 |
109 | if __name__ == '__main__':
110 | main()
111 |
--------------------------------------------------------------------------------
/insightface_func/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NNNNAI/VGGFace2-HQ/36d5b72cb7b7b5a17a8daa7e8f94c79ff3ef32ec/insightface_func/__init__.py
--------------------------------------------------------------------------------
/insightface_func/face_detect_crop.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : insightface.ai
3 | # @Author : Jia Guo
4 | # @Time : 2021-05-04
5 | # @Function :
6 |
7 |
8 | from __future__ import division
9 | import collections
10 | import numpy as np
11 | import glob
12 | import os
13 | import os.path as osp
14 | from numpy.linalg import norm
15 | from insightface.model_zoo import model_zoo
16 | from insightface_func.utils import face_align
17 |
18 | __all__ = ['Face_detect_crop', 'Face']
19 |
20 | Face = collections.namedtuple('Face', [
21 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age',
22 | 'embedding_norm', 'normed_embedding',
23 | 'landmark'
24 | ])
25 |
26 | Face.__new__.__defaults__ = (None, ) * len(Face._fields)
27 |
28 |
29 | class Face_detect_crop:
30 | def __init__(self, name, root='~/.insightface_func/models'):
31 | self.models = {}
32 | root = os.path.expanduser(root)
33 | onnx_files = glob.glob(osp.join(root, name, '*.onnx'))
34 | onnx_files = sorted(onnx_files)
35 | for onnx_file in onnx_files:
36 | if onnx_file.find('_selfgen_')>0:
37 | #print('ignore:', onnx_file)
38 | continue
39 | model = model_zoo.get_model(onnx_file)
40 | if model.taskname not in self.models:
41 | print('find model:', onnx_file, model.taskname)
42 | self.models[model.taskname] = model
43 | else:
44 | print('duplicated model task type, ignore:', onnx_file, model.taskname)
45 | del model
46 | assert 'detection' in self.models
47 | self.det_model = self.models['detection']
48 |
49 |
50 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
51 | self.det_thresh = det_thresh
52 | assert det_size is not None
53 | print('set det-size:', det_size)
54 | self.det_size = det_size
55 | for taskname, model in self.models.items():
56 | if taskname=='detection':
57 | model.prepare(ctx_id, input_size=det_size)
58 | else:
59 | model.prepare(ctx_id)
60 |
61 | def get(self, img, crop_size, max_num=0, mode = 'None'):
62 | bboxes, kpss = self.det_model.detect(img,
63 | threshold=self.det_thresh,
64 | max_num=max_num,
65 | metric='default')
66 | if bboxes.shape[0] == 0:
67 | return []
68 | ret = []
69 | if mode == 'Both':
70 | for i in range(bboxes.shape[0]):
71 | kps = None
72 | if kpss is not None:
73 | kps = kpss[i]
74 | aimg_None,aimg_arface = face_align.norm_crop(img, kps,crop_size,mode =mode)
75 | return [aimg_None,aimg_arface]
76 |
77 | else:
78 | for i in range(bboxes.shape[0]):
79 | kps = None
80 | if kpss is not None:
81 | kps = kpss[i]
82 | aimg = face_align.norm_crop(img, kps,crop_size,mode =mode)
83 | return [aimg]
84 |
85 | def draw_on(self, img, faces):
86 | import cv2
87 | for i in range(len(faces)):
88 | face = faces[i]
89 | box = face.bbox.astype(np.int)
90 | color = (0, 0, 255)
91 | cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, 2)
92 | if face.kps is not None:
93 | kps = face.kps.astype(np.int)
94 | #print(landmark.shape)
95 | for l in range(kps.shape[0]):
96 | color = (0, 0, 255)
97 | if l == 0 or l == 3:
98 | color = (0, 255, 0)
99 | cv2.circle(img, (kps[l][0], kps[l][1]), 1, color,
100 | 2)
101 | return img
102 |
103 |
--------------------------------------------------------------------------------
/insightface_func/face_detect_crop_ffhq_newarcAlign.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-15 19:42:42
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-16 15:35:02
7 | Description:
8 | '''
9 |
10 | from __future__ import division
11 | import collections
12 | import numpy as np
13 | import glob
14 | import os
15 | import os.path as osp
16 | from insightface.model_zoo import model_zoo
17 | from insightface_func.utils import face_align_ffhqandnewarc as face_align
18 |
19 | __all__ = ['Face_detect_crop', 'Face']
20 |
21 | Face = collections.namedtuple('Face', [
22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age',
23 | 'embedding_norm', 'normed_embedding',
24 | 'landmark'
25 | ])
26 |
27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields)
28 |
29 |
30 | class Face_detect_crop:
31 | def __init__(self, name, root='~/.insightface_func/models'):
32 | self.models = {}
33 | root = os.path.expanduser(root)
34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx'))
35 | onnx_files = sorted(onnx_files)
36 | for onnx_file in onnx_files:
37 | if onnx_file.find('_selfgen_')>0:
38 | #print('ignore:', onnx_file)
39 | continue
40 | model = model_zoo.get_model(onnx_file)
41 | if model.taskname not in self.models:
42 | print('find model:', onnx_file, model.taskname)
43 | self.models[model.taskname] = model
44 | else:
45 | print('duplicated model task type, ignore:', onnx_file, model.taskname)
46 | del model
47 | assert 'detection' in self.models
48 | self.det_model = self.models['detection']
49 |
50 |
51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
52 | self.det_thresh = det_thresh
53 | assert det_size is not None
54 | print('set det-size:', det_size)
55 | self.det_size = det_size
56 | for taskname, model in self.models.items():
57 | if taskname=='detection':
58 | model.prepare(ctx_id, input_size=det_size)
59 | else:
60 | model.prepare(ctx_id)
61 |
62 | def get(self, img, crop_size, max_num=0, mode = 'ffhq'):
63 | bboxes, kpss = self.det_model.detect(img,
64 | threshold=self.det_thresh,
65 | max_num=max_num,
66 | metric='default')
67 | if bboxes.shape[0] == 0:
68 | return []
69 | ret = []
70 | if mode == 'Both':
71 | for i in range(bboxes.shape[0]):
72 | kps = None
73 | if kpss is not None:
74 | kps = kpss[i]
75 | aimg_ffhq,aimg_None = face_align.norm_crop(img, kps,crop_size,mode =mode)
76 | return [aimg_ffhq,aimg_None]
77 |
78 | else:
79 | for i in range(bboxes.shape[0]):
80 | kps = None
81 | if kpss is not None:
82 | kps = kpss[i]
83 | aimg = face_align.norm_crop(img, kps,crop_size,mode =mode)
84 | return [aimg]
85 |
86 | def draw_on(self, img, faces):
87 | import cv2
88 | for i in range(len(faces)):
89 | face = faces[i]
90 | box = face.bbox.astype(np.int)
91 | color = (0, 0, 255)
92 | cv2.rectangle(img, (box[0], box[1]), (box[2], box[3]), color, 2)
93 | if face.kps is not None:
94 | kps = face.kps.astype(np.int)
95 | #print(landmark.shape)
96 | for l in range(kps.shape[0]):
97 | color = (0, 0, 255)
98 | if l == 0 or l == 3:
99 | color = (0, 255, 0)
100 | cv2.circle(img, (kps[l][0], kps[l][1]), 1, color,
101 | 2)
102 | return img
103 |
104 |
--------------------------------------------------------------------------------
/insightface_func/utils/face_align.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Organization : insightface.ai
3 | # @Author : Jia Guo
4 | # @Time : 2021-05-04
5 | # @Function :
6 |
7 | import cv2
8 | import numpy as np
9 | from skimage import transform as trans
10 |
11 | src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007],
12 | [51.157, 89.050], [57.025, 89.702]],
13 | dtype=np.float32)
14 | #<--left
15 | src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111],
16 | [45.177, 86.190], [64.246, 86.758]],
17 | dtype=np.float32)
18 |
19 | #---frontal
20 | src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493],
21 | [42.463, 87.010], [69.537, 87.010]],
22 | dtype=np.float32)
23 |
24 | #-->right
25 | src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111],
26 | [48.167, 86.758], [67.236, 86.190]],
27 | dtype=np.float32)
28 |
29 | #-->right profile
30 | src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007],
31 | [55.388, 89.702], [61.257, 89.050]],
32 | dtype=np.float32)
33 |
34 | src = np.array([src1, src2, src3, src4, src5])
35 | src_map = src
36 |
37 | arcface_src = np.array(
38 | [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
39 | [41.5493, 92.3655], [70.7299, 92.2041]],
40 | dtype=np.float32)
41 |
42 | arcface_src = np.expand_dims(arcface_src, axis=0)
43 |
44 | # In[66]:
45 |
46 |
47 | # lmk is prediction; src is template
48 | def estimate_norm(lmk, image_size=112, mode='arcface'):
49 | assert lmk.shape == (5, 2)
50 | tform = trans.SimilarityTransform()
51 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
52 | min_M = []
53 | min_index = []
54 | min_error = float('inf')
55 | if mode == 'arcface':
56 | assert image_size == 112
57 | src = arcface_src
58 | else:
59 | src = src_map * image_size / 112
60 | for i in np.arange(src.shape[0]):
61 | tform.estimate(lmk, src[i])
62 | M = tform.params[0:2, :]
63 | results = np.dot(M, lmk_tran.T)
64 | results = results.T
65 | error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1)))
66 | # print(error)
67 | if error < min_error:
68 | min_error = error
69 | min_M = M
70 | min_index = i
71 | return min_M, min_index
72 |
73 |
74 | def norm_crop(img, landmark, image_size=112, mode='arcface'):
75 | if mode == 'Both':
76 | M_None, _ = estimate_norm(landmark, image_size, mode = 'None')
77 | M_arcface, _ = estimate_norm(landmark, 112, mode='arcface')
78 | warped_None = cv2.warpAffine(img, M_None, (image_size, image_size), borderValue=0.0)
79 | warped_arcface = cv2.warpAffine(img, M_arcface, (112, 112), borderValue=0.0)
80 | return warped_None, warped_arcface
81 | else:
82 | M, pose_index = estimate_norm(landmark, image_size, mode)
83 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
84 | return warped
85 |
86 | def square_crop(im, S):
87 | if im.shape[0] > im.shape[1]:
88 | height = S
89 | width = int(float(im.shape[1]) / im.shape[0] * S)
90 | scale = float(S) / im.shape[0]
91 | else:
92 | width = S
93 | height = int(float(im.shape[0]) / im.shape[1] * S)
94 | scale = float(S) / im.shape[1]
95 | resized_im = cv2.resize(im, (width, height))
96 | det_im = np.zeros((S, S, 3), dtype=np.uint8)
97 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
98 | return det_im, scale
99 |
100 |
101 | def transform(data, center, output_size, scale, rotation):
102 | scale_ratio = scale
103 | rot = float(rotation) * np.pi / 180.0
104 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
105 | t1 = trans.SimilarityTransform(scale=scale_ratio)
106 | cx = center[0] * scale_ratio
107 | cy = center[1] * scale_ratio
108 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
109 | t3 = trans.SimilarityTransform(rotation=rot)
110 | t4 = trans.SimilarityTransform(translation=(output_size / 2,
111 | output_size / 2))
112 | t = t1 + t2 + t3 + t4
113 | M = t.params[0:2]
114 | cropped = cv2.warpAffine(data,
115 | M, (output_size, output_size),
116 | borderValue=0.0)
117 | return cropped, M
118 |
119 |
120 | def trans_points2d(pts, M):
121 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
122 | for i in range(pts.shape[0]):
123 | pt = pts[i]
124 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
125 | new_pt = np.dot(M, new_pt)
126 | #print('new_pt', new_pt.shape, new_pt)
127 | new_pts[i] = new_pt[0:2]
128 |
129 | return new_pts
130 |
131 |
132 | def trans_points3d(pts, M):
133 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
134 | #print(scale)
135 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
136 | for i in range(pts.shape[0]):
137 | pt = pts[i]
138 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
139 | new_pt = np.dot(M, new_pt)
140 | #print('new_pt', new_pt.shape, new_pt)
141 | new_pts[i][0:2] = new_pt[0:2]
142 | new_pts[i][2] = pts[i][2] * scale
143 |
144 | return new_pts
145 |
146 |
147 | def trans_points(pts, M):
148 | if pts.shape[1] == 2:
149 | return trans_points2d(pts, M)
150 | else:
151 | return trans_points3d(pts, M)
152 |
153 |
--------------------------------------------------------------------------------
/insightface_func/utils/face_align_ffhqandnewarc.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-15 19:42:42
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-15 20:01:47
7 | Description:
8 | '''
9 |
10 | import cv2
11 | import numpy as np
12 | from skimage import transform as trans
13 |
14 | src1 = np.array([[51.642, 50.115], [57.617, 49.990], [35.740, 69.007],
15 | [51.157, 89.050], [57.025, 89.702]],
16 | dtype=np.float32)
17 | #<--left
18 | src2 = np.array([[45.031, 50.118], [65.568, 50.872], [39.677, 68.111],
19 | [45.177, 86.190], [64.246, 86.758]],
20 | dtype=np.float32)
21 |
22 | #---frontal
23 | src3 = np.array([[39.730, 51.138], [72.270, 51.138], [56.000, 68.493],
24 | [42.463, 87.010], [69.537, 87.010]],
25 | dtype=np.float32)
26 |
27 | #-->right
28 | src4 = np.array([[46.845, 50.872], [67.382, 50.118], [72.737, 68.111],
29 | [48.167, 86.758], [67.236, 86.190]],
30 | dtype=np.float32)
31 |
32 | #-->right profile
33 | src5 = np.array([[54.796, 49.990], [60.771, 50.115], [76.673, 69.007],
34 | [55.388, 89.702], [61.257, 89.050]],
35 | dtype=np.float32)
36 |
37 | src = np.array([src1, src2, src3, src4, src5])
38 | src_map = src
39 |
40 | ffhq_src = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
41 | [201.26117, 371.41043], [313.08905, 371.15118]])
42 | ffhq_src = np.expand_dims(ffhq_src, axis=0)
43 |
44 | # arcface_src = np.array(
45 | # [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
46 | # [41.5493, 92.3655], [70.7299, 92.2041]],
47 | # dtype=np.float32)
48 |
49 | # arcface_src = np.expand_dims(arcface_src, axis=0)
50 |
51 | # In[66]:
52 |
53 |
54 | # lmk is prediction; src is template
55 | def estimate_norm(lmk, image_size=112, mode='ffhq'):
56 | assert lmk.shape == (5, 2)
57 | tform = trans.SimilarityTransform()
58 | lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
59 | min_M = []
60 | min_index = []
61 | min_error = float('inf')
62 | if mode == 'ffhq':
63 | # assert image_size == 112
64 | src = ffhq_src * image_size / 512
65 | else:
66 | src = src_map * image_size / 112
67 | for i in np.arange(src.shape[0]):
68 | tform.estimate(lmk, src[i])
69 | M = tform.params[0:2, :]
70 | results = np.dot(M, lmk_tran.T)
71 | results = results.T
72 | error = np.sum(np.sqrt(np.sum((results - src[i])**2, axis=1)))
73 | # print(error)
74 | if error < min_error:
75 | min_error = error
76 | min_M = M
77 | min_index = i
78 | return min_M, min_index
79 |
80 |
81 | def norm_crop(img, landmark, image_size=112, mode='ffhq'):
82 | if mode == 'Both':
83 | M_None, _ = estimate_norm(landmark, image_size, mode = 'newarc')
84 | M_ffhq, _ = estimate_norm(landmark, image_size, mode='ffhq')
85 | warped_None = cv2.warpAffine(img, M_None, (image_size, image_size), borderValue=0.0)
86 | warped_ffhq = cv2.warpAffine(img, M_ffhq, (image_size, image_size), borderValue=0.0)
87 | return warped_ffhq, warped_None
88 | else:
89 | M, pose_index = estimate_norm(landmark, image_size, mode)
90 | warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
91 | return warped
92 |
93 | def square_crop(im, S):
94 | if im.shape[0] > im.shape[1]:
95 | height = S
96 | width = int(float(im.shape[1]) / im.shape[0] * S)
97 | scale = float(S) / im.shape[0]
98 | else:
99 | width = S
100 | height = int(float(im.shape[0]) / im.shape[1] * S)
101 | scale = float(S) / im.shape[1]
102 | resized_im = cv2.resize(im, (width, height))
103 | det_im = np.zeros((S, S, 3), dtype=np.uint8)
104 | det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
105 | return det_im, scale
106 |
107 |
108 | def transform(data, center, output_size, scale, rotation):
109 | scale_ratio = scale
110 | rot = float(rotation) * np.pi / 180.0
111 | #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
112 | t1 = trans.SimilarityTransform(scale=scale_ratio)
113 | cx = center[0] * scale_ratio
114 | cy = center[1] * scale_ratio
115 | t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
116 | t3 = trans.SimilarityTransform(rotation=rot)
117 | t4 = trans.SimilarityTransform(translation=(output_size / 2,
118 | output_size / 2))
119 | t = t1 + t2 + t3 + t4
120 | M = t.params[0:2]
121 | cropped = cv2.warpAffine(data,
122 | M, (output_size, output_size),
123 | borderValue=0.0)
124 | return cropped, M
125 |
126 |
127 | def trans_points2d(pts, M):
128 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
129 | for i in range(pts.shape[0]):
130 | pt = pts[i]
131 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
132 | new_pt = np.dot(M, new_pt)
133 | #print('new_pt', new_pt.shape, new_pt)
134 | new_pts[i] = new_pt[0:2]
135 |
136 | return new_pts
137 |
138 |
139 | def trans_points3d(pts, M):
140 | scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
141 | #print(scale)
142 | new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
143 | for i in range(pts.shape[0]):
144 | pt = pts[i]
145 | new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
146 | new_pt = np.dot(M, new_pt)
147 | #print('new_pt', new_pt.shape, new_pt)
148 | new_pts[i][0:2] = new_pt[0:2]
149 | new_pts[i][2] = pts[i][2] * scale
150 |
151 | return new_pts
152 |
153 |
154 | def trans_points(pts, M):
155 | if pts.shape[1] == 2:
156 | return trans_points2d(pts, M)
157 | else:
158 | return trans_points3d(pts, M)
159 |
160 |
--------------------------------------------------------------------------------
/options/train_gfpgan_v1.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_GFPGANv1_512
3 | model_type: GFPGANModel
4 | num_gpu: 4
5 | manual_seed: 0
6 |
7 | # dataset and data loader settings
8 | datasets:
9 | train:
10 | name: FFHQ
11 | type: FFHQDegradationDataset
12 | # dataroot_gt: datasets/ffhq/ffhq_512.lmdb
13 | dataroot_gt: datasets/ffhq/ffhq_512
14 | io_backend:
15 | # type: lmdb
16 | type: disk
17 |
18 | use_hflip: true
19 | mean: [0.5, 0.5, 0.5]
20 | std: [0.5, 0.5, 0.5]
21 | out_size: 512
22 |
23 | blur_kernel_size: 41
24 | kernel_list: ['iso', 'aniso']
25 | kernel_prob: [0.5, 0.5]
26 | blur_sigma: [0.1, 10]
27 | downsample_range: [0.8, 8]
28 | noise_range: [0, 20]
29 | jpeg_range: [60, 100]
30 |
31 | # color jitter and gray
32 | color_jitter_prob: 0.3
33 | color_jitter_shift: 20
34 | color_jitter_pt_prob: 0.3
35 | gray_prob: 0.01
36 |
37 | # If you do not want colorization, please set
38 | # color_jitter_prob: ~
39 | # color_jitter_pt_prob: ~
40 | # gray_prob: 0.01
41 | # gt_gray: True
42 |
43 | crop_components: true
44 | component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
45 | eye_enlarge_ratio: 1.4
46 |
47 | # data loader
48 | use_shuffle: true
49 | num_worker_per_gpu: 6
50 | batch_size_per_gpu: 3
51 | dataset_enlarge_ratio: 1
52 | prefetch_mode: ~
53 |
54 | val:
55 | # Please modify accordingly to use your own validation
56 | # Or comment the val block if do not need validation during training
57 | name: validation
58 | type: PairedImageDataset
59 | dataroot_lq: datasets/faces/validation/input
60 | dataroot_gt: datasets/faces/validation/reference
61 | io_backend:
62 | type: disk
63 | mean: [0.5, 0.5, 0.5]
64 | std: [0.5, 0.5, 0.5]
65 | scale: 1
66 |
67 | # network structures
68 | network_g:
69 | type: GFPGANv1
70 | out_size: 512
71 | num_style_feat: 512
72 | channel_multiplier: 1
73 | resample_kernel: [1, 3, 3, 1]
74 | decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
75 | fix_decoder: true
76 | num_mlp: 8
77 | lr_mlp: 0.01
78 | input_is_latent: true
79 | different_w: true
80 | narrow: 1
81 | sft_half: true
82 |
83 | network_d:
84 | type: StyleGAN2Discriminator
85 | out_size: 512
86 | channel_multiplier: 1
87 | resample_kernel: [1, 3, 3, 1]
88 |
89 | network_d_left_eye:
90 | type: FacialComponentDiscriminator
91 |
92 | network_d_right_eye:
93 | type: FacialComponentDiscriminator
94 |
95 | network_d_mouth:
96 | type: FacialComponentDiscriminator
97 |
98 | network_identity:
99 | type: ResNetArcFace
100 | block: IRBlock
101 | layers: [2, 2, 2, 2]
102 | use_se: False
103 |
104 | # path
105 | path:
106 | pretrain_network_g: ~
107 | param_key_g: params_ema
108 | strict_load_g: ~
109 | pretrain_network_d: ~
110 | pretrain_network_d_left_eye: ~
111 | pretrain_network_d_right_eye: ~
112 | pretrain_network_d_mouth: ~
113 | pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
114 | # resume
115 | resume_state: ~
116 | ignore_resume_networks: ['network_identity']
117 |
118 | # training settings
119 | train:
120 | optim_g:
121 | type: Adam
122 | lr: !!float 2e-3
123 | optim_d:
124 | type: Adam
125 | lr: !!float 2e-3
126 | optim_component:
127 | type: Adam
128 | lr: !!float 2e-3
129 |
130 | scheduler:
131 | type: MultiStepLR
132 | milestones: [600000, 700000]
133 | gamma: 0.5
134 |
135 | total_iter: 800000
136 | warmup_iter: -1 # no warm up
137 |
138 | # losses
139 | # pixel loss
140 | pixel_opt:
141 | type: L1Loss
142 | loss_weight: !!float 1e-1
143 | reduction: mean
144 | # L1 loss used in pyramid loss, component style loss and identity loss
145 | L1_opt:
146 | type: L1Loss
147 | loss_weight: 1
148 | reduction: mean
149 |
150 | # image pyramid loss
151 | pyramid_loss_weight: 1
152 | remove_pyramid_loss: 50000
153 | # perceptual loss (content and style losses)
154 | perceptual_opt:
155 | type: PerceptualLoss
156 | layer_weights:
157 | # before relu
158 | 'conv1_2': 0.1
159 | 'conv2_2': 0.1
160 | 'conv3_4': 1
161 | 'conv4_4': 1
162 | 'conv5_4': 1
163 | vgg_type: vgg19
164 | use_input_norm: true
165 | perceptual_weight: !!float 1
166 | style_weight: 50
167 | range_norm: true
168 | criterion: l1
169 | # gan loss
170 | gan_opt:
171 | type: GANLoss
172 | gan_type: wgan_softplus
173 | loss_weight: !!float 1e-1
174 | # r1 regularization for discriminator
175 | r1_reg_weight: 10
176 | # facial component loss
177 | gan_component_opt:
178 | type: GANLoss
179 | gan_type: vanilla
180 | real_label_val: 1.0
181 | fake_label_val: 0.0
182 | loss_weight: !!float 1
183 | comp_style_weight: 200
184 | # identity loss
185 | identity_weight: 10
186 |
187 | net_d_iters: 1
188 | net_d_init_iters: 0
189 | net_d_reg_every: 16
190 |
191 | # validation settings
192 | val:
193 | val_freq: !!float 5e3
194 | save_img: true
195 |
196 | metrics:
197 | psnr: # metric name, can be arbitrary
198 | type: calculate_psnr
199 | crop_border: 0
200 | test_y_channel: false
201 |
202 | # logging settings
203 | logger:
204 | print_freq: 100
205 | save_checkpoint_freq: !!float 5e3
206 | use_tb_logger: true
207 | wandb:
208 | project: ~
209 | resume_id: ~
210 |
211 | # dist training settings
212 | dist_params:
213 | backend: nccl
214 | port: 29500
215 |
216 | find_unused_parameters: true
217 |
--------------------------------------------------------------------------------
/options/train_gfpgan_v1_simple.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_GFPGANv1_512_simple
3 | model_type: GFPGANModel
4 | num_gpu: 4
5 | manual_seed: 0
6 |
7 | # dataset and data loader settings
8 | datasets:
9 | train:
10 | name: FFHQ
11 | type: FFHQDegradationDataset
12 | # dataroot_gt: datasets/ffhq/ffhq_512.lmdb
13 | dataroot_gt: datasets/ffhq/ffhq_512
14 | io_backend:
15 | # type: lmdb
16 | type: disk
17 |
18 | use_hflip: true
19 | mean: [0.5, 0.5, 0.5]
20 | std: [0.5, 0.5, 0.5]
21 | out_size: 512
22 |
23 | blur_kernel_size: 41
24 | kernel_list: ['iso', 'aniso']
25 | kernel_prob: [0.5, 0.5]
26 | blur_sigma: [0.1, 10]
27 | downsample_range: [0.8, 8]
28 | noise_range: [0, 20]
29 | jpeg_range: [60, 100]
30 |
31 | # color jitter and gray
32 | color_jitter_prob: 0.3
33 | color_jitter_shift: 20
34 | color_jitter_pt_prob: 0.3
35 | gray_prob: 0.01
36 |
37 | # If you do not want colorization, please set
38 | # color_jitter_prob: ~
39 | # color_jitter_pt_prob: ~
40 | # gray_prob: 0.01
41 | # gt_gray: True
42 |
43 | # crop_components: false
44 | # component_path: experiments/pretrained_models/FFHQ_eye_mouth_landmarks_512.pth
45 | # eye_enlarge_ratio: 1.4
46 |
47 | # data loader
48 | use_shuffle: true
49 | num_worker_per_gpu: 6
50 | batch_size_per_gpu: 3
51 | dataset_enlarge_ratio: 1
52 | prefetch_mode: ~
53 |
54 | val:
55 | # Please modify accordingly to use your own validation
56 | # Or comment the val block if do not need validation during training
57 | name: validation
58 | type: PairedImageDataset
59 | dataroot_lq: datasets/faces/validation/input
60 | dataroot_gt: datasets/faces/validation/reference
61 | io_backend:
62 | type: disk
63 | mean: [0.5, 0.5, 0.5]
64 | std: [0.5, 0.5, 0.5]
65 | scale: 1
66 |
67 | # network structures
68 | network_g:
69 | type: GFPGANv1
70 | out_size: 512
71 | num_style_feat: 512
72 | channel_multiplier: 1
73 | resample_kernel: [1, 3, 3, 1]
74 | decoder_load_path: experiments/pretrained_models/StyleGAN2_512_Cmul1_FFHQ_B12G4_scratch_800k.pth
75 | fix_decoder: true
76 | num_mlp: 8
77 | lr_mlp: 0.01
78 | input_is_latent: true
79 | different_w: true
80 | narrow: 1
81 | sft_half: true
82 |
83 | network_d:
84 | type: StyleGAN2Discriminator
85 | out_size: 512
86 | channel_multiplier: 1
87 | resample_kernel: [1, 3, 3, 1]
88 |
89 | # network_d_left_eye:
90 | # type: FacialComponentDiscriminator
91 |
92 | # network_d_right_eye:
93 | # type: FacialComponentDiscriminator
94 |
95 | # network_d_mouth:
96 | # type: FacialComponentDiscriminator
97 |
98 | network_identity:
99 | type: ResNetArcFace
100 | block: IRBlock
101 | layers: [2, 2, 2, 2]
102 | use_se: False
103 |
104 | # path
105 | path:
106 | pretrain_network_g: ~
107 | param_key_g: params_ema
108 | strict_load_g: ~
109 | pretrain_network_d: ~
110 | # pretrain_network_d_left_eye: ~
111 | # pretrain_network_d_right_eye: ~
112 | # pretrain_network_d_mouth: ~
113 | pretrain_network_identity: experiments/pretrained_models/arcface_resnet18.pth
114 | # resume
115 | resume_state: ~
116 | ignore_resume_networks: ['network_identity']
117 |
118 | # training settings
119 | train:
120 | optim_g:
121 | type: Adam
122 | lr: !!float 2e-3
123 | optim_d:
124 | type: Adam
125 | lr: !!float 2e-3
126 | optim_component:
127 | type: Adam
128 | lr: !!float 2e-3
129 |
130 | scheduler:
131 | type: MultiStepLR
132 | milestones: [600000, 700000]
133 | gamma: 0.5
134 |
135 | total_iter: 800000
136 | warmup_iter: -1 # no warm up
137 |
138 | # losses
139 | # pixel loss
140 | pixel_opt:
141 | type: L1Loss
142 | loss_weight: !!float 1e-1
143 | reduction: mean
144 | # L1 loss used in pyramid loss, component style loss and identity loss
145 | L1_opt:
146 | type: L1Loss
147 | loss_weight: 1
148 | reduction: mean
149 |
150 | # image pyramid loss
151 | pyramid_loss_weight: 1
152 | remove_pyramid_loss: 50000
153 | # perceptual loss (content and style losses)
154 | perceptual_opt:
155 | type: PerceptualLoss
156 | layer_weights:
157 | # before relu
158 | 'conv1_2': 0.1
159 | 'conv2_2': 0.1
160 | 'conv3_4': 1
161 | 'conv4_4': 1
162 | 'conv5_4': 1
163 | vgg_type: vgg19
164 | use_input_norm: true
165 | perceptual_weight: !!float 1
166 | style_weight: 50
167 | range_norm: true
168 | criterion: l1
169 | # gan loss
170 | gan_opt:
171 | type: GANLoss
172 | gan_type: wgan_softplus
173 | loss_weight: !!float 1e-1
174 | # r1 regularization for discriminator
175 | r1_reg_weight: 10
176 | # facial component loss
177 | # gan_component_opt:
178 | # type: GANLoss
179 | # gan_type: vanilla
180 | # real_label_val: 1.0
181 | # fake_label_val: 0.0
182 | # loss_weight: !!float 1
183 | # comp_style_weight: 200
184 | # identity loss
185 | identity_weight: 10
186 |
187 | net_d_iters: 1
188 | net_d_init_iters: 0
189 | net_d_reg_every: 16
190 |
191 | # validation settings
192 | val:
193 | val_freq: !!float 5e3
194 | save_img: true
195 |
196 | metrics:
197 | psnr: # metric name, can be arbitrary
198 | type: calculate_psnr
199 | crop_border: 0
200 | test_y_channel: false
201 |
202 | # logging settings
203 | logger:
204 | print_freq: 100
205 | save_checkpoint_freq: !!float 5e3
206 | use_tb_logger: true
207 | wandb:
208 | project: ~
209 | resume_id: ~
210 |
211 | # dist training settings
212 | dist_params:
213 | backend: nccl
214 | port: 29500
215 |
216 | find_unused_parameters: true
217 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.7
2 | numpy<1.21 # numba requires numpy<1.21,>=1.17
3 | opencv-python
4 | torchvision
5 | scipy
6 | tqdm
7 | basicsr>=1.3.4.0
8 | facexlib>=0.2.0.3
9 | lmdb
10 | pyyaml
11 | tb-nightly
12 | yapf
13 |
--------------------------------------------------------------------------------
/scripts/crop_align_vggface2_FFHQalign.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-16 15:34:14
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-19 16:14:01
7 | Description:
8 | '''
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 | import os
13 | import argparse
14 | import cv2
15 | import glob
16 | from tqdm import tqdm
17 | from insightface_func.face_detect_crop_ffhq_newarcAlign import Face_detect_crop
18 | import argparse
19 |
20 | def align_image_dir(dir_name_tmp):
21 | ori_path_tmp = os.path.join(input_dir, dir_name_tmp)
22 | image_filenames = glob.glob(os.path.join(ori_path_tmp,'*'))
23 | save_dir_ffhqalign = os.path.join(output_dir_ffhqalign,dir_name_tmp)
24 | if not os.path.exists(save_dir_ffhqalign):
25 | os.makedirs(save_dir_ffhqalign)
26 |
27 |
28 | for file in image_filenames:
29 | image_file = os.path.basename(file)
30 |
31 | image_file_name_ffhqalign = os.path.join(save_dir_ffhqalign, image_file)
32 | if os.path.exists(image_file_name_ffhqalign):
33 | continue
34 |
35 | face_img = cv2.imread(file)
36 | if face_img.shape[0]<250 or face_img.shape[1]<250:
37 | continue
38 | ret = app.get(face_img,crop_size,mode=mode)
39 | if len(ret)!=0 :
40 | cv2.imwrite(image_file_name_ffhqalign, ret[0])
41 | else:
42 | continue
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 |
48 | parser.add_argument('--input_dir',type=str,default = '/Data/VGGface2/train')
49 | parser.add_argument('--output_dir_ffhqalign',type=str,default = '/Data/VGGface2_FFHQalign')
50 | parser.add_argument('--crop_size',type=int,default = 256)
51 | parser.add_argument('--mode',type=str,default = 'ffhq',choices=['ffhq','newarc','both'])
52 |
53 | args = parser.parse_args()
54 | input_dir = args.input_dir
55 | output_dir_ffhqalign = args.output_dir_ffhqalign
56 | crop_size = args.crop_size
57 | mode = args.mode
58 |
59 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
60 |
61 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(320,320))
62 |
63 | dirs = sorted(os.listdir(input_dir))
64 | handle_dir_list = dirs
65 | for handle_dir_list_tmp in tqdm(handle_dir_list):
66 | align_image_dir(handle_dir_list_tmp)
67 |
68 |
--------------------------------------------------------------------------------
/scripts/crop_align_vggface2_FFHQalignandNewarcalign.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-15 19:42:42
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-19 16:17:54
7 | Description:
8 | '''
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 | import os
13 | import argparse
14 | import cv2
15 | import glob
16 | from tqdm import tqdm
17 | from insightface_func.face_detect_crop_ffhq_newarcAlign import Face_detect_crop
18 | import argparse
19 |
20 | def align_image_dir(dir_name_tmp):
21 | ori_path_tmp = os.path.join(input_dir, dir_name_tmp)
22 | image_filenames = glob.glob(os.path.join(ori_path_tmp,'*'))
23 | save_dir_newarcalign = os.path.join(output_dir_newarcalign,dir_name_tmp)
24 | save_dir_ffhqalign = os.path.join(output_dir_ffhqalign,dir_name_tmp)
25 | if not os.path.exists(save_dir_newarcalign):
26 | os.makedirs(save_dir_newarcalign)
27 | if not os.path.exists(save_dir_ffhqalign):
28 | os.makedirs(save_dir_ffhqalign)
29 |
30 |
31 | for file in image_filenames:
32 | image_file = os.path.basename(file)
33 |
34 | image_file_name_newarcalign = os.path.join(save_dir_newarcalign, image_file)
35 | image_file_name_ffhqalign = os.path.join(save_dir_ffhqalign, image_file)
36 | if os.path.exists(image_file_name_newarcalign) and os.path.exists(image_file_name_ffhqalign):
37 | continue
38 |
39 | face_img = cv2.imread(file)
40 | if face_img.shape[0]<250 or face_img.shape[1]<250:
41 | continue
42 | ret = app.get(face_img,crop_size,mode=mode)
43 | if len(ret)!=0 :
44 | cv2.imwrite(image_file_name_ffhqalign, ret[0])
45 | cv2.imwrite(image_file_name_newarcalign, ret[1])
46 | else:
47 | continue
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 |
53 | parser.add_argument('--input_dir',type=str,default = '/home/gdp/harddisk/Data1/VGGface2/train')
54 | parser.add_argument('--output_dir_ffhqalign',type=str,default = '/home/gdp/harddisk/Data1/VGGface2_ffhq_align')
55 | parser.add_argument('--output_dir_newarcalign',type=str,default = '/home/gdp/harddisk/Data1/VGGface2_newarc_align')
56 | parser.add_argument('--crop_size',type=int,default = 256)
57 | parser.add_argument('--mode',type=str,default = 'Both')
58 |
59 | args = parser.parse_args()
60 | input_dir = args.input_dir
61 | output_dir_newarcalign = args.output_dir_newarcalign
62 | output_dir_ffhqalign = args.output_dir_ffhqalign
63 | crop_size = args.crop_size
64 | mode = args.mode
65 |
66 | app = Face_detect_crop(name='antelope', root='./insightface_func/models')
67 |
68 | app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(320,320))
69 |
70 | dirs = sorted(os.listdir(input_dir))
71 | handle_dir_list = dirs
72 | for handle_dir_list_tmp in tqdm(handle_dir_list):
73 | align_image_dir(handle_dir_list_tmp)
74 |
75 |
--------------------------------------------------------------------------------
/scripts/inference_gfpgan_forvggface2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-16 19:30:52
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-20 17:57:55
7 | Description:
8 | '''
9 | import os
10 | import torch
11 | import argparse
12 |
13 | from tqdm import tqdm
14 | from vggface_dataset import getLoader
15 | from basicsr.utils import imwrite, tensor2img
16 | from gfpgan.archs.gfpganv1_arch import GFPGANv1
17 | from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
18 | import platform
19 |
20 | def main_worker(args):
21 | arch = args.arch
22 |
23 | with torch.no_grad():
24 | # initialize the GFP-GAN
25 | if arch == 'clean':
26 | gfpgan = GFPGANv1Clean(
27 | out_size=512,
28 | num_style_feat=512,
29 | channel_multiplier=args.channel,
30 | decoder_load_path=None,
31 | fix_decoder=False,
32 | num_mlp=8,
33 | input_is_latent=True,
34 | different_w=True,
35 | narrow=1,
36 | sft_half=True)
37 | else:
38 | gfpgan = GFPGANv1(
39 | out_size=512,
40 | num_style_feat=512,
41 | channel_multiplier=args.channel,
42 | decoder_load_path=None,
43 | fix_decoder=True,
44 | num_mlp=8,
45 | input_is_latent=True,
46 | different_w=True,
47 | narrow=1,
48 | sft_half=True)
49 |
50 | loadnet = torch.load(args.model_path)
51 | if 'params_ema' in loadnet:
52 | keyname = 'params_ema'
53 | else:
54 | keyname = 'params'
55 | gfpgan.load_state_dict(loadnet[keyname], strict=True)
56 | gfpgan.eval()
57 | gfpgan.cuda()
58 |
59 | test_dataloader = getLoader(args.input_path, 512, args.batchSize, 8)
60 |
61 | print(len(test_dataloader))
62 |
63 |
64 | for images,filenames in tqdm(test_dataloader):
65 | images = images.cuda()
66 |
67 | output_batch = gfpgan(images, return_rgb=False)[0]
68 |
69 | for tmp_index in range(len(output_batch)):
70 | tmp_filename = filenames[tmp_index]
71 |
72 | split_leave = tmp_filename.split(args.input_path)[-1].split(split_name)
73 | restored_face = output_batch[tmp_index]
74 | restored_face = tensor2img(restored_face, rgb2bgr=True, min_max=(-1, 1))
75 | restored_face = restored_face.astype('uint8')
76 |
77 | sub_dir = os.path.join(args.save_dir, split_leave[-2])
78 | os.makedirs(sub_dir, exist_ok=True)
79 |
80 | save_path_tmp = os.path.join(sub_dir, split_leave[-1])
81 |
82 | imwrite(restored_face, save_path_tmp)
83 |
84 |
85 | if __name__ == '__main__':
86 | parser = argparse.ArgumentParser()
87 |
88 | parser.add_argument('--arch', type=str, default='clean')
89 | parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
90 | parser.add_argument('--input_path', type=str, default='/Data/VGGface2_FFHQalign')
91 | parser.add_argument('--sft_half', default = False, action='store_true')
92 | parser.add_argument('--batchSize', type=int, default = 8)
93 | parser.add_argument('--save_dir', type=str, default = ' ')
94 | parser.add_argument('--channel', type=int, default=2)
95 |
96 | args = parser.parse_args()
97 |
98 | if platform.system().lower() == 'windows':
99 | split_name = '\\'
100 | elif platform.system().lower() == 'linux':
101 | split_name = '/'
102 | os.makedirs(args.save_dir, exist_ok=True)
103 | main_worker(args)
104 |
--------------------------------------------------------------------------------
/scripts/vggface_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Naiyuan liu
3 | Github: https://github.com/NNNNAI
4 | Date: 2021-11-16 19:32:18
5 | LastEditors: Naiyuan liu
6 | LastEditTime: 2021-11-16 19:35:12
7 | Description:
8 | '''
9 | import os
10 | from PIL import Image
11 | from torch.utils import data
12 | from torchvision import transforms as T
13 | import glob
14 |
15 |
16 |
17 | class TotalDataset(data.Dataset):
18 |
19 | def __init__(self,image_dir,
20 | content_transform):
21 | self.image_dir = image_dir
22 |
23 | self.content_transform= content_transform
24 | self.dataset = []
25 | self.mean = [0.5, 0.5, 0.5]
26 | self.std = [0.5, 0.5, 0.5]
27 | self.preprocess()
28 | self.num_images = len(self.dataset)
29 |
30 | def preprocess(self):
31 | additional_pattern = '*/*'
32 | self.dataset.extend(sorted(glob.glob(os.path.join(self.image_dir, additional_pattern), recursive=False)))
33 |
34 | print('Finished preprocessing the VGGFACE2 dataset...')
35 |
36 |
37 | def __getitem__(self, index):
38 | """Return single image."""
39 | dataset = self.dataset
40 |
41 | src_filename1 = dataset[index]
42 |
43 | src_image1 = self.content_transform(Image.open(src_filename1))
44 |
45 |
46 | return src_image1, src_filename1
47 |
48 |
49 | def __len__(self):
50 | """Return the number of images."""
51 | return self.num_images
52 |
53 | def getLoader(c_image_dir, ResizeSize=512, batch_size=16, num_workers=8):
54 | """Build and return a data loader."""
55 | c_transforms = []
56 |
57 |
58 | c_transforms.append(T.Resize([ResizeSize,ResizeSize]))
59 | c_transforms.append(T.ToTensor())
60 | c_transforms.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
61 |
62 | c_transforms = T.Compose(c_transforms)
63 |
64 | content_dataset = TotalDataset(c_image_dir, c_transforms)
65 |
66 |
67 | sampler = None
68 | content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
69 | drop_last=False,num_workers=num_workers,sampler=sampler,pin_memory=True)
70 | return content_data_loader
71 |
72 | def denorm(x):
73 | out = (x + 1) / 2
74 | return out.clamp_(0, 1)
75 |
76 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # line break before binary operator (W503)
4 | W503,
5 | # line break after binary operator (W504)
6 | W504,
7 | max-line-length=120
8 |
9 | [yapf]
10 | based_on_style = pep8
11 | column_limit = 120
12 | blank_line_before_nested_class_or_def = true
13 | split_before_expression_after_opening_paren = true
14 |
15 | [isort]
16 | line_length = 120
17 | multi_line_output = 0
18 | known_standard_library = pkg_resources,setuptools
19 | known_first_party = gfpgan
20 | known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
21 | no_lines_before = STDLIB,LOCALFOLDER
22 | default_section = THIRDPARTY
23 |
24 | [codespell]
25 | skip = .git,./docs/build
26 | count =
27 | quiet-level = 3
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import time
8 |
9 | version_file = 'gfpgan/version.py'
10 |
11 |
12 | def readme():
13 | with open('README.md', encoding='utf-8') as f:
14 | content = f.read()
15 | return content
16 |
17 |
18 | def get_git_hash():
19 |
20 | def _minimal_ext_cmd(cmd):
21 | # construct minimal environment
22 | env = {}
23 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
24 | v = os.environ.get(k)
25 | if v is not None:
26 | env[k] = v
27 | # LANGUAGE is used on win32
28 | env['LANGUAGE'] = 'C'
29 | env['LANG'] = 'C'
30 | env['LC_ALL'] = 'C'
31 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
32 | return out
33 |
34 | try:
35 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
36 | sha = out.strip().decode('ascii')
37 | except OSError:
38 | sha = 'unknown'
39 |
40 | return sha
41 |
42 |
43 | def get_hash():
44 | if os.path.exists('.git'):
45 | sha = get_git_hash()[:7]
46 | else:
47 | sha = 'unknown'
48 |
49 | return sha
50 |
51 |
52 | def write_version_py():
53 | content = """# GENERATED VERSION FILE
54 | # TIME: {}
55 | __version__ = '{}'
56 | __gitsha__ = '{}'
57 | version_info = ({})
58 | """
59 | sha = get_hash()
60 | with open('VERSION', 'r') as f:
61 | SHORT_VERSION = f.read().strip()
62 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
63 |
64 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
65 | with open(version_file, 'w') as f:
66 | f.write(version_file_str)
67 |
68 |
69 | def get_version():
70 | with open(version_file, 'r') as f:
71 | exec(compile(f.read(), version_file, 'exec'))
72 | return locals()['__version__']
73 |
74 |
75 | def get_requirements(filename='requirements.txt'):
76 | here = os.path.dirname(os.path.realpath(__file__))
77 | with open(os.path.join(here, filename), 'r') as f:
78 | requires = [line.replace('\n', '') for line in f.readlines()]
79 | return requires
80 |
81 |
82 | if __name__ == '__main__':
83 | write_version_py()
84 | setup(
85 | name='gfpgan',
86 | version=get_version(),
87 | description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration',
88 | long_description=readme(),
89 | long_description_content_type='text/markdown',
90 | author='Xintao Wang',
91 | author_email='xintao.wang@outlook.com',
92 | keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan',
93 | url='https://github.com/TencentARC/GFPGAN',
94 | include_package_data=True,
95 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
96 | classifiers=[
97 | 'Development Status :: 4 - Beta',
98 | 'License :: OSI Approved :: Apache Software License',
99 | 'Operating System :: OS Independent',
100 | 'Programming Language :: Python :: 3',
101 | 'Programming Language :: Python :: 3.7',
102 | 'Programming Language :: Python :: 3.8',
103 | ],
104 | license='Apache License Version 2.0',
105 | setup_requires=['cython', 'numpy'],
106 | install_requires=get_requirements(),
107 | zip_safe=False)
108 |
--------------------------------------------------------------------------------