├── .gitignore
├── LICENSE
├── README.md
├── config
└── AudioConfig.py
├── data
├── __init__.py
├── base_dataset.py
└── voxtest_dataset.py
├── experiments
└── demo_vox.sh
├── inference.py
├── misc
├── Audio_Source.zip
├── Input.zip
├── Mouth_Source.zip
├── Pose_Source.zip
├── demo.csv
├── demo.gif
├── demo_id.gif
├── method.png
└── output.gif
├── models
├── __init__.py
├── av_model.py
└── networks
│ ├── FAN_feature_extractor.py
│ ├── __init__.py
│ ├── __pycache__
│ ├── FAN_feature_extractor.cpython-36.pyc
│ ├── FAN_feature_extractor.cpython-37.pyc
│ ├── Voxceleb_model.cpython-36.pyc
│ ├── Voxceleb_model.cpython-37.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── architecture.cpython-36.pyc
│ ├── architecture.cpython-37.pyc
│ ├── audio_architecture.cpython-36.pyc
│ ├── audio_architecture.cpython-37.pyc
│ ├── audio_network.cpython-37.pyc
│ ├── base_network.cpython-36.pyc
│ ├── base_network.cpython-37.pyc
│ ├── discriminator.cpython-36.pyc
│ ├── discriminator.cpython-37.pyc
│ ├── encoder.cpython-36.pyc
│ ├── encoder.cpython-37.pyc
│ ├── generator.cpython-36.pyc
│ ├── generator.cpython-37.pyc
│ ├── loss.cpython-36.pyc
│ ├── loss.cpython-37.pyc
│ ├── normalization.cpython-36.pyc
│ ├── normalization.cpython-37.pyc
│ ├── stylegan2.cpython-36.pyc
│ ├── stylegan2.cpython-37.pyc
│ ├── vision_network.cpython-36.pyc
│ └── vision_network.cpython-37.pyc
│ ├── architecture.py
│ ├── audio_network.py
│ ├── base_network.py
│ ├── discriminator.py
│ ├── encoder.py
│ ├── generator.py
│ ├── loss.py
│ ├── sync_batchnorm
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── batchnorm.cpython-36.pyc
│ │ ├── batchnorm.cpython-37.pyc
│ │ ├── comm.cpython-36.pyc
│ │ ├── comm.cpython-37.pyc
│ │ ├── replicate.cpython-36.pyc
│ │ ├── replicate.cpython-37.pyc
│ │ ├── scatter_gather.cpython-36.pyc
│ │ └── scatter_gather.cpython-37.pyc
│ ├── batchnorm.py
│ ├── batchnorm_reimpl.py
│ ├── comm.py
│ ├── replicate.py
│ ├── scatter_gather.py
│ └── unittest.py
│ ├── util.py
│ └── vision_network.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── requirements.txt
├── scripts
├── align_68.py
└── prepare_testing_files.py
└── util
├── __init__.py
├── html.py
├── iter_counter.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | **/*.pyc
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | # custom
108 | demo/
109 | checkpoints/
110 | results/
111 | .vscode
112 | .idea
113 | *.pkl
114 | *.pkl.json
115 | *.log.json
116 | work_dirs/
117 | *.avi
118 |
119 | # Pytorch
120 | *.pth
121 | *.tar
122 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution 4.0 International Public License
58 |
59 | By exercising the Licensed Rights (defined below), You accept and agree
60 | to be bound by the terms and conditions of this Creative Commons
61 | Attribution 4.0 International Public License ("Public License"). To the
62 | extent this Public License may be interpreted as a contract, You are
63 | granted the Licensed Rights in consideration of Your acceptance of
64 | these terms and conditions, and the Licensor grants You such rights in
65 | consideration of benefits the Licensor receives from making the
66 | Licensed Material available under these terms and conditions.
67 |
68 | Section 1 -- Definitions.
69 |
70 | a. Adapted Material means material subject to Copyright and Similar
71 | Rights that is derived from or based upon the Licensed Material
72 | and in which the Licensed Material is translated, altered,
73 | arranged, transformed, or otherwise modified in a manner requiring
74 | permission under the Copyright and Similar Rights held by the
75 | Licensor. For purposes of this Public License, where the Licensed
76 | Material is a musical work, performance, or sound recording,
77 | Adapted Material is always produced where the Licensed Material is
78 | synched in timed relation with a moving image.
79 |
80 | b. Adapter's License means the license You apply to Your Copyright
81 | and Similar Rights in Your contributions to Adapted Material in
82 | accordance with the terms and conditions of this Public License.
83 |
84 | c. Copyright and Similar Rights means copyright and/or similar rights
85 | closely related to copyright including, without limitation,
86 | performance, broadcast, sound recording, and Sui Generis Database
87 | Rights, without regard to how the rights are labeled or
88 | categorized. For purposes of this Public License, the rights
89 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
90 | Rights.
91 |
92 | d. Effective Technological Measures means those measures that, in the
93 | absence of proper authority, may not be circumvented under laws
94 | fulfilling obligations under Article 11 of the WIPO Copyright
95 | Treaty adopted on December 20, 1996, and/or similar international
96 | agreements.
97 |
98 | e. Exceptions and Limitations means fair use, fair dealing, and/or
99 | any other exception or limitation to Copyright and Similar Rights
100 | that applies to Your use of the Licensed Material.
101 |
102 | f. Licensed Material means the artistic or literary work, database,
103 | or other material to which the Licensor applied this Public
104 | License.
105 |
106 | g. Licensed Rights means the rights granted to You subject to the
107 | terms and conditions of this Public License, which are limited to
108 | all Copyright and Similar Rights that apply to Your use of the
109 | Licensed Material and that the Licensor has authority to license.
110 |
111 | h. Licensor means the individual(s) or entity(ies) granting rights
112 | under this Public License.
113 |
114 | i. Share means to provide material to the public by any means or
115 | process that requires permission under the Licensed Rights, such
116 | as reproduction, public display, public performance, distribution,
117 | dissemination, communication, or importation, and to make material
118 | available to the public including in ways that members of the
119 | public may access the material from a place and at a time
120 | individually chosen by them.
121 |
122 | j. Sui Generis Database Rights means rights other than copyright
123 | resulting from Directive 96/9/EC of the European Parliament and of
124 | the Council of 11 March 1996 on the legal protection of databases,
125 | as amended and/or succeeded, as well as other essentially
126 | equivalent rights anywhere in the world.
127 |
128 | k. You means the individual or entity exercising the Licensed Rights
129 | under this Public License. Your has a corresponding meaning.
130 |
131 | Section 2 -- Scope.
132 |
133 | a. License grant.
134 |
135 | 1. Subject to the terms and conditions of this Public License,
136 | the Licensor hereby grants You a worldwide, royalty-free,
137 | non-sublicensable, non-exclusive, irrevocable license to
138 | exercise the Licensed Rights in the Licensed Material to:
139 |
140 | a. reproduce and Share the Licensed Material, in whole or
141 | in part; and
142 |
143 | b. produce, reproduce, and Share Adapted Material.
144 |
145 | 2. Exceptions and Limitations. For the avoidance of doubt, where
146 | Exceptions and Limitations apply to Your use, this Public
147 | License does not apply, and You do not need to comply with
148 | its terms and conditions.
149 |
150 | 3. Term. The term of this Public License is specified in Section
151 | 6(a).
152 |
153 | 4. Media and formats; technical modifications allowed. The
154 | Licensor authorizes You to exercise the Licensed Rights in
155 | all media and formats whether now known or hereafter created,
156 | and to make technical modifications necessary to do so. The
157 | Licensor waives and/or agrees not to assert any right or
158 | authority to forbid You from making technical modifications
159 | necessary to exercise the Licensed Rights, including
160 | technical modifications necessary to circumvent Effective
161 | Technological Measures. For purposes of this Public License,
162 | simply making modifications authorized by this Section 2(a)
163 | (4) never produces Adapted Material.
164 |
165 | 5. Downstream recipients.
166 |
167 | a. Offer from the Licensor -- Licensed Material. Every
168 | recipient of the Licensed Material automatically
169 | receives an offer from the Licensor to exercise the
170 | Licensed Rights under the terms and conditions of this
171 | Public License.
172 |
173 | b. No downstream restrictions. You may not offer or impose
174 | any additional or different terms or conditions on, or
175 | apply any Effective Technological Measures to, the
176 | Licensed Material if doing so restricts exercise of the
177 | Licensed Rights by any recipient of the Licensed
178 | Material.
179 |
180 | 6. No endorsement. Nothing in this Public License constitutes or
181 | may be construed as permission to assert or imply that You
182 | are, or that Your use of the Licensed Material is, connected
183 | with, or sponsored, endorsed, or granted official status by,
184 | the Licensor or others designated to receive attribution as
185 | provided in Section 3(a)(1)(A)(i).
186 |
187 | b. Other rights.
188 |
189 | 1. Moral rights, such as the right of integrity, are not
190 | licensed under this Public License, nor are publicity,
191 | privacy, and/or other similar personality rights; however, to
192 | the extent possible, the Licensor waives and/or agrees not to
193 | assert any such rights held by the Licensor to the limited
194 | extent necessary to allow You to exercise the Licensed
195 | Rights, but not otherwise.
196 |
197 | 2. Patent and trademark rights are not licensed under this
198 | Public License.
199 |
200 | 3. To the extent possible, the Licensor waives any right to
201 | collect royalties from You for the exercise of the Licensed
202 | Rights, whether directly or through a collecting society
203 | under any voluntary or waivable statutory or compulsory
204 | licensing scheme. In all other cases the Licensor expressly
205 | reserves any right to collect such royalties.
206 |
207 | Section 3 -- License Conditions.
208 |
209 | Your exercise of the Licensed Rights is expressly made subject to the
210 | following conditions.
211 |
212 | a. Attribution.
213 |
214 | 1. If You Share the Licensed Material (including in modified
215 | form), You must:
216 |
217 | a. retain the following if it is supplied by the Licensor
218 | with the Licensed Material:
219 |
220 | i. identification of the creator(s) of the Licensed
221 | Material and any others designated to receive
222 | attribution, in any reasonable manner requested by
223 | the Licensor (including by pseudonym if
224 | designated);
225 |
226 | ii. a copyright notice;
227 |
228 | iii. a notice that refers to this Public License;
229 |
230 | iv. a notice that refers to the disclaimer of
231 | warranties;
232 |
233 | v. a URI or hyperlink to the Licensed Material to the
234 | extent reasonably practicable;
235 |
236 | b. indicate if You modified the Licensed Material and
237 | retain an indication of any previous modifications; and
238 |
239 | c. indicate the Licensed Material is licensed under this
240 | Public License, and include the text of, or the URI or
241 | hyperlink to, this Public License.
242 |
243 | 2. You may satisfy the conditions in Section 3(a)(1) in any
244 | reasonable manner based on the medium, means, and context in
245 | which You Share the Licensed Material. For example, it may be
246 | reasonable to satisfy the conditions by providing a URI or
247 | hyperlink to a resource that includes the required
248 | information.
249 |
250 | 3. If requested by the Licensor, You must remove any of the
251 | information required by Section 3(a)(1)(A) to the extent
252 | reasonably practicable.
253 |
254 | 4. If You Share Adapted Material You produce, the Adapter's
255 | License You apply must not prevent recipients of the Adapted
256 | Material from complying with this Public License.
257 |
258 | Section 4 -- Sui Generis Database Rights.
259 |
260 | Where the Licensed Rights include Sui Generis Database Rights that
261 | apply to Your use of the Licensed Material:
262 |
263 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
264 | to extract, reuse, reproduce, and Share all or a substantial
265 | portion of the contents of the database;
266 |
267 | b. if You include all or a substantial portion of the database
268 | contents in a database in which You have Sui Generis Database
269 | Rights, then the database in which You have Sui Generis Database
270 | Rights (but not its individual contents) is Adapted Material; and
271 |
272 | c. You must comply with the conditions in Section 3(a) if You Share
273 | all or a substantial portion of the contents of the database.
274 |
275 | For the avoidance of doubt, this Section 4 supplements and does not
276 | replace Your obligations under this Public License where the Licensed
277 | Rights include other Copyright and Similar Rights.
278 |
279 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
280 |
281 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
282 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
283 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
284 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
285 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
286 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
287 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
288 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
289 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
290 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
291 |
292 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
293 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
294 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
295 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
296 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
297 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
298 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
299 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
300 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
301 |
302 | c. The disclaimer of warranties and limitation of liability provided
303 | above shall be interpreted in a manner that, to the extent
304 | possible, most closely approximates an absolute disclaimer and
305 | waiver of all liability.
306 |
307 | Section 6 -- Term and Termination.
308 |
309 | a. This Public License applies for the term of the Copyright and
310 | Similar Rights licensed here. However, if You fail to comply with
311 | this Public License, then Your rights under this Public License
312 | terminate automatically.
313 |
314 | b. Where Your right to use the Licensed Material has terminated under
315 | Section 6(a), it reinstates:
316 |
317 | 1. automatically as of the date the violation is cured, provided
318 | it is cured within 30 days of Your discovery of the
319 | violation; or
320 |
321 | 2. upon express reinstatement by the Licensor.
322 |
323 | For the avoidance of doubt, this Section 6(b) does not affect any
324 | right the Licensor may have to seek remedies for Your violations
325 | of this Public License.
326 |
327 | c. For the avoidance of doubt, the Licensor may also offer the
328 | Licensed Material under separate terms or conditions or stop
329 | distributing the Licensed Material at any time; however, doing so
330 | will not terminate this Public License.
331 |
332 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
333 | License.
334 |
335 | Section 7 -- Other Terms and Conditions.
336 |
337 | a. The Licensor shall not be bound by any additional or different
338 | terms or conditions communicated by You unless expressly agreed.
339 |
340 | b. Any arrangements, understandings, or agreements regarding the
341 | Licensed Material not stated herein are separate from and
342 | independent of the terms and conditions of this Public License.
343 |
344 | Section 8 -- Interpretation.
345 |
346 | a. For the avoidance of doubt, this Public License does not, and
347 | shall not be interpreted to, reduce, limit, restrict, or impose
348 | conditions on any use of the Licensed Material that could lawfully
349 | be made without permission under this Public License.
350 |
351 | b. To the extent possible, if any provision of this Public License is
352 | deemed unenforceable, it shall be automatically reformed to the
353 | minimum extent necessary to make it enforceable. If the provision
354 | cannot be reformed, it shall be severed from this Public License
355 | without affecting the enforceability of the remaining terms and
356 | conditions.
357 |
358 | c. No term or condition of this Public License will be waived and no
359 | failure to comply consented to unless expressly agreed to by the
360 | Licensor.
361 |
362 | d. Nothing in this Public License constitutes or may be interpreted
363 | as a limitation upon, or waiver of, any privileges and immunities
364 | that apply to the Licensor or You, including from the legal
365 | processes of any jurisdiction or authority.
366 |
367 | =======================================================================
368 |
369 | Creative Commons is not a party to its public licenses.
370 | Notwithstanding, Creative Commons may elect to apply one of its public
371 | licenses to material it publishes and in those instances will be
372 | considered the "Licensor." Except for the limited purpose of indicating
373 | that material is shared under a Creative Commons public license or as
374 | otherwise permitted by the Creative Commons policies published at
375 | creativecommons.org/policies, Creative Commons does not authorize the
376 | use of the trademark "Creative Commons" or any other trademark or logo
377 | of Creative Commons without its prior written consent including,
378 | without limitation, in connection with any unauthorized modifications
379 | to any of its public licenses or any other arrangements,
380 | understandings, or agreements concerning use of licensed material. For
381 | the avoidance of doubt, this paragraph does not form part of the public
382 | licenses.
383 |
384 | Creative Commons may be contacted at creativecommons.org.
385 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021)
2 |
3 | [Hang Zhou](https://hangz-nju-cuhk.github.io/), Yasheng Sun, [Wayne Wu](https://wywu.github.io/), [Chen Change Loy](http://personal.ie.cuhk.edu.hk/~ccloy/), [Xiaogang Wang](http://www.ee.cuhk.edu.hk/~xgwang/), and [Ziwei Liu](https://liuziwei7.github.io/).
4 |
5 |
6 |
7 |
8 | ### [Project](https://hangz-nju-cuhk.github.io/projects/PC-AVS) | [Paper](https://arxiv.org/abs/2104.11116) | [Demo](https://www.youtube.com/watch?v=lNQQHIggnUg)
9 |
10 |
11 | We propose **Pose-Controllable Audio-Visual System (PC-AVS)**,
12 | which achieves free pose control when driving arbitrary talking faces with audios. Instead of learning pose motions from audios, we leverage another pose source video to compensate only for head motions.
13 | The key is to devise an implicit low-dimension pose code that is free of mouth shape or identity information.
14 | In this way, audio-visual representations are modularized into spaces of three key factors: speech content, head pose, and identity information.
15 |
16 |
17 |
18 | ## Requirements
19 | * Python 3.6 and [Pytorch](https://pytorch.org/) 1.3.0 are used. Basic requirements are listed in the 'requirements.txt'.
20 |
21 | ```
22 | pip install -r requirements.txt
23 | ```
24 |
25 |
26 | ## Quick Start: Generate Demo Results
27 | * Download the pre-trained [checkpoints](https://drive.google.com/file/d/1Zehr3JLIpzdg2S5zZrhIbpYPKF-4gKU_/view?usp=sharing).
28 |
29 | * Create the default folder ```./checkpoints``` and
30 | unzip the ```demo.zip``` at ```./checkpoints/demo```. There should be 5 ```pth```s in it.
31 |
32 | * Unzip all ```*.zip``` files within the ```misc``` folder.
33 |
34 | * Run the demo scripts:
35 | ``` bash
36 | bash experiments/demo_vox.sh
37 | ```
38 |
39 |
40 | * The ```--gen_video``` argument is by default on,
41 | [ffmpeg](https://www.ffmpeg.org/) >= 4.0.0 is required to use this flag in linux systems.
42 | All frames along with an ```avconcat.mp4``` video file will be saved in the ```./id_517600055_pose_517600078_audio_681600002/results``` folder.
43 |
44 |
45 |
46 |
47 | From left to right are the *reference input*, the *generated results*,
48 | the *pose source video* and the *synced original video* with the driving audio.
49 |
50 | ## Prepare Testing Meta Data
51 |
52 | * ### Automatic VoxCeleb2 Data Formulation
53 |
54 | The inference code ```experiments/demo.sh``` refers to ```./misc/demo.csv``` for testing data paths.
55 | In linux systems, any applicable ```csv``` file can be created automatically by running:
56 |
57 | ```bash
58 | python scripts/prepare_testing_files.py
59 | ```
60 |
61 | Then modify the ```meta_path_vox``` in ```experiments/demo_vox.sh``` to ```'./misc/demo2.csv'``` and run
62 |
63 | ``` bash
64 | bash experiments/demo_vox.sh
65 | ```
66 | An additional result should be seen saved.
67 |
68 | * ### Metadata Details
69 |
70 | Detailedly, in ```scripts/prepare_testing_files.py``` there are certain flags which enjoy great flexibility when formulating the metadata:
71 |
72 | 1. ```--src_pose_path``` denotes the driving pose source path.
73 | It can be an ```mp4``` file or a folder containing frames in the form of ```%06d.jpg``` starting from 0.
74 |
75 | 2. ```--src_audio_path``` denotes the audio source's path.
76 | It can be an ```mp3``` audio file or an ```mp4``` video file. If a video is given,
77 | the frames will be automatically saved in ```./misc/Mouth_Source/video_name```, and disables the ```--src_mouth_frame_path``` flag.
78 |
79 | 3. ```--src_mouth_frame_path```. When ```--src_audio_path``` is not a video path,
80 | this flags could provide the folder containing the video frames synced with the source audio.
81 |
82 | 4. ```--src_input_path``` is the path to the input reference image. When the path is a video file, we will convert it to frames.
83 |
84 | 5. ```--csv_path``` the path to the to-be-saved metadata.
85 |
86 | You can manually modify the metadata ```csv``` file or add lines to it according to the rules defined in the ```scripts/prepare_testing_files.py``` file or the dataloader ```data/voxtest_dataset.py```.
87 |
88 | We provide a number of demo choices in the ```misc``` folder, including several ones used in our [video](https://www.youtube.com/watch?v=lNQQHIggnUg).
89 | Feel free to rearrange them even across folders. And you are welcome to record audio files by yourself.
90 |
91 | * ### Self-Prepared Data Processing
92 | Our model handles only **VoxCeleb2-like** cropped data, thus pre-processing is needed for self-prepared data.
93 |
94 | To process self-prepared data [face-alignment](https://github.com/1adrianb/face-alignment) is needed. It can be installed by running
95 | ```
96 | pip install face-alignment
97 | ```
98 |
99 | Assuming that a video is already processed into a ```[name]``` folder by previous steps through ```prepare_testing_files.py```,
100 | you can run
101 | ```
102 | python scripts/align_68.py --folder_path [name]
103 | ```
104 |
105 | The cropped images will be saved at an additional ```[name_cropped]``` folder.
106 | Then you can manually change the ```demo.csv``` file or alter the directory folder path and run the preprocessing file again.
107 |
108 | ## Have More Fun
109 | * We also support synthesizing and driving a talking head solely from audio.
110 | * Download the pre-trained [checkpoints](https://drive.google.com/file/d/1K-UZYRn7Oz2VEumrIRMCvr69FhQjQNx8/view?usp=share_link).
111 | * Checkout the *speech2talkingface* branch.
112 | ``` bash
113 | git checkout speech2talkingface
114 | ```
115 | * Follow similar steps as quick start and run the demo scripts.
116 | ``` bash
117 | bash experiments/demo_vox.sh
118 | ```
119 |
120 |
121 |
122 | From left to right are the *generated results*,
123 | the *pose source video* and the *synced original video* with the driving audio.
124 |
125 |
126 | ## Train Your Own Model
127 | * Not supported yet.
128 |
129 | ## License and Citation
130 |
131 | The usage of this software is under [CC-BY-4.0](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/LICENSE).
132 | ```
133 | @InProceedings{zhou2021pose,
134 | author = {Zhou, Hang and Sun, Yasheng and Wu, Wayne and Loy, Chen Change and Wang, Xiaogang and Liu, Ziwei},
135 | title = {Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation},
136 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
137 | year = {2021}
138 | }
139 |
140 | @inproceedings{sun2021speech2talking,
141 | title={Speech2Talking-Face: Inferring and Driving a Face with Synchronized Audio-Visual Representation.},
142 | author={Sun, Yasheng and Zhou, Hang and Liu, Ziwei and Koike, Hideki},
143 | booktitle={IJCAI},
144 | volume={2},
145 | pages={4},
146 | year={2021}
147 | }
148 | ```
149 |
150 | ## Acknowledgement
151 | * The structure of this codebase is borrowed from [SPADE](https://github.com/NVlabs/SPADE).
152 | * The generator is borrowed from [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).
153 | * The audio encoder is borrowed from [voxceleb_trainer](https://github.com/clovaai/voxceleb_trainer).
154 |
--------------------------------------------------------------------------------
/config/AudioConfig.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import librosa.filters
3 | import numpy as np
4 | from scipy import signal
5 | from scipy.io import wavfile
6 | import lws
7 |
8 |
9 | class AudioConfig:
10 | def __init__(self, frame_rate=25,
11 | sample_rate=16000,
12 | num_mels=80,
13 | fft_size=1280,
14 | hop_size=160,
15 | num_frames_per_clip=5,
16 | save_mel=True
17 | ):
18 | self.frame_rate = frame_rate
19 | self.sample_rate = sample_rate
20 | self.num_bins_per_frame = int(sample_rate / hop_size / frame_rate)
21 | self.num_frames_per_clip = num_frames_per_clip
22 | self.silence_threshold = 2
23 | self.num_mels = num_mels
24 | self.save_mel = save_mel
25 | self.fmin = 125
26 | self.fmax = 7600
27 | self.fft_size = fft_size
28 | self.hop_size = hop_size
29 | self.frame_shift_ms = None
30 | self.min_level_db = -100
31 | self.ref_level_db = 20
32 | self.rescaling = True
33 | self.rescaling_max = 0.999
34 | self.allow_clipping_in_normalization = True
35 | self.log_scale_min = -32.23619130191664
36 | self.norm_audio = True
37 | self.with_phase = False
38 |
39 | def load_wav(self, path):
40 | return librosa.core.load(path, sr=self.sample_rate)[0]
41 |
42 | def audio_normalize(self, samples, desired_rms=0.1, eps=1e-4):
43 | rms = np.maximum(eps, np.sqrt(np.mean(samples ** 2)))
44 | samples = samples * (desired_rms / rms)
45 | return samples
46 |
47 | def generate_spectrogram_magphase(self, audio):
48 | spectro = librosa.core.stft(audio, hop_length=self.get_hop_size(), n_fft=self.fft_size, center=True)
49 | spectro_mag, spectro_phase = librosa.core.magphase(spectro)
50 | spectro_mag = np.expand_dims(spectro_mag, axis=0)
51 | if self.with_phase:
52 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
53 | return spectro_mag, spectro_phase
54 | else:
55 | return spectro_mag
56 |
57 | def save_wav(self, wav, path):
58 | wav *= 32767 / max(0.01, np.max(np.abs(wav)))
59 | wavfile.write(path, self.sample_rate, wav.astype(np.int16))
60 |
61 | def trim(self, quantized):
62 | start, end = self.start_and_end_indices(quantized, self.silence_threshold)
63 | return quantized[start:end]
64 |
65 | def adjust_time_resolution(self, quantized, mel):
66 | """Adjust time resolution by repeating features
67 |
68 | Args:
69 | quantized (ndarray): (T,)
70 | mel (ndarray): (N, D)
71 |
72 | Returns:
73 | tuple: Tuple of (T,) and (T, D)
74 | """
75 | assert len(quantized.shape) == 1
76 | assert len(mel.shape) == 2
77 |
78 | upsample_factor = quantized.size // mel.shape[0]
79 | mel = np.repeat(mel, upsample_factor, axis=0)
80 | n_pad = quantized.size - mel.shape[0]
81 | if n_pad != 0:
82 | assert n_pad > 0
83 | mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0)
84 |
85 | # trim
86 | start, end = self.start_and_end_indices(quantized, self.silence_threshold)
87 |
88 | return quantized[start:end], mel[start:end, :]
89 |
90 | adjast_time_resolution = adjust_time_resolution # 'adjust' is correct spelling, this is for compatibility
91 |
92 | def start_and_end_indices(self, quantized, silence_threshold=2):
93 | for start in range(quantized.size):
94 | if abs(quantized[start] - 127) > silence_threshold:
95 | break
96 | for end in range(quantized.size - 1, 1, -1):
97 | if abs(quantized[end] - 127) > silence_threshold:
98 | break
99 |
100 | assert abs(quantized[start] - 127) > silence_threshold
101 | assert abs(quantized[end] - 127) > silence_threshold
102 |
103 | return start, end
104 |
105 | def melspectrogram(self, y):
106 | D = self._lws_processor().stft(y).T
107 | S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
108 | if not self.allow_clipping_in_normalization:
109 | assert S.max() <= 0 and S.min() - self.min_level_db >= 0
110 | return self._normalize(S)
111 |
112 | def get_hop_size(self):
113 | hop_size = self.hop_size
114 | if hop_size is None:
115 | assert self.frame_shift_ms is not None
116 | hop_size = int(self.frame_shift_ms / 1000 * self.sample_rate)
117 | return hop_size
118 |
119 | def _lws_processor(self):
120 | return lws.lws(self.fft_size, self.get_hop_size(), mode="speech")
121 |
122 | def lws_num_frames(self, length, fsize, fshift):
123 | """Compute number of time frames of lws spectrogram
124 | """
125 | pad = (fsize - fshift)
126 | if length % fshift == 0:
127 | M = (length + pad * 2 - fsize) // fshift + 1
128 | else:
129 | M = (length + pad * 2 - fsize) // fshift + 2
130 | return M
131 |
132 | def lws_pad_lr(self, x, fsize, fshift):
133 | """Compute left and right padding lws internally uses
134 | """
135 | M = self.lws_num_frames(len(x), fsize, fshift)
136 | pad = (fsize - fshift)
137 | T = len(x) + 2 * pad
138 | r = (M - 1) * fshift + fsize - T
139 | return pad, pad + r
140 |
141 |
142 | def _linear_to_mel(self, spectrogram):
143 | global _mel_basis
144 | _mel_basis = self._build_mel_basis()
145 | return np.dot(_mel_basis, spectrogram)
146 |
147 | def _build_mel_basis(self):
148 | assert self.fmax <= self.sample_rate // 2
149 | return librosa.filters.mel(self.sample_rate, self.fft_size,
150 | fmin=self.fmin, fmax=self.fmax,
151 | n_mels=self.num_mels)
152 |
153 | def _amp_to_db(self, x):
154 | min_level = np.exp(self.min_level_db / 20 * np.log(10))
155 | return 20 * np.log10(np.maximum(min_level, x))
156 |
157 | def _db_to_amp(self, x):
158 | return np.power(10.0, x * 0.05)
159 |
160 | def _normalize(self, S):
161 | return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
162 |
163 | def _denormalize(self, S):
164 | return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
165 |
166 | def read_audio(self, audio_path):
167 | wav = self.load_wav(audio_path)
168 | if self.norm_audio:
169 | wav = self.audio_normalize(wav)
170 | else:
171 | wav = wav / np.abs(wav).max()
172 |
173 | return wav
174 |
175 | def audio_to_spectrogram(self, wav):
176 | if self.save_mel:
177 | spectrogram = self.melspectrogram(wav).astype(np.float32).T
178 | else:
179 | spectrogram = self.generate_spectrogram_magphase(wav)
180 | return spectrogram
181 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.utils.data
3 | from data.base_dataset import BaseDataset
4 |
5 |
6 | def find_dataset_using_name(dataset_name):
7 | # Given the option --dataset [datasetname],
8 | # the file "datasets/datasetname_dataset.py"
9 | # will be imported.
10 | dataset_filename = "data." + dataset_name + "_dataset"
11 | datasetlib = importlib.import_module(dataset_filename)
12 |
13 | # In the file, the class called DatasetNameDataset() will
14 | # be instantiated. It has to be a subclass of BaseDataset,
15 | # and it is case-insensitive.
16 | dataset = None
17 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
18 | for name, cls in datasetlib.__dict__.items():
19 | if name.lower() == target_dataset_name.lower() \
20 | and issubclass(cls, BaseDataset):
21 | dataset = cls
22 |
23 | if dataset is None:
24 | raise ValueError("In %s.py, there should be a subclass of BaseDataset "
25 | "with class name that matches %s in lowercase." %
26 | (dataset_filename, target_dataset_name))
27 |
28 | return dataset
29 |
30 |
31 | def get_option_setter(dataset_name):
32 | dataset_class = find_dataset_using_name(dataset_name)
33 | return dataset_class.modify_commandline_options
34 |
35 |
36 | def create_dataloader(opt):
37 | dataset_modes = opt.dataset_mode.split(',')
38 | if len(dataset_modes) == 1:
39 | dataset = find_dataset_using_name(opt.dataset_mode)
40 | instance = dataset()
41 | instance.initialize(opt)
42 | print("dataset [%s] of size %d was created" %
43 | (type(instance).__name__, len(instance)))
44 | if not opt.isTrain:
45 | shuffle = False
46 | else:
47 | shuffle = True
48 | dataloader = torch.utils.data.DataLoader(
49 | instance,
50 | batch_size=opt.batchSize,
51 | shuffle=shuffle,
52 | num_workers=int(opt.nThreads),
53 | drop_last=opt.isTrain
54 | )
55 | return dataloader
56 |
57 | else:
58 | dataloader_dict = {}
59 | for dataset_mode in dataset_modes:
60 | dataset = find_dataset_using_name(dataset_mode)
61 | instance = dataset()
62 | instance.initialize(opt)
63 | print("dataset [%s] of size %d was created" %
64 | (type(instance).__name__, len(instance)))
65 | if not opt.isTrain:
66 | shuffle = not opt.defined_driven
67 | else:
68 | shuffle = True
69 | dataloader = torch.utils.data.DataLoader(
70 | instance,
71 | batch_size=opt.batchSize,
72 | shuffle=shuffle,
73 | num_workers=int(opt.nThreads),
74 | drop_last=opt.isTrain
75 | )
76 | dataloader_dict[dataset_mode] = dataloader
77 | return dataloader_dict
78 |
79 |
80 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torch
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import cv2
6 |
7 |
8 | class BaseDataset(data.Dataset):
9 | def __init__(self):
10 | super(BaseDataset, self).__init__()
11 |
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train):
14 | return parser
15 |
16 | def initialize(self, opt):
17 | pass
18 |
19 | def to_Tensor(self, img):
20 | if img.ndim == 3:
21 | wrapped_img = img.transpose(2, 0, 1) / 255.0
22 | elif img.ndim == 4:
23 | wrapped_img = img.transpose(0, 3, 1, 2) / 255.0
24 | else:
25 | wrapped_img = img / 255.0
26 | wrapped_img = torch.from_numpy(wrapped_img).float()
27 |
28 | return wrapped_img * 2 - 1
29 |
30 | def face_augmentation(self, img, crop_size):
31 | img = self._color_transfer(img)
32 | img = self._reshape(img, crop_size)
33 | img = self._blur_and_sharp(img)
34 | return img
35 |
36 | def _blur_and_sharp(self, img):
37 | blur = np.random.randint(0, 2)
38 | img2 = img.copy()
39 | output = []
40 | for i in range(len(img2)):
41 | if blur:
42 | ksize = np.random.choice([3, 5, 7, 9])
43 | output.append(cv2.medianBlur(img2[i], ksize))
44 | else:
45 | kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
46 | output.append(cv2.filter2D(img2[i], -1, kernel))
47 | output = np.stack(output)
48 | return output
49 |
50 | def _color_transfer(self, img):
51 |
52 | transfer_c = np.random.uniform(0.3, 1.6)
53 |
54 | start_channel = np.random.randint(0, 2)
55 | end_channel = np.random.randint(start_channel + 1, 4)
56 |
57 | img2 = img.copy()
58 |
59 | img2[:, :, :, start_channel:end_channel] = np.minimum(np.maximum(img[:, :, :, start_channel:end_channel] * transfer_c, np.zeros(img[:, :, :, start_channel:end_channel].shape)),
60 | np.ones(img[:, :, :, start_channel:end_channel].shape) * 255)
61 | return img2
62 |
63 | def perspective_transform(self, img, crop_size=224, pers_size=10, enlarge_size=-10):
64 | h, w, c = img.shape
65 | dst = np.array([
66 | [-enlarge_size, -enlarge_size],
67 | [-enlarge_size + pers_size, w + enlarge_size],
68 | [h + enlarge_size, -enlarge_size],
69 | [h + enlarge_size - pers_size, w + enlarge_size],], dtype=np.float32)
70 | src = np.array([[-enlarge_size, -enlarge_size], [-enlarge_size, w + enlarge_size],
71 | [h + enlarge_size, -enlarge_size], [h + enlarge_size, w + enlarge_size]]).astype(np.float32())
72 | M = cv2.getPerspectiveTransform(src, dst)
73 | warped = cv2.warpPerspective(img, M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE)
74 | return warped, M
75 |
76 | def _reshape(self, img, crop_size):
77 | reshape = np.random.randint(0, 2)
78 | reshape_size = np.random.randint(15, 25)
79 | extra_padding_size = np.random.randint(0, reshape_size // 2)
80 | pers_size = np.random.randint(20, 30) * pow(-1, np.random.randint(2))
81 |
82 | enlarge_size = np.random.randint(20, 40) * pow(-1, np.random.randint(2))
83 | shape = img[0].shape
84 | img2 = img.copy()
85 | output = []
86 | for i in range(len(img2)):
87 | if reshape:
88 | im = cv2.resize(img2[i], (shape[0] - reshape_size*2, shape[1] + reshape_size*2))
89 | im = cv2.copyMakeBorder(im, 0, 0, reshape_size + extra_padding_size, reshape_size + extra_padding_size, cv2.cv2.BORDER_REFLECT)
90 | im = im[reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :, :]
91 | im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size)
92 | output.append(im)
93 | else:
94 | im = cv2.resize(img2[i], (shape[0] + reshape_size*2, shape[1] - reshape_size*2))
95 | im = cv2.copyMakeBorder(im, reshape_size + extra_padding_size, reshape_size + extra_padding_size, 0, 0, cv2.cv2.BORDER_REFLECT)
96 | im = im[:, reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :]
97 | im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size)
98 | output.append(im)
99 | output = np.stack(output)
100 | return output
--------------------------------------------------------------------------------
/data/voxtest_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | from config import AudioConfig
5 | import shutil
6 | import cv2
7 | import glob
8 | import random
9 | import torch
10 | from data.base_dataset import BaseDataset
11 | import util.util as util
12 |
13 |
14 | class VOXTestDataset(BaseDataset):
15 |
16 | @staticmethod
17 | def modify_commandline_options(parser, is_train):
18 | parser.add_argument('--no_pairing_check', action='store_true',
19 | help='If specified, skip sanity check of correct label-image file pairing')
20 | return parser
21 |
22 | def cv2_loader(self, img_str):
23 | img_array = np.frombuffer(img_str, dtype=np.uint8)
24 | return cv2.imdecode(img_array, cv2.IMREAD_COLOR)
25 |
26 | def load_img(self, image_path, M=None, crop=True, crop_len=16):
27 | img = cv2.imread(image_path)
28 |
29 | if img is None:
30 | raise Exception('None Image')
31 |
32 | if M is not None:
33 | img = cv2.warpAffine(img, M, (self.opt.crop_size, self.opt.crop_size), borderMode=cv2.BORDER_REPLICATE)
34 |
35 | if crop:
36 | img = img[:self.opt.crop_size - crop_len*2, crop_len:self.opt.crop_size - crop_len]
37 | if self.opt.target_crop_len > 0:
38 | img = img[self.opt.target_crop_len:self.opt.crop_size - self.opt.target_crop_len, self.opt.target_crop_len:self.opt.crop_size - self.opt.target_crop_len]
39 | img = cv2.resize(img, (self.opt.crop_size, self.opt.crop_size))
40 |
41 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
42 | return img
43 |
44 | def fill_list(self, tmp_list):
45 | length = len(tmp_list)
46 | if length % self.opt.batchSize != 0:
47 | end = math.ceil(length / self.opt.batchSize) * self.opt.batchSize
48 | tmp_list = tmp_list + tmp_list[-1 * (end - length) :]
49 | return tmp_list
50 |
51 | def frame2audio_indexs(self, frame_inds):
52 | start_frame_ind = frame_inds - self.audio.num_frames_per_clip // 2
53 |
54 | start_audio_inds = start_frame_ind * self.audio.num_bins_per_frame
55 | return start_audio_inds
56 |
57 | def initialize(self, opt):
58 | self.opt = opt
59 | self.path_label = opt.path_label
60 | self.clip_len = opt.clip_len
61 | self.frame_interval = opt.frame_interval
62 | self.num_clips = opt.num_clips
63 | self.frame_rate = opt.frame_rate
64 | self.num_inputs = opt.num_inputs
65 | self.filename_tmpl = opt.filename_tmpl
66 |
67 | self.mouth_num_frames = None
68 | self.mouth_frame_path = None
69 | self.pose_num_frames = None
70 |
71 | self.audio = AudioConfig.AudioConfig(num_frames_per_clip=opt.num_frames_per_clip, hop_size=opt.hop_size)
72 | self.num_audio_bins = self.audio.num_frames_per_clip * self.audio.num_bins_per_frame
73 |
74 |
75 | assert len(opt.path_label.split()) == 8, opt.path_label
76 | id_path, ref_num, \
77 | pose_frame_path, pose_num_frames, \
78 | audio_path, mouth_frame_path, mouth_num_frames, spectrogram_path = opt.path_label.split()
79 |
80 |
81 | id_idx, mouth_idx = id_path.split('/')[-1], audio_path.split('/')[-1].split('.')[0]
82 | if not os.path.isdir(pose_frame_path):
83 | pose_frame_path = id_path
84 | pose_num_frames = 1
85 |
86 | pose_idx = pose_frame_path.split('/')[-1]
87 | id_idx, pose_idx, mouth_idx = str(id_idx), str(pose_idx), str(mouth_idx)
88 |
89 | self.processed_file_savepath = os.path.join('results', 'id_' + id_idx + '_pose_' + pose_idx +
90 | '_audio_' + os.path.basename(audio_path)[:-4])
91 | if not os.path.exists(self.processed_file_savepath): os.makedirs(self.processed_file_savepath)
92 |
93 |
94 | if not os.path.isfile(spectrogram_path):
95 | wav = self.audio.read_audio(audio_path)
96 | self.spectrogram = self.audio.audio_to_spectrogram(wav)
97 |
98 | else:
99 | self.spectrogram = np.load(spectrogram_path)
100 |
101 | if os.path.isdir(mouth_frame_path):
102 | self.mouth_frame_path = mouth_frame_path
103 | self.mouth_num_frames = mouth_num_frames
104 |
105 | self.pose_num_frames = int(pose_num_frames)
106 |
107 | self.target_frame_inds = np.arange(2, len(self.spectrogram) // self.audio.num_bins_per_frame - 2)
108 | self.audio_inds = self.frame2audio_indexs(self.target_frame_inds)
109 |
110 | self.dataset_size = len(self.target_frame_inds)
111 |
112 | id_img_paths = glob.glob(os.path.join(id_path, '*.jpg')) + glob.glob(os.path.join(id_path, '*.png'))
113 | random.shuffle(id_img_paths)
114 | opt.num_inputs = min(len(id_img_paths), opt.num_inputs)
115 | id_img_tensors = []
116 |
117 | for i, image_path in enumerate(id_img_paths):
118 | id_img_tensor = self.to_Tensor(self.load_img(image_path))
119 | id_img_tensors += [id_img_tensor]
120 | shutil.copyfile(image_path, os.path.join(self.processed_file_savepath, 'ref_id_{}.jpg'.format(i)))
121 | if i == (opt.num_inputs - 1):
122 | break
123 | self.id_img_tensor = torch.stack(id_img_tensors)
124 | self.pose_frame_path = pose_frame_path
125 | self.audio_path = audio_path
126 | self.id_path = id_path
127 | self.mouth_frame_path = mouth_frame_path
128 | self.initialized = False
129 |
130 |
131 | def paths_match(self, path1, path2):
132 | filename1_without_ext = os.path.splitext(os.path.basename(path1)[-10:])[0]
133 | filename2_without_ext = os.path.splitext(os.path.basename(path2)[-10:])[0]
134 | return filename1_without_ext == filename2_without_ext
135 |
136 | def load_one_frame(self, frame_ind, video_path, M=None, crop=True):
137 | filepath = os.path.join(video_path, self.filename_tmpl.format(frame_ind))
138 | img = self.load_img(filepath, M=M, crop=crop)
139 | img = self.to_Tensor(img)
140 | return img
141 |
142 | def load_spectrogram(self, audio_ind):
143 | mel_shape = self.spectrogram.shape
144 |
145 | if (audio_ind + self.num_audio_bins) <= mel_shape[0] and audio_ind >= 0:
146 | spectrogram = np.array(self.spectrogram[audio_ind:audio_ind + self.num_audio_bins, :]).astype('float32')
147 | else:
148 | print('(audio_ind {} + opt.num_audio_bins {}) > mel_shape[0] {} '.format(audio_ind, self.num_audio_bins,
149 | mel_shape[0]))
150 | if audio_ind > 0:
151 | spectrogram = np.array(self.spectrogram[audio_ind:audio_ind + self.num_audio_bins, :]).astype('float32')
152 | else:
153 | spectrogram = np.zeros((self.num_audio_bins, mel_shape[1])).astype(np.float16).astype(np.float32)
154 |
155 | spectrogram = torch.from_numpy(spectrogram)
156 | spectrogram = spectrogram.unsqueeze(0)
157 |
158 | spectrogram = spectrogram.transpose(-2, -1)
159 | return spectrogram
160 |
161 | def __getitem__(self, index):
162 |
163 | img_index = self.target_frame_inds[index]
164 | mel_index = self.audio_inds[index]
165 |
166 | pose_index = util.calc_loop_idx(img_index, self.pose_num_frames)
167 |
168 | pose_frame = self.load_one_frame(pose_index, self.pose_frame_path)
169 |
170 | if os.path.isdir(self.mouth_frame_path):
171 | mouth_frame = self.load_one_frame(img_index, self.mouth_frame_path)
172 | else:
173 | mouth_frame = torch.zeros_like(pose_frame)
174 |
175 | spectrograms = self.load_spectrogram(mel_index)
176 |
177 | input_dict = {
178 | 'input': self.id_img_tensor,
179 | 'target': mouth_frame,
180 | 'driving_pose_frames': pose_frame,
181 | 'augmented': pose_frame,
182 | 'label': torch.zeros(1),
183 | }
184 | if self.opt.use_audio:
185 | input_dict['spectrograms'] = spectrograms
186 |
187 | # Give subclasses a chance to modify the final output
188 | self.postprocess(input_dict)
189 |
190 | return input_dict
191 |
192 | def postprocess(self, input_dict):
193 | return input_dict
194 |
195 | def __len__(self):
196 | return self.dataset_size
197 |
198 | def get_processed_file_savepath(self):
199 | return self.processed_file_savepath
200 |
--------------------------------------------------------------------------------
/experiments/demo_vox.sh:
--------------------------------------------------------------------------------
1 | meta_path_vox='./misc/demo.csv'
2 |
3 | python -u inference.py \
4 | --name demo \
5 | --meta_path_vox ${meta_path_vox} \
6 | --dataset_mode voxtest \
7 | --netG modulate \
8 | --netA resseaudio \
9 | --netA_sync ressesync \
10 | --netD multiscale \
11 | --netV resnext \
12 | --netE fan \
13 | --model av \
14 | --gpu_ids 0 \
15 | --clip_len 1 \
16 | --batchSize 16 \
17 | --style_dim 2560 \
18 | --nThreads 4 \
19 | --input_id_feature \
20 | --generate_interval 1 \
21 | --style_feature_loss \
22 | --use_audio 1 \
23 | --noise_pose \
24 | --driving_pose \
25 | --gen_video \
26 | --generate_from_audio_only \
27 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('..')
4 | from options.test_options import TestOptions
5 | import torch
6 | from models import create_model
7 | import data
8 | import util.util as util
9 | from tqdm import tqdm
10 |
11 |
12 | def video_concat(processed_file_savepath, name, video_names, audio_path):
13 | cmd = ['ffmpeg']
14 | num_inputs = len(video_names)
15 | for video_name in video_names:
16 | cmd += ['-i', '\'' + str(os.path.join(processed_file_savepath, video_name + '.mp4'))+'\'',]
17 |
18 | cmd += ['-filter_complex hstack=inputs=' + str(num_inputs),
19 | '\'' + str(os.path.join(processed_file_savepath, name+'.mp4')) + '\'', '-loglevel error -y']
20 | cmd = ' '.join(cmd)
21 | os.system(cmd)
22 |
23 | video_add_audio(name, audio_path, processed_file_savepath)
24 |
25 |
26 | def video_add_audio(name, audio_path, processed_file_savepath):
27 | os.system('cp {} {}'.format(audio_path, processed_file_savepath))
28 | cmd = ['ffmpeg', '-i', '\'' + os.path.join(processed_file_savepath, name + '.mp4') + '\'',
29 | '-i', audio_path,
30 | '-q:v 0',
31 | '-strict -2',
32 | '\'' + os.path.join(processed_file_savepath, 'av' + name + '.mp4') + '\'',
33 | '-loglevel error -y']
34 | cmd = ' '.join(cmd)
35 | os.system(cmd)
36 |
37 |
38 | def img2video(dst_path, prefix, video_path):
39 | cmd = ['ffmpeg', '-i', '\'' + video_path + '/' + prefix + '%d.jpg'
40 | + '\'', '-q:v 0', '\'' + dst_path + '/' + prefix + '.mp4' + '\'', '-loglevel error -y']
41 | cmd = ' '.join(cmd)
42 | os.system(cmd)
43 |
44 |
45 | def inference_single_audio(opt, path_label, model):
46 | #
47 | opt.path_label = path_label
48 | dataloader = data.create_dataloader(opt)
49 | processed_file_savepath = dataloader.dataset.get_processed_file_savepath()
50 |
51 | idx = 0
52 | if opt.driving_pose:
53 | video_names = ['Input_', 'G_Pose_Driven_', 'Pose_Source_', 'Mouth_Source_']
54 | else:
55 | video_names = ['Input_', 'G_Fix_Pose_', 'Mouth_Source_']
56 | is_mouth_frame = os.path.isdir(dataloader.dataset.mouth_frame_path)
57 | if not is_mouth_frame:
58 | video_names.pop()
59 | save_paths = []
60 | for name in video_names:
61 | save_path = os.path.join(processed_file_savepath, name)
62 | util.mkdir(save_path)
63 | save_paths.append(save_path)
64 | for data_i in tqdm(dataloader):
65 | # print('==============', i, '===============')
66 | fake_image_original_pose_a, fake_image_driven_pose_a = model.forward(data_i, mode='inference')
67 |
68 | for num in range(len(fake_image_driven_pose_a)):
69 | util.save_torch_img(data_i['input'][num], os.path.join(save_paths[0], video_names[0] + str(idx) + '.jpg'))
70 | if opt.driving_pose:
71 | util.save_torch_img(fake_image_driven_pose_a[num],
72 | os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg'))
73 | util.save_torch_img(data_i['driving_pose_frames'][num],
74 | os.path.join(save_paths[2], video_names[2] + str(idx) + '.jpg'))
75 | else:
76 | util.save_torch_img(fake_image_original_pose_a[num],
77 | os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg'))
78 | if is_mouth_frame:
79 | util.save_torch_img(data_i['target'][num], os.path.join(save_paths[-1], video_names[-1] + str(idx) + '.jpg'))
80 | idx += 1
81 |
82 | if opt.gen_video:
83 | for i, video_name in enumerate(video_names):
84 | img2video(processed_file_savepath, video_name, save_paths[i])
85 | video_concat(processed_file_savepath, 'concat', video_names, dataloader.dataset.audio_path)
86 |
87 | print('results saved...' + processed_file_savepath)
88 | del dataloader
89 | return
90 |
91 |
92 | def main():
93 |
94 | opt = TestOptions().parse()
95 | opt.isTrain = False
96 | torch.manual_seed(0)
97 | model = create_model(opt).cuda()
98 | model.eval()
99 |
100 | with open(opt.meta_path_vox, 'r') as f:
101 | lines = f.read().splitlines()
102 |
103 | for clip_idx, path_label in enumerate(lines):
104 | try:
105 | assert len(path_label.split()) == 8, path_label
106 |
107 | inference_single_audio(opt, path_label, model)
108 |
109 | except Exception as ex:
110 | import traceback
111 | traceback.print_exc()
112 | print(path_label + '\n')
113 | print(str(ex))
114 |
115 |
116 | if __name__ == '__main__':
117 | main()
118 |
--------------------------------------------------------------------------------
/misc/Audio_Source.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Audio_Source.zip
--------------------------------------------------------------------------------
/misc/Input.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Input.zip
--------------------------------------------------------------------------------
/misc/Mouth_Source.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Mouth_Source.zip
--------------------------------------------------------------------------------
/misc/Pose_Source.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Pose_Source.zip
--------------------------------------------------------------------------------
/misc/demo.csv:
--------------------------------------------------------------------------------
1 | misc/Input/517600055 1 misc/Pose_Source/517600078 160 misc/Audio_Source/681600002.mp3 misc/Mouth_Source/681600002 363 dummy
2 |
--------------------------------------------------------------------------------
/misc/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/demo.gif
--------------------------------------------------------------------------------
/misc/demo_id.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/demo_id.gif
--------------------------------------------------------------------------------
/misc/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/method.png
--------------------------------------------------------------------------------
/misc/output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/output.gif
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | def find_model_using_name(model_name):
4 | # Given the option --model [modelname],
5 | # the file "models/modelname_model.py"
6 | # will be imported.
7 | model_filename = "models." + model_name + "_model"
8 | modellib = importlib.import_module(model_filename)
9 |
10 | # In the file, the class called ModelNameModel() will
11 | # be instantiated. It has to be a subclass of torch.nn.Module,
12 | # and it is case-insensitive.
13 | model = None
14 | target_model_name = model_name.replace('_', '') + 'model'
15 | for name, cls in modellib.__dict__.items():
16 | if name.lower() == target_model_name.lower():
17 | model = cls
18 |
19 | if model is None:
20 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
21 | exit(0)
22 |
23 | return model
24 |
25 |
26 | def get_option_setter(model_name):
27 | model_class = find_model_using_name(model_name)
28 | return model_class.modify_commandline_options
29 |
30 |
31 | def create_model(opt):
32 | model = find_model_using_name(opt.model)
33 | instance = model(opt)
34 | print("model [%s] was created" % (type(instance).__name__))
35 |
36 | return instance
37 |
--------------------------------------------------------------------------------
/models/networks/FAN_feature_extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from util import util
4 | import torch.nn.functional as F
5 |
6 |
7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8 | "3x3 convolution with padding"
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10 | stride=strd, padding=padding, bias=bias)
11 |
12 |
13 | class ConvBlock(nn.Module):
14 | def __init__(self, in_planes, out_planes):
15 | super(ConvBlock, self).__init__()
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22 |
23 | if in_planes != out_planes:
24 | self.downsample = nn.Sequential(
25 | nn.BatchNorm2d(in_planes),
26 | nn.ReLU(True),
27 | nn.Conv2d(in_planes, out_planes,
28 | kernel_size=1, stride=1, bias=False),
29 | )
30 | else:
31 | self.downsample = None
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out1 = self.bn1(x)
37 | out1 = F.relu(out1, True)
38 | out1 = self.conv1(out1)
39 |
40 | out2 = self.bn2(out1)
41 | out2 = F.relu(out2, True)
42 | out2 = self.conv2(out2)
43 |
44 | out3 = self.bn3(out2)
45 | out3 = F.relu(out3, True)
46 | out3 = self.conv3(out3)
47 |
48 | out3 = torch.cat((out1, out2, out3), 1)
49 |
50 | if self.downsample is not None:
51 | residual = self.downsample(residual)
52 |
53 | out3 += residual
54 |
55 | return out3
56 |
57 |
58 | class HourGlass(nn.Module):
59 | def __init__(self, num_modules, depth, num_features):
60 | super(HourGlass, self).__init__()
61 | self.num_modules = num_modules
62 | self.depth = depth
63 | self.features = num_features
64 | self.dropout = nn.Dropout(0.5)
65 |
66 | self._generate_network(self.depth)
67 |
68 | def _generate_network(self, level):
69 | self.add_module('b1_' + str(level), ConvBlock(256, 256))
70 |
71 | self.add_module('b2_' + str(level), ConvBlock(256, 256))
72 |
73 | if level > 1:
74 | self._generate_network(level - 1)
75 | else:
76 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
77 |
78 | self.add_module('b3_' + str(level), ConvBlock(256, 256))
79 |
80 | def _forward(self, level, inp):
81 | # Upper branch
82 | up1 = inp
83 | up1 = self._modules['b1_' + str(level)](up1)
84 | up1 = self.dropout(up1)
85 | # Lower branch
86 | low1 = F.max_pool2d(inp, 2, stride=2)
87 | low1 = self._modules['b2_' + str(level)](low1)
88 |
89 | if level > 1:
90 | low2 = self._forward(level - 1, low1)
91 | else:
92 | low2 = low1
93 | low2 = self._modules['b2_plus_' + str(level)](low2)
94 |
95 | low3 = low2
96 | low3 = self._modules['b3_' + str(level)](low3)
97 | up1size = up1.size()
98 | rescale_size = (up1size[2], up1size[3])
99 | up2 = F.upsample(low3, size=rescale_size, mode='bilinear')
100 |
101 | return up1 + up2
102 |
103 | def forward(self, x):
104 | return self._forward(self.depth, x)
105 |
106 |
107 | class FAN_use(nn.Module):
108 | def __init__(self):
109 | super(FAN_use, self).__init__()
110 | self.num_modules = 1
111 |
112 | # Base part
113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
114 | self.bn1 = nn.BatchNorm2d(64)
115 | self.conv2 = ConvBlock(64, 128)
116 | self.conv3 = ConvBlock(128, 128)
117 | self.conv4 = ConvBlock(128, 256)
118 |
119 | # Stacking part
120 | hg_module = 0
121 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
122 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
123 | self.add_module('conv_last' + str(hg_module),
124 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
125 | self.add_module('l' + str(hg_module), nn.Conv2d(256,
126 | 68, kernel_size=1, stride=1, padding=0))
127 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
128 |
129 | if hg_module < self.num_modules - 1:
130 | self.add_module(
131 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
132 | self.add_module('al' + str(hg_module), nn.Conv2d(68,
133 | 256, kernel_size=1, stride=1, padding=0))
134 |
135 | self.avgpool = nn.MaxPool2d((2, 2), 2)
136 | self.conv6 = nn.Conv2d(68, 1, 3, 2, 1)
137 | self.fc = nn.Linear(28 * 28, 512)
138 | self.bn5 = nn.BatchNorm2d(68)
139 | self.relu = nn.ReLU(True)
140 |
141 | def forward(self, x):
142 | x = F.relu(self.bn1(self.conv1(x)), True)
143 | x = F.max_pool2d(self.conv2(x), 2)
144 | x = self.conv3(x)
145 | x = self.conv4(x)
146 |
147 | previous = x
148 |
149 | i = 0
150 | hg = self._modules['m' + str(i)](previous)
151 |
152 | ll = hg
153 | ll = self._modules['top_m_' + str(i)](ll)
154 |
155 | ll = self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll))
156 | tmp_out = self._modules['l' + str(i)](F.relu(ll))
157 |
158 | net = self.relu(self.bn5(tmp_out))
159 | net = self.conv6(net)
160 | net = net.view(-1, net.shape[-2] * net.shape[-1])
161 | net = self.relu(net)
162 | net = self.fc(net)
163 | return net
164 |
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.networks.base_network import BaseNetwork
3 | from models.networks.loss import *
4 | from models.networks.discriminator import MultiscaleDiscriminator, ImageDiscriminator
5 | from models.networks.generator import ModulateGenerator
6 | from models.networks.encoder import ResSEAudioEncoder, ResNeXtEncoder, ResSESyncEncoder, FanEncoder
7 | import util.util as util
8 |
9 |
10 | def find_network_using_name(target_network_name, filename):
11 | target_class_name = target_network_name + filename
12 | module_name = 'models.networks.' + filename
13 | network = util.find_class_in_module(target_class_name, module_name)
14 |
15 | assert issubclass(network, BaseNetwork), \
16 | "Class %s should be a subclass of BaseNetwork" % network
17 |
18 | return network
19 |
20 |
21 | def modify_commandline_options(parser, is_train):
22 | opt, _ = parser.parse_known_args()
23 |
24 | netG_cls = find_network_using_name(opt.netG, 'generator')
25 | parser = netG_cls.modify_commandline_options(parser, is_train)
26 | if is_train:
27 | netD_cls = find_network_using_name(opt.netD, 'discriminator')
28 | parser = netD_cls.modify_commandline_options(parser, is_train)
29 | netA_cls = find_network_using_name(opt.netA, 'encoder')
30 | parser = netA_cls.modify_commandline_options(parser, is_train)
31 | # parser = netA_sync_cls.modify_commandline_options(parser, is_train)
32 |
33 | return parser
34 |
35 |
36 | def create_network(cls, opt):
37 | net = cls(opt)
38 | net.print_network()
39 | if len(opt.gpu_ids) > 0:
40 | assert(torch.cuda.is_available())
41 | net.cuda()
42 | net.init_weights(opt.init_type, opt.init_variance)
43 | return net
44 |
45 |
46 | def define_networks(opt, name, type):
47 | netG_cls = find_network_using_name(name, type)
48 | return create_network(netG_cls, opt)
49 |
50 | def define_G(opt):
51 | netG_cls = find_network_using_name(opt.netG, 'generator')
52 | return create_network(netG_cls, opt)
53 |
54 |
55 | def define_D(opt):
56 | netD_cls = find_network_using_name(opt.netD, 'discriminator')
57 | return create_network(netD_cls, opt)
58 |
59 | def define_A(opt):
60 | netA_cls = find_network_using_name(opt.netA, 'encoder')
61 | return create_network(netA_cls, opt)
62 |
63 | def define_A_sync(opt):
64 | netA_cls = find_network_using_name(opt.netA_sync, 'encoder')
65 | return create_network(netA_cls, opt)
66 |
67 |
68 | def define_E(opt):
69 | # there exists only one encoder type
70 | netE_cls = find_network_using_name(opt.netE, 'encoder')
71 | return create_network(netE_cls, opt)
72 |
73 |
74 | def define_V(opt):
75 | # there exists only one encoder type
76 | netV_cls = find_network_using_name(opt.netV, 'encoder')
77 | return create_network(netV_cls, opt)
78 |
79 |
80 | def define_P(opt):
81 | netP_cls = find_network_using_name(opt.netP, 'encoder')
82 | return create_network(netP_cls, opt)
83 |
84 |
85 | def define_F_rec(opt):
86 | netF_rec_cls = find_network_using_name(opt.netF_rec, 'encoder')
87 | return create_network(netF_rec_cls, opt)
88 |
--------------------------------------------------------------------------------
/models/networks/__pycache__/FAN_feature_extractor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/FAN_feature_extractor.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/FAN_feature_extractor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/FAN_feature_extractor.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/Voxceleb_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/Voxceleb_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/Voxceleb_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/Voxceleb_model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/architecture.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/architecture.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/architecture.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/architecture.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/audio_architecture.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/audio_architecture.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/audio_architecture.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/audio_architecture.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/audio_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/audio_network.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/base_network.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/base_network.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/base_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/base_network.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/discriminator.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/discriminator.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/discriminator.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/discriminator.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/encoder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/encoder.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/encoder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/encoder.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/generator.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/generator.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/generator.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/generator.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/normalization.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/normalization.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/normalization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/normalization.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/stylegan2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/stylegan2.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/stylegan2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/stylegan2.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/vision_network.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/vision_network.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/__pycache__/vision_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/vision_network.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/architecture.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 | from models.networks.encoder import VGGEncoder
6 | from util import util
7 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
8 | import torch.nn.utils.spectral_norm as spectral_norm
9 |
10 |
11 | # VGG architecter, used for the perceptual loss using a pretrained VGG network
12 | class VGG19(torch.nn.Module):
13 | def __init__(self, requires_grad=False):
14 | super(VGG19, self).__init__()
15 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
16 | self.slice1 = torch.nn.Sequential()
17 | self.slice2 = torch.nn.Sequential()
18 | self.slice3 = torch.nn.Sequential()
19 | self.slice4 = torch.nn.Sequential()
20 | self.slice5 = torch.nn.Sequential()
21 | for x in range(2):
22 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
23 | for x in range(2, 7):
24 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
25 | for x in range(7, 12):
26 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
27 | for x in range(12, 21):
28 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
29 | for x in range(21, 30):
30 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
31 | if not requires_grad:
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 | def forward(self, X):
36 | h_relu1 = self.slice1(X)
37 | h_relu2 = self.slice2(h_relu1)
38 | h_relu3 = self.slice3(h_relu2)
39 | h_relu4 = self.slice4(h_relu3)
40 | h_relu5 = self.slice5(h_relu4)
41 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
42 | return out
43 |
44 |
45 | class VGGFace19(torch.nn.Module):
46 | def __init__(self, opt, requires_grad=False):
47 | super(VGGFace19, self).__init__()
48 | self.model = VGGEncoder(opt)
49 | self.opt = opt
50 | ckpt = torch.load(opt.VGGFace_pretrain_path)
51 | print("=> loading checkpoint '{}'".format(opt.VGGFace_pretrain_path))
52 | util.copy_state_dict(ckpt, self.model)
53 | vgg_pretrained_features = self.model.model.features
54 | len_features = len(self.model.model.features)
55 | self.slice1 = torch.nn.Sequential()
56 | self.slice2 = torch.nn.Sequential()
57 | self.slice3 = torch.nn.Sequential()
58 | self.slice4 = torch.nn.Sequential()
59 | self.slice5 = torch.nn.Sequential()
60 | self.slice6 = torch.nn.Sequential()
61 |
62 | for x in range(2):
63 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
64 | for x in range(2, 7):
65 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
66 | for x in range(7, 12):
67 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
68 | for x in range(12, 21):
69 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
70 | for x in range(21, 30):
71 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
72 | for x in range(30, len_features):
73 | self.slice6.add_module(str(x), vgg_pretrained_features[x])
74 | if not requires_grad:
75 | for param in self.parameters():
76 | param.requires_grad = False
77 |
78 | def forward(self, X):
79 | X = X.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size)
80 | h_relu1 = self.slice1(X)
81 | h_relu2 = self.slice2(h_relu1)
82 | h_relu3 = self.slice3(h_relu2)
83 | h_relu4 = self.slice4(h_relu3)
84 | h_relu5 = self.slice5(h_relu4)
85 | h_relu6 = self.slice6(h_relu5)
86 | out = [h_relu3, h_relu4, h_relu5, h_relu6, h_relu6]
87 | return out
88 |
89 |
90 | # Returns a function that creates a normalization function
91 | # that does not condition on semantic map
92 | def get_nonspade_norm_layer(opt, norm_type='instance'):
93 | # helper function to get # output channels of the previous layer
94 | def get_out_channel(layer):
95 | if hasattr(layer, 'out_channels'):
96 | return getattr(layer, 'out_channels')
97 | return layer.weight.size(0)
98 |
99 | # this function will be returned
100 | def add_norm_layer(layer):
101 | nonlocal norm_type
102 | if norm_type.startswith('spectral'):
103 | layer = spectral_norm(layer)
104 | subnorm_type = norm_type[len('spectral'):]
105 | else:
106 | subnorm_type = norm_type
107 |
108 | if subnorm_type == 'none' or len(subnorm_type) == 0:
109 | return layer
110 |
111 | # remove bias in the previous layer, which is meaningless
112 | # since it has no effect after normalization
113 | if getattr(layer, 'bias', None) is not None:
114 | delattr(layer, 'bias')
115 | layer.register_parameter('bias', None)
116 |
117 | if subnorm_type == 'batch':
118 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
119 | elif subnorm_type == 'syncbatch':
120 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
121 | elif subnorm_type == 'instance':
122 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
123 | else:
124 | raise ValueError('normalization layer %s is not recognized' % subnorm_type)
125 |
126 | return nn.Sequential(layer, norm_layer)
127 |
128 | return add_norm_layer
129 |
--------------------------------------------------------------------------------
/models/networks/audio_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ResNetSE(nn.Module):
6 | def __init__(self, block, layers, num_filters, nOut, encoder_type='SAP', n_mels=80, n_mel_T=1, log_input=True, **kwargs):
7 | super(ResNetSE, self).__init__()
8 |
9 | print('Embedding size is %d, encoder %s.' % (nOut, encoder_type))
10 |
11 | self.inplanes = num_filters[0]
12 | self.encoder_type = encoder_type
13 | self.n_mels = n_mels
14 | self.log_input = log_input
15 |
16 | self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
17 | self.relu = nn.ReLU(inplace=True)
18 | self.bn1 = nn.BatchNorm2d(num_filters[0])
19 |
20 | self.layer1 = self._make_layer(block, num_filters[0], layers[0])
21 | self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
22 | self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
23 | self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2))
24 |
25 | self.instancenorm = nn.InstanceNorm1d(n_mels)
26 |
27 | outmap_size = int(self.n_mels * n_mel_T / 8)
28 |
29 | self.attention = nn.Sequential(
30 | nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
31 | nn.ReLU(),
32 | nn.BatchNorm1d(128),
33 | nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
34 | nn.Softmax(dim=2),
35 | )
36 |
37 | if self.encoder_type == "SAP":
38 | out_dim = num_filters[3] * outmap_size
39 | elif self.encoder_type == "ASP":
40 | out_dim = num_filters[3] * outmap_size * 2
41 | else:
42 | raise ValueError('Undefined encoder')
43 |
44 | self.fc = nn.Linear(out_dim, nOut)
45 |
46 | for m in self.modules():
47 | if isinstance(m, nn.Conv2d):
48 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
49 | elif isinstance(m, nn.BatchNorm2d):
50 | nn.init.constant_(m.weight, 1)
51 | nn.init.constant_(m.bias, 0)
52 |
53 | def _make_layer(self, block, planes, blocks, stride=1):
54 | downsample = None
55 | if stride != 1 or self.inplanes != planes * block.expansion:
56 | downsample = nn.Sequential(
57 | nn.Conv2d(self.inplanes, planes * block.expansion,
58 | kernel_size=1, stride=stride, bias=False),
59 | nn.BatchNorm2d(planes * block.expansion),
60 | )
61 |
62 | layers = []
63 | layers.append(block(self.inplanes, planes, stride, downsample))
64 | self.inplanes = planes * block.expansion
65 | for i in range(1, blocks):
66 | layers.append(block(self.inplanes, planes))
67 |
68 | return nn.Sequential(*layers)
69 |
70 | def new_parameter(self, *size):
71 | out = nn.Parameter(torch.FloatTensor(*size))
72 | nn.init.xavier_normal_(out)
73 | return out
74 |
75 | def forward(self, x):
76 |
77 | # with torch.no_grad():
78 | # x = self.torchfb(x) + 1e-6
79 | # if self.log_input: x = x.log()
80 | # x = self.instancenorm(x).unsqueeze(1)
81 |
82 | x = self.conv1(x)
83 | x = self.relu(x)
84 | x = self.bn1(x)
85 |
86 | x = self.layer1(x)
87 | x = self.layer2(x)
88 | x = self.layer3(x)
89 | x = self.layer4(x)
90 |
91 | x = x.reshape(x.size()[0], -1, x.size()[-1])
92 |
93 | w = self.attention(x)
94 |
95 | if self.encoder_type == "SAP":
96 | x = torch.sum(x * w, dim=2)
97 | elif self.encoder_type == "ASP":
98 | mu = torch.sum(x * w, dim=2)
99 | sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
100 | x = torch.cat((mu, sg), 1)
101 |
102 | x = x.view(x.size()[0], -1)
103 | x = self.fc(x)
104 |
105 | return x
106 |
107 |
108 |
109 |
110 | class SEBasicBlock(nn.Module):
111 | expansion = 1
112 |
113 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
114 | super(SEBasicBlock, self).__init__()
115 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
116 | self.bn1 = nn.BatchNorm2d(planes)
117 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
118 | self.bn2 = nn.BatchNorm2d(planes)
119 | self.relu = nn.ReLU(inplace=True)
120 | self.se = SELayer(planes, reduction)
121 | self.downsample = downsample
122 | self.stride = stride
123 |
124 | def forward(self, x):
125 | residual = x
126 |
127 | out = self.conv1(x)
128 | out = self.relu(out)
129 | out = self.bn1(out)
130 |
131 | out = self.conv2(out)
132 | out = self.bn2(out)
133 | out = self.se(out)
134 |
135 | if self.downsample is not None:
136 | residual = self.downsample(x)
137 |
138 | out += residual
139 | out = self.relu(out)
140 | return out
141 |
142 |
143 | class SEBottleneck(nn.Module):
144 | expansion = 4
145 |
146 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
147 | super(SEBottleneck, self).__init__()
148 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
149 | self.bn1 = nn.BatchNorm2d(planes)
150 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
151 | padding=1, bias=False)
152 | self.bn2 = nn.BatchNorm2d(planes)
153 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
154 | self.bn3 = nn.BatchNorm2d(planes * 4)
155 | self.relu = nn.ReLU(inplace=True)
156 | self.se = SELayer(planes * 4, reduction)
157 | self.downsample = downsample
158 | self.stride = stride
159 |
160 | def forward(self, x):
161 | residual = x
162 |
163 | out = self.conv1(x)
164 | out = self.bn1(out)
165 | out = self.relu(out)
166 |
167 | out = self.conv2(out)
168 | out = self.bn2(out)
169 | out = self.relu(out)
170 |
171 | out = self.conv3(out)
172 | out = self.bn3(out)
173 | out = self.se(out)
174 |
175 | if self.downsample is not None:
176 | residual = self.downsample(x)
177 |
178 | out += residual
179 | out = self.relu(out)
180 |
181 | return out
182 |
183 |
184 | class SELayer(nn.Module):
185 | def __init__(self, channel, reduction=8):
186 | super(SELayer, self).__init__()
187 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
188 | self.fc = nn.Sequential(
189 | nn.Linear(channel, channel // reduction),
190 | nn.ReLU(inplace=True),
191 | nn.Linear(channel // reduction, channel),
192 | nn.Sigmoid()
193 | )
194 |
195 | def forward(self, x):
196 | b, c, _, _ = x.size()
197 | y = self.avg_pool(x).view(b, c)
198 | y = self.fc(y).view(b, c, 1, 1)
199 | return x * y
200 |
--------------------------------------------------------------------------------
/models/networks/base_network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import init
3 |
4 |
5 | class BaseNetwork(nn.Module):
6 | def __init__(self):
7 | super(BaseNetwork, self).__init__()
8 |
9 | @staticmethod
10 | def modify_commandline_options(parser, is_train):
11 | return parser
12 |
13 | def print_network(self):
14 | if isinstance(self, list):
15 | self = self[0]
16 | num_params = 0
17 | for param in self.parameters():
18 | num_params += param.numel()
19 | print('Network [%s] was created. Total number of parameters: %.1f million. '
20 | 'To see the architecture, do print(network).'
21 | % (type(self).__name__, num_params / 1000000))
22 |
23 | def init_weights(self, init_type='normal', gain=0.02):
24 | def init_func(m):
25 | classname = m.__class__.__name__
26 | if classname.find('BatchNorm2d') != -1:
27 | if hasattr(m, 'weight') and m.weight is not None:
28 | init.normal_(m.weight.data, 1.0, gain)
29 | if hasattr(m, 'bias') and m.bias is not None:
30 | init.constant_(m.bias.data, 0.0)
31 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
32 | if init_type == 'normal':
33 | init.normal_(m.weight.data, 0.0, gain)
34 | elif init_type == 'xavier':
35 | init.xavier_normal_(m.weight.data, gain=gain)
36 | elif init_type == 'xavier_uniform':
37 | init.xavier_uniform_(m.weight.data, gain=1.0)
38 | elif init_type == 'kaiming':
39 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
40 | elif init_type == 'orthogonal':
41 | init.orthogonal_(m.weight.data, gain=gain)
42 | elif init_type == 'none': # uses pytorch's default init method
43 | m.reset_parameters()
44 | else:
45 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
46 | if hasattr(m, 'bias') and m.bias is not None:
47 | init.constant_(m.bias.data, 0.0)
48 |
49 | self.apply(init_func)
50 |
51 | # propagate to children
52 | for m in self.children():
53 | if hasattr(m, 'init_weights'):
54 | m.init_weights(init_type, gain)
55 |
--------------------------------------------------------------------------------
/models/networks/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | from models.networks.base_network import BaseNetwork
4 | import util.util as util
5 | import torch
6 | from models.networks.architecture import get_nonspade_norm_layer
7 | import torch.nn.functional as F
8 |
9 |
10 | class MultiscaleDiscriminator(BaseNetwork):
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | parser.add_argument('--netD_subarch', type=str, default='n_layer',
14 | help='architecture of each discriminator')
15 | parser.add_argument('--num_D', type=int, default=2,
16 | help='number of discriminators to be used in multiscale')
17 | opt, _ = parser.parse_known_args()
18 |
19 | # define properties of each discriminator of the multiscale discriminator
20 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator',
21 | 'models.networks.discriminator')
22 | subnetD.modify_commandline_options(parser, is_train)
23 |
24 | return parser
25 |
26 | def __init__(self, opt):
27 | super(MultiscaleDiscriminator, self).__init__()
28 | self.opt = opt
29 |
30 | for i in range(opt.num_D):
31 | subnetD = self.create_single_discriminator(opt)
32 | self.add_module('discriminator_%d' % i, subnetD)
33 |
34 | def create_single_discriminator(self, opt):
35 | subarch = opt.netD_subarch
36 | if subarch == 'n_layer':
37 | netD = NLayerDiscriminator(opt)
38 | else:
39 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)
40 | return netD
41 |
42 | def downsample(self, input):
43 | return F.avg_pool2d(input, kernel_size=3,
44 | stride=2, padding=[1, 1],
45 | count_include_pad=False)
46 |
47 | # Returns list of lists of discriminator outputs.
48 | # The final result is of size opt.num_D x opt.n_layers_D
49 | def forward(self, input):
50 | result = []
51 | get_intermediate_features = not self.opt.no_ganFeat_loss
52 | for name, D in self.named_children():
53 | out = D(input)
54 | if not get_intermediate_features:
55 | out = [out]
56 | result.append(out)
57 | input = self.downsample(input)
58 |
59 | return result
60 |
61 |
62 | # Defines the PatchGAN discriminator with the specified arguments.
63 | class NLayerDiscriminator(BaseNetwork):
64 | @staticmethod
65 | def modify_commandline_options(parser, is_train):
66 | parser.add_argument('--n_layers_D', type=int, default=4,
67 | help='# layers in each discriminator')
68 | return parser
69 |
70 | def __init__(self, opt):
71 |
72 | super(NLayerDiscriminator, self).__init__()
73 | self.opt = opt
74 |
75 | kw = 4
76 | padw = int(np.ceil((kw - 1.0) / 2))
77 | nf = opt.ndf
78 | input_nc = self.compute_D_input_nc(opt)
79 |
80 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)
81 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
82 | nn.LeakyReLU(0.2, False)]]
83 |
84 | for n in range(1, opt.n_layers_D):
85 | nf_prev = nf
86 | nf = min(nf * 2, 512)
87 | stride = 1 if n == opt.n_layers_D - 1 else 2
88 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw,
89 | stride=stride, padding=padw)),
90 | nn.LeakyReLU(0.2, False)
91 | ]]
92 |
93 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
94 |
95 | # We divide the layers into groups to extract intermediate layer outputs
96 | for n in range(len(sequence)):
97 | self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
98 |
99 | def compute_D_input_nc(self, opt):
100 | if opt.D_input == "concat":
101 | input_nc = opt.label_nc + opt.output_nc
102 | if opt.contain_dontcare_label:
103 | input_nc += 1
104 | if not opt.no_instance:
105 | input_nc += 1
106 | else:
107 | input_nc = 3
108 | return input_nc
109 |
110 | def forward(self, input):
111 | results = [input]
112 | for submodel in self.children():
113 |
114 | # intermediate_output = checkpoint(submodel, results[-1])
115 | intermediate_output = submodel(results[-1])
116 | results.append(intermediate_output)
117 |
118 | get_intermediate_features = not self.opt.no_ganFeat_loss
119 | if get_intermediate_features:
120 | return results[0:]
121 | else:
122 | return results[-1]
123 |
124 |
125 | class AudioSubDiscriminator(BaseNetwork):
126 | def __init__(self, opt, nc, audio_nc):
127 | super(AudioSubDiscriminator, self).__init__()
128 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)
129 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
130 | sequence = []
131 | sequence += [norm_layer(nn.Conv1d(nc, nc, 3, 2, 1)),
132 | nn.ReLU()
133 | ]
134 | sequence += [norm_layer(nn.Conv1d(nc, audio_nc, 3, 2, 1)),
135 | nn.ReLU()
136 | ]
137 |
138 | self.conv = nn.Sequential(*sequence)
139 | self.cosine = nn.CosineSimilarity()
140 | self.mapping = nn.Linear(audio_nc, audio_nc)
141 |
142 | def forward(self, result, audio):
143 | region = result[result.shape[3] // 2:result.shape[3] - 2, result.shape[4] // 3: 2 * result.shape[4] // 3]
144 | visual = self.avgpool(region)
145 | cos = self.cosine(visual, self.mapping(audio))
146 | return cos
147 |
148 |
149 | class ImageDiscriminator(BaseNetwork):
150 | """Defines a PatchGAN discriminator"""
151 | def modify_commandline_options(parser, is_train):
152 | parser.add_argument('--n_layers_D', type=int, default=4,
153 | help='# layers in each discriminator')
154 | return parser
155 |
156 | def __init__(self, opt, n_layers=3, norm_layer=nn.BatchNorm2d):
157 | """Construct a PatchGAN discriminator
158 | Parameters:
159 | input_nc (int) -- the number of channels in input images
160 | ndf (int) -- the number of filters in the last conv layer
161 | n_layers (int) -- the number of conv layers in the discriminator
162 | norm_layer -- normalization layer
163 | """
164 | super(ImageDiscriminator, self).__init__()
165 | use_bias = norm_layer == nn.InstanceNorm2d
166 | if opt.D_input == "concat":
167 | input_nc = opt.label_nc + opt.output_nc
168 | else:
169 | input_nc = opt.label_nc
170 | ndf = 64
171 | kw = 4
172 | padw = 1
173 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
174 | nf_mult = 1
175 | nf_mult_prev = 1
176 | for n in range(1, n_layers): # gradually increase the number of filters
177 | nf_mult_prev = nf_mult
178 | nf_mult = min(2 ** n, 8)
179 | sequence += [
180 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
181 | norm_layer(ndf * nf_mult),
182 | nn.LeakyReLU(0.2, True)
183 | ]
184 |
185 | nf_mult_prev = nf_mult
186 | nf_mult = min(2 ** n_layers, 8)
187 | sequence += [
188 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
189 | norm_layer(ndf * nf_mult),
190 | nn.LeakyReLU(0.2, True)
191 | ]
192 |
193 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
194 | self.model = nn.Sequential(*sequence)
195 |
196 | def forward(self, input):
197 | """Standard forward."""
198 | return self.model(input)
199 |
200 |
201 | class FeatureDiscriminator(BaseNetwork):
202 | def __init__(self, opt):
203 | super(FeatureDiscriminator, self).__init__()
204 | self.opt = opt
205 | self.fc = nn.Linear(512, opt.num_labels)
206 | self.dropout = nn.Dropout(0.5)
207 |
208 | def forward(self, x):
209 | x0 = x.view(-1, 512)
210 | net = self.dropout(x0)
211 | net = self.fc(net)
212 | return net
213 |
214 |
215 |
--------------------------------------------------------------------------------
/models/networks/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from models.networks.base_network import BaseNetwork
5 | import torchvision.models.mobilenet
6 | from util import util
7 | from models.networks.audio_network import ResNetSE, SEBasicBlock
8 | import torch
9 | from models.networks.FAN_feature_extractor import FAN_use
10 | from torchvision.models.vgg import vgg19_bn
11 | from models.networks.vision_network import ResNeXt50
12 |
13 |
14 | class ResSEAudioEncoder(BaseNetwork):
15 | def __init__(self, opt, nOut=2048, n_mel_T=None):
16 | super(ResSEAudioEncoder, self).__init__()
17 | self.nOut = nOut
18 | # Number of filters
19 | num_filters = [32, 64, 128, 256]
20 | if n_mel_T is None: # use it when use audio identity
21 | n_mel_T = opt.n_mel_T
22 | self.model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, self.nOut, n_mel_T=n_mel_T)
23 | self.fc = nn.Linear(self.nOut, opt.num_classes)
24 |
25 | def forward_feature(self, x):
26 |
27 | input_size = x.size()
28 | if len(input_size) == 5:
29 | bz, clip_len, c, f, t = input_size
30 | x = x.view(bz * clip_len, c, f, t)
31 | out = self.model(x)
32 | return out
33 |
34 | def forward(self, x):
35 | out = self.forward_feature(x)
36 | score = self.fc(out)
37 | return out, score
38 |
39 |
40 | class ResSESyncEncoder(ResSEAudioEncoder):
41 | def __init__(self, opt):
42 | super(ResSESyncEncoder, self).__init__(opt, nOut=512, n_mel_T=1)
43 |
44 |
45 | class ResNeXtEncoder(ResNeXt50):
46 | def __init__(self, opt):
47 | super(ResNeXtEncoder, self).__init__(opt)
48 |
49 |
50 | class VGGEncoder(BaseNetwork):
51 | def __init__(self, opt):
52 | super(VGGEncoder, self).__init__()
53 | self.model = vgg19_bn(num_classes=opt.num_classes)
54 |
55 | def forward(self, x):
56 | return self.model(x)
57 |
58 |
59 | class FanEncoder(BaseNetwork):
60 | def __init__(self, opt):
61 | super(FanEncoder, self).__init__()
62 | self.opt = opt
63 | pose_dim = self.opt.pose_dim
64 | self.model = FAN_use()
65 | self.classifier = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, opt.num_classes))
66 |
67 | # mapper to mouth subspace
68 | self.to_mouth = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512))
69 | self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim))
70 | self.mouth_fc = nn.Sequential(nn.ReLU(), nn.Linear(512*opt.clip_len, opt.num_classes))
71 |
72 | # mapper to head pose subspace
73 | self.to_headpose = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512))
74 | self.headpose_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, pose_dim))
75 | self.headpose_fc = nn.Sequential(nn.ReLU(), nn.Linear(pose_dim*opt.clip_len, opt.num_classes))
76 |
77 | def load_pretrain(self):
78 | check_point = torch.load(self.opt.FAN_pretrain_path)
79 | print("=> loading checkpoint '{}'".format(self.opt.FAN_pretrain_path))
80 | util.copy_state_dict(check_point, self.model)
81 |
82 | def forward_feature(self, x):
83 | net = self.model(x)
84 | return net
85 |
86 | def forward(self, x):
87 | x0 = x.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size)
88 | net = self.forward_feature(x0)
89 | scores = self.classifier(net.view(-1, self.opt.num_clips, 512).mean(1))
90 | return net, scores
91 |
--------------------------------------------------------------------------------
/models/networks/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models.networks.architecture import VGG19, VGGFace19
5 |
6 |
7 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
8 | # When LSGAN is used, it is basically same as MSELoss,
9 | # but it abstracts away the need to create the target label tensor
10 | # that has the same size as the input
11 | class GANLoss(nn.Module):
12 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
13 | tensor=torch.FloatTensor, opt=None):
14 | super(GANLoss, self).__init__()
15 | self.real_label = target_real_label
16 | self.fake_label = target_fake_label
17 | self.real_label_tensor = None
18 | self.fake_label_tensor = None
19 | self.zero_tensor = None
20 | self.Tensor = tensor
21 | self.gan_mode = gan_mode
22 | self.opt = opt
23 | if gan_mode == 'ls':
24 | pass
25 | elif gan_mode == 'original':
26 | pass
27 | elif gan_mode == 'w':
28 | pass
29 | elif gan_mode == 'hinge':
30 | pass
31 | else:
32 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
33 |
34 | def get_target_tensor(self, input, target_is_real):
35 | if target_is_real:
36 | if self.real_label_tensor is None:
37 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
38 | self.real_label_tensor.requires_grad_(False)
39 | return self.real_label_tensor.expand_as(input)
40 | else:
41 | if self.fake_label_tensor is None:
42 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
43 | self.fake_label_tensor.requires_grad_(False)
44 | return self.fake_label_tensor.expand_as(input)
45 |
46 | def get_zero_tensor(self, input):
47 | if self.zero_tensor is None:
48 | self.zero_tensor = self.Tensor(1).fill_(0)
49 | self.zero_tensor.requires_grad_(False)
50 | return self.zero_tensor.expand_as(input)
51 |
52 | def loss(self, input, target_is_real, for_discriminator=True):
53 | if self.gan_mode == 'original': # cross entropy loss
54 | target_tensor = self.get_target_tensor(input, target_is_real)
55 | loss = F.binary_cross_entropy_with_logits(input, target_tensor)
56 | return loss
57 | elif self.gan_mode == 'ls':
58 | target_tensor = self.get_target_tensor(input, target_is_real)
59 | return F.mse_loss(input, target_tensor)
60 | elif self.gan_mode == 'hinge':
61 | if for_discriminator:
62 | if target_is_real:
63 | minval = torch.min(input - 1, self.get_zero_tensor(input))
64 | loss = -torch.mean(minval)
65 | else:
66 | minval = torch.min(-input - 1, self.get_zero_tensor(input))
67 | loss = -torch.mean(minval)
68 | else:
69 | assert target_is_real, "The generator's hinge loss must be aiming for real"
70 | loss = -torch.mean(input)
71 | return loss
72 | else:
73 | # wgan
74 | if target_is_real:
75 | return -input.mean()
76 | else:
77 | return input.mean()
78 |
79 | def __call__(self, input, target_is_real, for_discriminator=True):
80 | # computing loss is a bit complicated because |input| may not be
81 | # a tensor, but list of tensors in case of multiscale discriminator
82 | if isinstance(input, list):
83 | loss = 0
84 | for pred_i in input:
85 | if isinstance(pred_i, list):
86 | pred_i = pred_i[-1]
87 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
88 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
89 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
90 | loss += new_loss
91 | return loss / len(input)
92 | else:
93 | return self.loss(input, target_is_real, for_discriminator)
94 |
95 |
96 | # Perceptual loss that uses a pretrained VGG network
97 | class VGGLoss(nn.Module):
98 | def __init__(self, opt, vgg=VGG19()):
99 | super(VGGLoss, self).__init__()
100 | self.vgg = vgg.cuda()
101 | self.criterion = nn.L1Loss()
102 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
103 |
104 | def forward(self, x, y, layer=0):
105 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
106 | loss = 0
107 | for i in range(len(x_vgg)):
108 | if i >= layer:
109 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
110 | return loss
111 |
112 |
113 | # KL Divergence loss used in VAE with an image encoder
114 | class KLDLoss(nn.Module):
115 | def forward(self, mu, logvar):
116 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
117 |
118 |
119 | class CrossEntropyLoss(nn.Module):
120 | """Cross Entropy Loss
121 |
122 | It will calculate cross_entropy loss given cls_score and label.
123 | """
124 |
125 | def forward(self, cls_score, label):
126 | loss_cls = F.cross_entropy(cls_score, label)
127 | return loss_cls
128 |
129 |
130 | class SumLogSoftmaxLoss(nn.Module):
131 |
132 | def forward(self, x):
133 | out = F.log_softmax(x, dim=1)
134 | loss = - torch.mean(out) + torch.mean(F.log_softmax(torch.ones_like(out), dim=1) )
135 | return loss
136 |
137 |
138 | class L2SoftmaxLoss(nn.Module):
139 | def __init__(self):
140 | super(L2SoftmaxLoss, self).__init__()
141 | self.softmax = nn.Softmax()
142 | self.L2loss = nn.MSELoss()
143 | self.label = None
144 |
145 | def forward(self, x):
146 | out = self.softmax(x)
147 | self.label = (torch.ones(out.size()).float() * (1 / x.size(1))).cuda()
148 | loss = self.L2loss(out, self.label)
149 | return loss
150 |
151 |
152 | class SoftmaxContrastiveLoss(nn.Module):
153 | def __init__(self):
154 | super(SoftmaxContrastiveLoss, self).__init__()
155 | self.cross_ent = nn.CrossEntropyLoss()
156 |
157 | def l2_norm(self, x):
158 | x_norm = F.normalize(x, p=2, dim=1)
159 | return x_norm
160 |
161 | def l2_sim(self, feature1, feature2):
162 | Feature = feature1.expand(feature1.size(0), feature1.size(0), feature1.size(1)).transpose(0, 1)
163 | return torch.norm(Feature - feature2, p=2, dim=2)
164 |
165 | @torch.no_grad()
166 | def evaluate(self, face_feat, audio_feat, mode='max'):
167 | assert mode in 'max' or 'confusion', '{} must be in max or confusion'.format(mode)
168 | face_feat = self.l2_norm(face_feat)
169 | audio_feat = self.l2_norm(audio_feat)
170 | cross_dist = 1.0 / self.l2_sim(face_feat, audio_feat)
171 |
172 | print(cross_dist)
173 | if mode == 'max':
174 | label = torch.arange(face_feat.size(0)).to(cross_dist.device)
175 | max_idx = torch.argmax(cross_dist, dim=1)
176 | # print(max_idx, label)
177 | acc = torch.sum(label == max_idx) * 1.0 / label.size(0)
178 | else:
179 | raise ValueError
180 |
181 | return acc
182 |
183 | def forward(self, face_feat, audio_feat, mode='max'):
184 | assert mode in 'max' or 'confusion', '{} must be in max or confusion'.format(mode)
185 |
186 | face_feat = self.l2_norm(face_feat)
187 | audio_feat = self.l2_norm(audio_feat)
188 |
189 | cross_dist = 1.0 / self.l2_sim(face_feat, audio_feat)
190 |
191 | if mode == 'max':
192 | label = torch.arange(face_feat.size(0)).to(cross_dist.device)
193 | loss = F.cross_entropy(cross_dist, label)
194 | else:
195 | raise ValueError
196 | return loss
197 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
2 | from .batchnorm import patch_sync_batchnorm, convert_model
3 | from .replicate import DataParallelWithCallback, patch_replication_callback
4 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-36.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-37.pyc
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import contextlib
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 | try:
10 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
11 | except ImportError:
12 | ReduceAddCoalesced = Broadcast = None
13 |
14 | try:
15 | from jactorch.parallel.comm import SyncMaster
16 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
17 | except ImportError:
18 | from .comm import SyncMaster
19 | from .replicate import DataParallelWithCallback
20 |
21 | __all__ = [
22 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
23 | 'patch_sync_batchnorm', 'convert_model'
24 | ]
25 |
26 |
27 | def _sum_ft(tensor):
28 | """sum over the first and last dimention"""
29 | return tensor.sum(dim=0).sum(dim=-1)
30 |
31 |
32 | def _unsqueeze_ft(tensor):
33 | """add new dimensions at the front and the tail"""
34 | return tensor.unsqueeze(0).unsqueeze(-1)
35 |
36 |
37 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
38 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
39 |
40 |
41 | class _SynchronizedBatchNorm(_BatchNorm):
42 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
43 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
44 |
45 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
46 |
47 | self._sync_master = SyncMaster(self._data_parallel_master)
48 |
49 | self._is_parallel = False
50 | self._parallel_id = None
51 | self._slave_pipe = None
52 |
53 | def forward(self, input):
54 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
55 | if not (self._is_parallel and self.training):
56 | return F.batch_norm(
57 | input, self.running_mean, self.running_var, self.weight, self.bias,
58 | self.training, self.momentum, self.eps)
59 |
60 | # Resize the input to (B, C, -1).
61 | input_shape = input.size()
62 | input = input.view(input.size(0), self.num_features, -1)
63 |
64 | # Compute the sum and square-sum.
65 | sum_size = input.size(0) * input.size(2)
66 | input_sum = _sum_ft(input)
67 | input_ssum = _sum_ft(input ** 2)
68 |
69 | # Reduce-and-broadcast the statistics.
70 | if self._parallel_id == 0:
71 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
72 | else:
73 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
74 |
75 | # Compute the output.
76 | if self.affine:
77 | # MJY:: Fuse the multiplication for speed.
78 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
79 | else:
80 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
81 |
82 | # Reshape it.
83 | return output.view(input_shape)
84 |
85 | def __data_parallel_replicate__(self, ctx, copy_id):
86 | self._is_parallel = True
87 | self._parallel_id = copy_id
88 |
89 | # parallel_id == 0 means master device.
90 | if self._parallel_id == 0:
91 | ctx.sync_master = self._sync_master
92 | else:
93 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
94 |
95 | def _data_parallel_master(self, intermediates):
96 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
97 |
98 | # Always using same "device order" makes the ReduceAdd operation faster.
99 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
101 |
102 | to_reduce = [i[1][:2] for i in intermediates]
103 | to_reduce = [j for i in to_reduce for j in i] # flatten
104 | target_gpus = [i[1].sum.get_device() for i in intermediates]
105 |
106 | sum_size = sum([i[1].sum_size for i in intermediates])
107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
108 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
109 |
110 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
111 |
112 | outputs = []
113 | for i, rec in enumerate(intermediates):
114 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
115 |
116 | return outputs
117 |
118 | def _compute_mean_std(self, sum_, ssum, size):
119 | """Compute the mean and standard-deviation with sum and square-sum. This method
120 | also maintains the moving average on the master device."""
121 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
122 | mean = sum_ / size
123 | sumvar = ssum - sum_ * mean
124 | unbias_var = sumvar / (size - 1)
125 | bias_var = sumvar / size
126 |
127 | if hasattr(torch, 'no_grad'):
128 | with torch.no_grad():
129 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
130 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
131 | else:
132 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
133 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
134 |
135 | return mean, bias_var.clamp(self.eps) ** -0.5
136 |
137 |
138 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
139 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
140 | mini-batch.
141 |
142 | .. math::
143 |
144 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
145 |
146 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
147 | standard-deviation are reduced across all devices during training.
148 |
149 | For example, when one uses `nn.DataParallel` to wrap the network during
150 | training, PyTorch's implementation normalize the tensor on each device using
151 | the statistics only on that device, which accelerated the computation and
152 | is also easy to implement, but the statistics might be inaccurate.
153 | Instead, in this synchronized version, the statistics will be computed
154 | over all training samples distributed on multiple devices.
155 |
156 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
157 | as the built-in PyTorch implementation.
158 |
159 | The mean and standard-deviation are calculated per-dimension over
160 | the mini-batches and gamma and beta are learnable parameter vectors
161 | of size C (where C is the input size).
162 |
163 | During training, this layer keeps a running estimate of its computed mean
164 | and variance. The running sum is kept with a default momentum of 0.1.
165 |
166 | During evaluation, this running mean/variance is used for normalization.
167 |
168 | Because the BatchNorm is done over the `C` dimension, computing statistics
169 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
170 |
171 | Args:
172 | num_features: num_features from an expected input of size
173 | `batch_size x num_features [x width]`
174 | eps: a value added to the denominator for numerical stability.
175 | Default: 1e-5
176 | momentum: the value used for the running_mean and running_var
177 | computation. Default: 0.1
178 | affine: a boolean value that when set to ``True``, gives the layer learnable
179 | affine parameters. Default: ``True``
180 |
181 | Shape::
182 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
183 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
184 |
185 | Examples:
186 | >>> # With Learnable Parameters
187 | >>> m = SynchronizedBatchNorm1d(100)
188 | >>> # Without Learnable Parameters
189 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
190 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
191 | >>> output = m(input)
192 | """
193 |
194 | def _check_input_dim(self, input):
195 | if input.dim() != 2 and input.dim() != 3:
196 | raise ValueError('expected 2D or 3D input (got {}D input)'
197 | .format(input.dim()))
198 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
199 |
200 |
201 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
202 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
203 | of 3d inputs
204 |
205 | .. math::
206 |
207 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
208 |
209 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
210 | standard-deviation are reduced across all devices during training.
211 |
212 | For example, when one uses `nn.DataParallel` to wrap the network during
213 | training, PyTorch's implementation normalize the tensor on each device using
214 | the statistics only on that device, which accelerated the computation and
215 | is also easy to implement, but the statistics might be inaccurate.
216 | Instead, in this synchronized version, the statistics will be computed
217 | over all training samples distributed on multiple devices.
218 |
219 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
220 | as the built-in PyTorch implementation.
221 |
222 | The mean and standard-deviation are calculated per-dimension over
223 | the mini-batches and gamma and beta are learnable parameter vectors
224 | of size C (where C is the input size).
225 |
226 | During training, this layer keeps a running estimate of its computed mean
227 | and variance. The running sum is kept with a default momentum of 0.1.
228 |
229 | During evaluation, this running mean/variance is used for normalization.
230 |
231 | Because the BatchNorm is done over the `C` dimension, computing statistics
232 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
233 |
234 | Args:
235 | num_features: num_features from an expected input of
236 | size batch_size x num_features x height x width
237 | eps: a value added to the denominator for numerical stability.
238 | Default: 1e-5
239 | momentum: the value used for the running_mean and running_var
240 | computation. Default: 0.1
241 | affine: a boolean value that when set to ``True``, gives the layer learnable
242 | affine parameters. Default: ``True``
243 |
244 | Shape::
245 | - Input: :math:`(N, C, H, W)`
246 | - Output: :math:`(N, C, H, W)` (same shape as input)
247 |
248 | Examples:
249 | >>> # With Learnable Parameters
250 | >>> m = SynchronizedBatchNorm2d(100)
251 | >>> # Without Learnable Parameters
252 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
253 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
254 | >>> output = m(input)
255 | """
256 |
257 | def _check_input_dim(self, input):
258 | if input.dim() != 4:
259 | raise ValueError('expected 4D input (got {}D input)'
260 | .format(input.dim()))
261 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
262 |
263 |
264 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
265 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
266 | of 4d inputs
267 |
268 | .. math::
269 |
270 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
271 |
272 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
273 | standard-deviation are reduced across all devices during training.
274 |
275 | For example, when one uses `nn.DataParallel` to wrap the network during
276 | training, PyTorch's implementation normalize the tensor on each device using
277 | the statistics only on that device, which accelerated the computation and
278 | is also easy to implement, but the statistics might be inaccurate.
279 | Instead, in this synchronized version, the statistics will be computed
280 | over all training samples distributed on multiple devices.
281 |
282 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
283 | as the built-in PyTorch implementation.
284 |
285 | The mean and standard-deviation are calculated per-dimension over
286 | the mini-batches and gamma and beta are learnable parameter vectors
287 | of size C (where C is the input size).
288 |
289 | During training, this layer keeps a running estimate of its computed mean
290 | and variance. The running sum is kept with a default momentum of 0.1.
291 |
292 | During evaluation, this running mean/variance is used for normalization.
293 |
294 | Because the BatchNorm is done over the `C` dimension, computing statistics
295 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
296 | or Spatio-temporal BatchNorm
297 |
298 | Args:
299 | num_features: num_features from an expected input of
300 | size batch_size x num_features x depth x height x width
301 | eps: a value added to the denominator for numerical stability.
302 | Default: 1e-5
303 | momentum: the value used for the running_mean and running_var
304 | computation. Default: 0.1
305 | affine: a boolean value that when set to ``True``, gives the layer learnable
306 | affine parameters. Default: ``True``
307 |
308 | Shape::
309 | - Input: :math:`(N, C, D, H, W)`
310 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
311 |
312 | Examples:
313 | >>> # With Learnable Parameters
314 | >>> m = SynchronizedBatchNorm3d(100)
315 | >>> # Without Learnable Parameters
316 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
317 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
318 | >>> output = m(input)
319 | """
320 |
321 | def _check_input_dim(self, input):
322 | if input.dim() != 5:
323 | raise ValueError('expected 5D input (got {}D input)'
324 | .format(input.dim()))
325 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
326 |
327 |
328 | @contextlib.contextmanager
329 | def patch_sync_batchnorm():
330 | import torch.nn as nn
331 |
332 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
333 |
334 | nn.BatchNorm1d = SynchronizedBatchNorm1d
335 | nn.BatchNorm2d = SynchronizedBatchNorm2d
336 | nn.BatchNorm3d = SynchronizedBatchNorm3d
337 |
338 | yield
339 |
340 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
341 |
342 |
343 | def convert_model(module):
344 | """Traverse the input module and its child recursively
345 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
346 | to SynchronizedBatchNorm*N*d
347 |
348 | Args:
349 | module: the input module needs to be convert to SyncBN model
350 |
351 | Examples:
352 | >>> import torch.nn as nn
353 | >>> import torchvision
354 | >>> # m is a standard pytorch model
355 | >>> m = torchvision.models.resnet18(True)
356 | >>> m = nn.DataParallel(m)
357 | >>> # after convert, m is using SyncBN
358 | >>> m = convert_model(m)
359 | """
360 | if isinstance(module, torch.nn.DataParallel):
361 | mod = module.module
362 | mod = convert_model(mod)
363 | mod = DataParallelWithCallback(mod)
364 | return mod
365 |
366 | mod = module
367 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
368 | torch.nn.modules.batchnorm.BatchNorm2d,
369 | torch.nn.modules.batchnorm.BatchNorm3d],
370 | [SynchronizedBatchNorm1d,
371 | SynchronizedBatchNorm2d,
372 | SynchronizedBatchNorm3d]):
373 | if isinstance(module, pth_module):
374 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
375 | mod.running_mean = module.running_mean
376 | mod.running_var = module.running_var
377 | if module.affine:
378 | mod.weight.data = module.weight.data.clone().detach()
379 | mod.bias.data = module.bias.data.clone().detach()
380 |
381 | for name, child in module.named_children():
382 | mod.add_module(name, convert_model(child))
383 |
384 | return mod
385 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 |
5 | __all__ = ['BatchNorm2dReimpl']
6 |
7 |
8 | class BatchNorm2dReimpl(nn.Module):
9 | """
10 | A re-implementation of batch normalization, used for testing the numerical
11 | stability.
12 |
13 | Author: acgtyrant
14 | See also:
15 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
16 | """
17 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
18 | super().__init__()
19 |
20 | self.num_features = num_features
21 | self.eps = eps
22 | self.momentum = momentum
23 | self.weight = nn.Parameter(torch.empty(num_features))
24 | self.bias = nn.Parameter(torch.empty(num_features))
25 | self.register_buffer('running_mean', torch.zeros(num_features))
26 | self.register_buffer('running_var', torch.ones(num_features))
27 | self.reset_parameters()
28 |
29 | def reset_running_stats(self):
30 | self.running_mean.zero_()
31 | self.running_var.fill_(1)
32 |
33 | def reset_parameters(self):
34 | self.reset_running_stats()
35 | init.uniform_(self.weight)
36 | init.zeros_(self.bias)
37 |
38 | def forward(self, input_):
39 | batchsize, channels, height, width = input_.size()
40 | numel = batchsize * height * width
41 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
42 | sum_ = input_.sum(1)
43 | sum_of_square = input_.pow(2).sum(1)
44 | mean = sum_ / numel
45 | sumvar = sum_of_square - sum_ * mean
46 |
47 | self.running_mean = (
48 | (1 - self.momentum) * self.running_mean
49 | + self.momentum * mean.detach()
50 | )
51 | unbias_var = sumvar / (numel - 1)
52 | self.running_var = (
53 | (1 - self.momentum) * self.running_var
54 | + self.momentum * unbias_var.detach()
55 | )
56 |
57 | bias_var = sumvar / numel
58 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
59 | output = (
60 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
61 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
62 |
63 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
64 |
65 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | import queue
2 | import collections
3 | import threading
4 |
5 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
6 |
7 |
8 | class FutureResult(object):
9 | """A thread-safe future implementation. Used only as one-to-one pipe."""
10 |
11 | def __init__(self):
12 | self._result = None
13 | self._lock = threading.Lock()
14 | self._cond = threading.Condition(self._lock)
15 |
16 | def put(self, result):
17 | with self._lock:
18 | assert self._result is None, 'Previous result has\'t been fetched.'
19 | self._result = result
20 | self._cond.notify()
21 |
22 | def get(self):
23 | with self._lock:
24 | if self._result is None:
25 | self._cond.wait()
26 |
27 | res = self._result
28 | self._result = None
29 | return res
30 |
31 |
32 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
33 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
34 |
35 |
36 | class SlavePipe(_SlavePipeBase):
37 | """Pipe for master-slave communication."""
38 |
39 | def run_slave(self, msg):
40 | self.queue.put((self.identifier, msg))
41 | ret = self.result.get()
42 | self.queue.put(True)
43 | return ret
44 |
45 |
46 | class SyncMaster(object):
47 | """An abstract `SyncMaster` object.
48 |
49 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
50 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
51 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
52 | and passed to a registered callback.
53 | - After receiving the messages, the master device should gather the information and determine to message passed
54 | back to each slave devices.
55 | """
56 |
57 | def __init__(self, master_callback):
58 | """
59 |
60 | Args:
61 | master_callback: a callback to be invoked after having collected messages from slave devices.
62 | """
63 | self._master_callback = master_callback
64 | self._queue = queue.Queue()
65 | self._registry = collections.OrderedDict()
66 | self._activated = False
67 |
68 | def __getstate__(self):
69 | return {'master_callback': self._master_callback}
70 |
71 | def __setstate__(self, state):
72 | self.__init__(state['master_callback'])
73 |
74 | def register_slave(self, identifier):
75 | """
76 | Register an slave device.
77 |
78 | Args:
79 | identifier: an identifier, usually is the device id.
80 |
81 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
82 |
83 | """
84 | if self._activated:
85 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
86 | self._activated = False
87 | self._registry.clear()
88 | future = FutureResult()
89 | self._registry[identifier] = _MasterRegistry(future)
90 | return SlavePipe(identifier, self._queue, future)
91 |
92 | def run_master(self, master_msg):
93 | """
94 | Main entry for the master device in each forward pass.
95 | The messages were first collected from each devices (including the master device), and then
96 | an callback will be invoked to compute the message to be sent back to each devices
97 | (including the master device).
98 |
99 | Args:
100 | master_msg: the message that the master want to send to itself. This will be placed as the first
101 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
102 |
103 | Returns: the message to be sent back to the master device.
104 |
105 | """
106 | self._activated = True
107 |
108 | intermediates = [(0, master_msg)]
109 | for i in range(self.nr_slaves):
110 | intermediates.append(self._queue.get())
111 |
112 | results = self._master_callback(intermediates)
113 | assert results[0][0] == 0, 'The first result should belongs to the master.'
114 |
115 | for i, res in results:
116 | if i == 0:
117 | continue
118 | self._registry[i].result.put(res)
119 |
120 | for i in range(self.nr_slaves):
121 | assert self._queue.get() is True
122 |
123 | return results[0][1]
124 |
125 | @property
126 | def nr_slaves(self):
127 | return len(self._registry)
128 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 |
4 | from torch.nn.parallel.data_parallel import DataParallel
5 | from .scatter_gather import scatter_kwargs
6 |
7 | __all__ = [
8 | 'CallbackContext',
9 | 'execute_replication_callbacks',
10 | 'DataParallelWithCallback',
11 | 'patch_replication_callback'
12 | ]
13 |
14 |
15 | class CallbackContext(object):
16 | pass
17 |
18 |
19 | def execute_replication_callbacks(modules):
20 | """
21 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
22 |
23 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
24 |
25 | Note that, as all modules are isomorphism, we assign each sub-module with a context
26 | (shared among multiple copies of this module on different devices).
27 | Through this context, different copies can share some information.
28 |
29 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
30 | of any slave copies.
31 | """
32 | master_copy = modules[0]
33 | nr_modules = len(list(master_copy.modules()))
34 | ctxs = [CallbackContext() for _ in range(nr_modules)]
35 |
36 | for i, module in enumerate(modules):
37 | for j, m in enumerate(module.modules()):
38 | if hasattr(m, '__data_parallel_replicate__'):
39 | m.__data_parallel_replicate__(ctxs[j], i)
40 |
41 |
42 | class DataParallelWithCallback(DataParallel):
43 | """
44 | Data Parallel with a replication callback.
45 |
46 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
47 | original `replicate` function.
48 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
49 |
50 | Examples:
51 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
52 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
53 | # sync_bn.__data_parallel_replicate__ will be invoked.
54 | """
55 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_size=None):
56 | super(DataParallelWithCallback, self).__init__(module)
57 |
58 | if not torch.cuda.is_available():
59 | self.module = module
60 | self.device_ids = []
61 | return
62 |
63 | if device_ids is None:
64 | device_ids = list(range(torch.cuda.device_count()))
65 | if output_device is None:
66 | output_device = device_ids[0]
67 | self.dim = dim
68 | self.module = module
69 | self.device_ids = device_ids
70 | self.output_device = output_device
71 | self.chunk_size = chunk_size
72 |
73 | if len(self.device_ids) == 1:
74 | self.module.cuda(device_ids[0])
75 |
76 | def forward(self, *inputs, **kwargs):
77 | if not self.device_ids:
78 | return self.module(*inputs, **kwargs)
79 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_size)
80 | if len(self.device_ids) == 1:
81 | return self.module(*inputs[0], **kwargs[0])
82 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
83 | outputs = self.parallel_apply(replicas, inputs, kwargs)
84 | return self.gather(outputs, self.output_device)
85 |
86 | def scatter(self, inputs, kwargs, device_ids, chunk_size):
87 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_size=self.chunk_size)
88 |
89 | def replicate(self, module, device_ids):
90 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 |
95 |
96 | def patch_replication_callback(data_parallel):
97 | """
98 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
99 | Useful when you have customized `DataParallel` implementation.
100 |
101 | Examples:
102 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
103 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
104 | > patch_replication_callback(sync_bn)
105 | # this is equivalent to
106 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
107 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
108 | """
109 |
110 | assert isinstance(data_parallel, DataParallel)
111 |
112 | old_replicate = data_parallel.replicate
113 |
114 | @functools.wraps(old_replicate)
115 | def new_replicate(module, device_ids):
116 | modules = old_replicate(module, device_ids)
117 | execute_replication_callbacks(modules)
118 | return modules
119 |
120 | data_parallel.replicate = new_replicate
121 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/scatter_gather.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.parallel._functions import Scatter, Gather
3 |
4 |
5 | def scatter(inputs, target_gpus, dim=0, chunk_size=None):
6 | r"""
7 | Slices tensors into approximately equal chunks and
8 | distributes them across given GPUs. Duplicates
9 | references to objects that are not tensors.
10 | """
11 | def scatter_map(obj):
12 | if isinstance(obj, torch.Tensor):
13 | return Scatter.apply(target_gpus, chunk_size, dim, obj)
14 | if isinstance(obj, tuple) and len(obj) > 0:
15 | return list(zip(*map(scatter_map, obj)))
16 | if isinstance(obj, list) and len(obj) > 0:
17 | return list(map(list, zip(*map(scatter_map, obj))))
18 | if isinstance(obj, dict) and len(obj) > 0:
19 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
20 | return [obj for targets in target_gpus]
21 |
22 | # After scatter_map is called, a scatter_map cell will exist. This cell
23 | # has a reference to the actual function scatter_map, which has references
24 | # to a closure that has a reference to the scatter_map cell (because the
25 | # fn is recursive). To avoid this reference cycle, we set the function to
26 | # None, clearing the cell
27 | try:
28 | res = scatter_map(inputs)
29 | finally:
30 | scatter_map = None
31 | return res
32 |
33 |
34 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_size=None):
35 | r"""Scatter with support for kwargs dictionary"""
36 | inputs = scatter(inputs, target_gpus, dim, chunk_size) if inputs else []
37 | kwargs = scatter(kwargs, target_gpus, dim, chunk_size) if kwargs else []
38 | if len(inputs) < len(kwargs):
39 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
40 | elif len(kwargs) < len(inputs):
41 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
42 | inputs = tuple(inputs)
43 | kwargs = tuple(kwargs)
44 | return inputs, kwargs
45 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 |
4 |
5 | class TorchTestCase(unittest.TestCase):
6 | def assertTensorClose(self, x, y):
7 | adiff = float((x - y).abs().max())
8 | if (y == 0).all():
9 | rdiff = 'NaN'
10 | else:
11 | rdiff = float((adiff / y).abs().max())
12 |
13 | message = (
14 | 'Tensor close check failed\n'
15 | 'adiff={}\n'
16 | 'rdiff={}\n'
17 | ).format(adiff, rdiff)
18 | self.assertTrue(torch.allclose(x, y), message)
19 |
20 |
--------------------------------------------------------------------------------
/models/networks/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import os
7 | from math import *
8 |
9 | def P2sRt(P):
10 | ''' decompositing camera matrix P.
11 | Args:
12 | P: (3, 4). Affine Camera Matrix.
13 | Returns:
14 | s: scale factor.
15 | R: (3, 3). rotation matrix.
16 | t2d: (2,). 2d translation.
17 | '''
18 | t3d = P[:, 3]
19 | R1 = P[0:1, :3]
20 | R2 = P[1:2, :3]
21 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0
22 | r1 = R1 / np.linalg.norm(R1)
23 | r2 = R2 / np.linalg.norm(R2)
24 | r3 = np.cross(r1, r2)
25 |
26 | R = np.concatenate((r1, r2, r3), 0)
27 | return s, R, t3d
28 |
29 | def matrix2angle(R):
30 | ''' compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf
31 | Args:
32 | R: (3,3). rotation matrix
33 | Returns:
34 | x: yaw
35 | y: pitch
36 | z: roll
37 | '''
38 | # assert(isRotationMatrix(R))
39 |
40 | if R[2, 0] != 1 and R[2, 0] != -1:
41 | x = -asin(max(-1, min(R[2, 0], 1)))
42 | y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x))
43 | z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x))
44 |
45 | else: # Gimbal lock
46 | z = 0 # can be anything
47 | if R[2, 0] == -1:
48 | x = np.pi / 2
49 | y = z + atan2(R[0, 1], R[0, 2])
50 | else:
51 | x = -np.pi / 2
52 | y = -z + atan2(-R[0, 1], -R[0, 2])
53 |
54 | return [x, y, z]
55 |
56 | def angle2matrix(angles):
57 | ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA.
58 | Args:
59 | angles: [3,]. x, y, z angles
60 | x: yaw.
61 | y: pitch.
62 | z: roll.
63 | Returns:
64 | R: 3x3. rotation matrix.
65 | '''
66 | # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])
67 | # x, y, z = angles[0], angles[1], angles[2]
68 | y, x, z = angles[0], angles[1], angles[2]
69 |
70 | # x
71 | Rx=np.array([[1, 0, 0],
72 | [0, cos(x), -sin(x)],
73 | [0, sin(x), cos(x)]])
74 | # y
75 | Ry=np.array([[ cos(y), 0, sin(y)],
76 | [ 0, 1, 0],
77 | [-sin(y), 0, cos(y)]])
78 | # z
79 | Rz=np.array([[cos(z), -sin(z), 0],
80 | [sin(z), cos(z), 0],
81 | [ 0, 0, 1]])
82 | R = Rz.dot(Ry).dot(Rx)
83 | return R.astype(np.float32)
84 |
85 | def tensor2im(input_image, imtype=np.uint8):
86 | """"Converts a Tensor array into a numpy image array.
87 |
88 | Parameters:
89 | input_image (tensor) -- the input image tensor array
90 | imtype (type) -- the desired type of the converted numpy array
91 | """
92 | if not isinstance(input_image, np.ndarray):
93 | if isinstance(input_image, torch.Tensor): # get the data from a variable
94 | image_tensor = input_image.data
95 | else:
96 | return input_image
97 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
98 | if image_numpy.shape[0] == 1: # grayscale to RGB
99 | image_numpy = np.tile(image_numpy, (3, 1, 1))
100 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
101 | else: # if it is a numpy array, do nothing
102 | image_numpy = input_image
103 | return image_numpy.astype(imtype)
104 |
105 |
106 | def diagnose_network(net, name='network'):
107 | """Calculate and print the mean of average absolute(gradients)
108 |
109 | Parameters:
110 | net (torch network) -- Torch network
111 | name (str) -- the name of the network
112 | """
113 | mean = 0.0
114 | count = 0
115 | for param in net.parameters():
116 | if param.grad is not None:
117 | mean += torch.mean(torch.abs(param.grad.data))
118 | count += 1
119 | if count > 0:
120 | mean = mean / count
121 | print(name)
122 | print(mean)
123 |
124 |
125 | def save_image(image_numpy, image_path):
126 | """Save a numpy image to the disk
127 |
128 | Parameters:
129 | image_numpy (numpy array) -- input numpy array
130 | image_path (str) -- the path of the image
131 | """
132 | image_pil = Image.fromarray(image_numpy)
133 | image_pil.save(image_path)
134 |
135 |
136 | def print_numpy(x, val=True, shp=False):
137 | """Print the mean, min, max, median, std, and size of a numpy array
138 |
139 | Parameters:
140 | val (bool) -- if print the values of the numpy array
141 | shp (bool) -- if print the shape of the numpy array
142 | """
143 | x = x.astype(np.float64)
144 | if shp:
145 | print('shape,', x.shape)
146 | if val:
147 | x = x.flatten()
148 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
149 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
150 |
151 |
152 | def mkdirs(paths):
153 | """create empty directories if they don't exist
154 |
155 | Parameters:
156 | paths (str list) -- a list of directory paths
157 | """
158 | if isinstance(paths, list) and not isinstance(paths, str):
159 | for path in paths:
160 | mkdir(path)
161 | else:
162 | mkdir(paths)
163 |
164 |
165 | def mkdir(path):
166 | """create a single empty directory if it didn't exist
167 |
168 | Parameters:
169 | path (str) -- a single directory path
170 | """
171 | if not os.path.exists(path):
172 | os.makedirs(path)
173 |
--------------------------------------------------------------------------------
/models/networks/vision_network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from models.networks.base_network import BaseNetwork
4 | from torchvision.models.resnet import ResNet, Bottleneck
5 | from util import util
6 | import torch
7 |
8 | model_urls = {
9 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
10 | }
11 |
12 |
13 | class ResNeXt50(BaseNetwork):
14 | def __init__(self, opt):
15 | super(ResNeXt50, self).__init__()
16 | self.model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
17 | self.opt = opt
18 | # self.reduced_id_dim = opt.reduced_id_dim
19 | self.conv1x1 = nn.Conv2d(512 * Bottleneck.expansion, 512, kernel_size=1, padding=0)
20 | self.fc = nn.Linear(512 * Bottleneck.expansion, opt.num_classes)
21 | # self.fc_pre = nn.Sequential(nn.Linear(512 * Bottleneck.expansion, self.reduced_id_dim), nn.ReLU())
22 |
23 |
24 | def load_pretrain(self):
25 | check_point = torch.load(model_urls['resnext50_32x4d'])
26 | util.copy_state_dict(check_point, self.model)
27 |
28 | def forward_feature(self, input):
29 | x = self.model.conv1(input)
30 | x = self.model.bn1(x)
31 | x = self.model.relu(x)
32 | x = self.model.maxpool(x)
33 |
34 | x = self.model.layer1(x)
35 | x = self.model.layer2(x)
36 | x = self.model.layer3(x)
37 | x = self.model.layer4(x)
38 | net = self.model.avgpool(x)
39 | net = torch.flatten(net, 1)
40 | x = self.conv1x1(x)
41 | # x = self.fc_pre(x)
42 | return net, x
43 |
44 | def forward(self, input):
45 | input_batch = input.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size)
46 | net, x = self.forward_feature(input_batch)
47 | net = net.view(-1, self.opt.num_inputs, 512 * Bottleneck.expansion)
48 | x = F.adaptive_avg_pool2d(x, (7, 7))
49 | x = x.view(-1, self.opt.num_inputs, 512, 7, 7)
50 | net = torch.mean(net, 1)
51 | x = torch.mean(x, 1)
52 | cls_scores = self.fc(net)
53 |
54 | return [net, x], cls_scores
55 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import math
4 | import os
5 | from util import util
6 | import torch
7 | import models
8 | import data
9 | import pickle
10 |
11 |
12 | class BaseOptions():
13 | def __init__(self):
14 | self.initialized = False
15 |
16 | def initialize(self, parser):
17 | # experiment specifics
18 | parser.add_argument('--name', type=str, default='demo', help='name of the experiment. It decides where to store samples and models')
19 | parser.add_argument('--filename_tmpl', type=str, default='{:06}.jpg', help='name of the experiment. It decides where to store samples and models')
20 | parser.add_argument('--data_path', type=str, default='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv', help='where to load voxceleb train data')
21 | parser.add_argument('--lrw_data_path', type=str,
22 | default='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv',
23 | help='where to load lrw train data')
24 |
25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids')
26 | parser.add_argument('--num_classes', type=int, default=5830, help='num classes')
27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
28 | parser.add_argument('--model', type=str, default='av', help='which model to use, rotate|rotatespade')
29 | parser.add_argument('--trainer', type=str, default='audio', help='which trainer to use, rotate|rotatespade')
30 | parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization')
31 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
32 | parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
33 | parser.add_argument('--norm_A', type=str, default='spectralinstance', help='instance normalization or batch normalization')
34 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
35 | # input/output sizes
36 | parser.add_argument('--batchSize', type=int, default=2, help='input batch size')
37 | parser.add_argument('--preprocess_mode', type=str, default='resize_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
38 | parser.add_argument('--crop_size', type=int, default=224, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
39 | parser.add_argument('--crop_len', type=int, default=16, help='Crop len')
40 | parser.add_argument('--target_crop_len', type=int, default=0, help='Crop len')
41 | parser.add_argument('--crop', action='store_true', help='whether to crop the image')
42 | parser.add_argument('--clip_len', type=int, default=1, help='num of imgs to process')
43 | parser.add_argument('--pose_dim', type=int, default=12, help='num of imgs to process')
44 | parser.add_argument('--frame_interval', type=int, default=1, help='the interval of frams')
45 | parser.add_argument('--num_clips', type=int, default=1, help='num of clips to process')
46 | parser.add_argument('--num_inputs', type=int, default=1, help='num of inputs to the network')
47 | parser.add_argument('--feature_encoded_dim', type=int, default=2560, help='dim of reduced id feature')
48 |
49 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
50 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
51 | parser.add_argument('--audio_nc', type=int, default=256, help='# of output audio channels')
52 | parser.add_argument('--frame_rate', type=int, default=25, help='fps')
53 | parser.add_argument('--num_frames_per_clip', type=int, default=5, help='num of frames one audio bin')
54 | parser.add_argument('--hop_size', type=int, default=160, help='audio hop size')
55 | parser.add_argument('--generate_interval', type=int, default=1, help='select frames to generate')
56 | parser.add_argument('--dis_feat_rec', action='store_true', help='select frames to generate')
57 |
58 | parser.add_argument('--train_recognition', action='store_true', help='train recognition only')
59 | parser.add_argument('--train_sync', action='store_true', help='train sync only')
60 | parser.add_argument('--train_word', action='store_true', help='train word only')
61 | parser.add_argument('--train_dis_pose', action='store_true', help='train dis pose')
62 | parser.add_argument('--generate_from_audio_only', action='store_true', help='if specified, generate only from audio features')
63 | parser.add_argument('--noise_pose', action='store_true', help='noise pose to generate a talking face')
64 | parser.add_argument('--style_feature_loss', action='store_true', help='style_feature_loss')
65 |
66 | # for setting inputsf
67 | parser.add_argument('--dataset_mode', type=str, default='voxtest')
68 | parser.add_argument('--landmark_align', action='store_true', help='wether there is landmark_align')
69 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
70 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
71 | parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data')
72 | parser.add_argument('--n_mel_T', default=4, type=int, help='# threads for loading data')
73 | parser.add_argument('--num_bins_per_frame', type=int, default=4, help='n_melT')
74 |
75 | parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
76 | parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
77 | parser.add_argument('--use_audio', type=int, default=1, help='use audio as driven input')
78 | parser.add_argument('--use_audio_id', type=int, default=0, help='use audio id')
79 | parser.add_argument('--augment_target', action='store_true', help='whether to use checkpoint')
80 | parser.add_argument('--verbose', action='store_true', help='just add')
81 |
82 | parser.add_argument('--display_winsize', type=int, default=224, help='display window size')
83 |
84 | # for generator
85 | parser.add_argument('--netG', type=str, default='modulate', help='selects model to use for netG (modulate)')
86 | parser.add_argument('--netA', type=str, default='resseaudio', help='selects model to use for netA (audio | spade)')
87 | parser.add_argument('--netA_sync', type=str, default='ressesync', help='selects model to use for netA (audio | spade)')
88 | parser.add_argument('--netV', type=str, default='resnext', help='selects model to use for netV (mobile | id)')
89 | parser.add_argument('--netE', type=str, default='fan', help='selects model to use for netV (mobile | fan)')
90 | parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image|projection)')
91 | parser.add_argument('--D_input', type=str, default='single', help='(concat|single|hinge)')
92 | parser.add_argument('--driven_type', type=str, default='face', help='selects model to use for netV (heatmap | face)')
93 | parser.add_argument('--landmark_type', type=str, default='min', help='selects model to use for netV (mobile | fan)')
94 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
95 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
96 | parser.add_argument('--feature_fusion', type=str, default='concat', help='style fusion method')
97 | parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
98 |
99 | # for instance-wise features
100 | parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
101 | parser.add_argument('--input_id_feature', action='store_true', help='if specified, use id feature as style gan input')
102 | parser.add_argument('--load_landmark', action='store_true', help='if specified, load landmarks')
103 | parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
104 | parser.add_argument('--style_dim', type=int, default=2580, help='# of encoder filters in the first conv layer')
105 |
106 | ####################### weight settings ###################################################################
107 |
108 | parser.add_argument('--vgg_face', action='store_true', help='if specified, use VGG feature matching loss')
109 |
110 | parser.add_argument('--VGGFace_pretrain_path', type=str, default='', help='VGGFace pretrain path')
111 | parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
112 | parser.add_argument('--lambda_image', type=float, default=1.0, help='weight for image reconstruction')
113 | parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss')
114 | parser.add_argument('--lambda_vggface', type=float, default=5.0, help='weight for vggface loss')
115 | parser.add_argument('--lambda_rotate_D', type=float, default='0.1',
116 | help='rotated D loss weight')
117 | parser.add_argument('--lambda_D', type=float, default=1,
118 | help='D loss weight')
119 | parser.add_argument('--lambda_softmax', type=float, default=1000000, help='weight for softmax loss')
120 | parser.add_argument('--lambda_crossmodal', type=float, default=1, help='weight for softmax loss')
121 |
122 | parser.add_argument('--lambda_contrastive', type=float, default=100, help='if specified, use contrastive loss for img and audio embed')
123 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
124 |
125 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
126 | parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
127 | parser.add_argument('--no_id_loss', action='store_true', help='if specified, do *not* use cls loss')
128 | parser.add_argument('--word_loss', action='store_true', help='if specified, do *not* use cls loss')
129 | parser.add_argument('--no_spectrogram', action='store_true', help='if specified, do *not* use mel spectrogram, use mfcc')
130 |
131 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
132 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
133 |
134 | ############################## optimizer #############################
135 | parser.add_argument('--optimizer', type=str, default='adam')
136 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
137 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
138 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')
139 |
140 | parser.add_argument('--no_gaussian_landmark', action='store_true', help='whether to use no_gaussian_landmark (1.0 landmark) for rotatespade model')
141 | parser.add_argument('--label_mask', action='store_true', help='whether to use face mask')
142 | parser.add_argument('--positional_encode', action='store_true', help='whether to use positional encode')
143 | parser.add_argument('--use_transformer', action='store_true', help='whether to use transformer')
144 | parser.add_argument('--has_mask', action='store_true', help='whether to use mask in transformer')
145 | parser.add_argument('--heatmap_size', type=float, default=3, help='the size of the heatmap, used in rotatespade model')
146 |
147 | self.initialized = True
148 | return parser
149 |
150 | def gather_options(self):
151 | # initialize parser with basic options
152 | if not self.initialized:
153 | parser = argparse.ArgumentParser(
154 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
155 | parser = self.initialize(parser)
156 |
157 | # get the basic options
158 | opt, unknown = parser.parse_known_args()
159 |
160 | # modify model-related parser options
161 | model_name = opt.model
162 | model_option_setter = models.get_option_setter(model_name)
163 | parser = model_option_setter(parser, self.isTrain)
164 |
165 | # modify dataset-related parser options
166 | dataset_mode = opt.dataset_mode
167 | dataset_modes = opt.dataset_mode.split(',')
168 |
169 | if len(dataset_modes) == 1:
170 | dataset_option_setter = data.get_option_setter(dataset_mode)
171 | parser = dataset_option_setter(parser, self.isTrain)
172 | else:
173 | for dm in dataset_modes:
174 | dataset_option_setter = data.get_option_setter(dm)
175 | parser = dataset_option_setter(parser, self.isTrain)
176 |
177 | opt, unknown = parser.parse_known_args()
178 |
179 | # if there is opt_file, load it.
180 | # lt options will be overwritten
181 | if opt.load_from_opt_file:
182 | parser = self.update_options_from_file(parser, opt)
183 |
184 | opt = parser.parse_args()
185 | self.parser = parser
186 | return opt
187 |
188 | def print_options(self, opt):
189 | message = ''
190 | message += '----------------- Options ---------------\n'
191 | for k, v in sorted(vars(opt).items()):
192 | comment = ''
193 | default = self.parser.get_default(k)
194 | if v != default:
195 | comment = '\t[default: %s]' % str(default)
196 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
197 | message += '----------------- End -------------------'
198 | print(message)
199 |
200 | def option_file_path(self, opt, makedir=False):
201 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
202 | if makedir:
203 | util.mkdirs(expr_dir)
204 | file_name = os.path.join(expr_dir, 'opt')
205 | return file_name
206 |
207 | def save_options(self, opt):
208 | file_name = self.option_file_path(opt, makedir=True)
209 | with open(file_name + '.txt', 'wt') as opt_file:
210 | for k, v in sorted(vars(opt).items()):
211 | comment = ''
212 | default = self.parser.get_default(k)
213 | if v != default:
214 | comment = '\t[default: %s]' % str(default)
215 | opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
216 |
217 | with open(file_name + '.pkl', 'wb') as opt_file:
218 | pickle.dump(opt, opt_file)
219 |
220 | def update_options_from_file(self, parser, opt):
221 | new_opt = self.load_options(opt)
222 | for k, v in sorted(vars(opt).items()):
223 | if hasattr(new_opt, k) and v != getattr(new_opt, k):
224 | new_val = getattr(new_opt, k)
225 | parser.set_defaults(**{k: new_val})
226 | return parser
227 |
228 | def load_options(self, opt):
229 | file_name = self.option_file_path(opt, makedir=False)
230 | new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
231 | return new_opt
232 |
233 | def parse(self, save=False):
234 |
235 | opt = self.gather_options()
236 | opt.isTrain = self.isTrain # train or test
237 |
238 | self.print_options(opt)
239 | if opt.isTrain:
240 | self.save_options(opt)
241 | # Set semantic_nc based on the option.
242 | # This will be convenient in many places
243 | # set gpu ids
244 | str_ids = opt.gpu_ids.split(',')
245 | opt.gpu_ids = []
246 | for str_id in str_ids:
247 | id = int(str_id)
248 | if id >= 0:
249 | opt.gpu_ids.append(id)
250 | if len(opt.gpu_ids) > 0:
251 | torch.cuda.set_device(opt.gpu_ids[0])
252 |
253 |
254 | self.opt = opt
255 | return self.opt
256 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self, parser):
6 | BaseOptions.initialize(self, parser)
7 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
8 | parser.add_argument('--input_path', type=str, default='./checkpoints/results/input_path', help='defined input path.')
9 | parser.add_argument('--meta_path_vox', type=str, default='./misc/demo.csv', help='the meta data path')
10 | parser.add_argument('--driving_pose', action='store_true', help='driven pose to generate a talking face')
11 | parser.add_argument('--list_num', type=int, default=0, help='list num')
12 | parser.add_argument('--fitting_iterations', type=int, default=10, help='The iterarions for fit testing')
13 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
14 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run')
15 | parser.add_argument('--start_ind', type=int, default=0, help='the start id for defined driven')
16 | parser.add_argument('--list_start', type=int, default=0, help='which num in the list to start')
17 | parser.add_argument('--list_end', type=int, default=float("inf"), help='how many test images to run')
18 | parser.add_argument('--save_path', type=str, default='./results/', help='where to save data')
19 | parser.add_argument('--multi_gpu', action='store_true', help='whether to use multi gpus')
20 | parser.add_argument('--defined_driven', action='store_true', help='whether to use defined driven')
21 | parser.add_argument('--gen_video', action='store_true', help='whether to generate videos')
22 | parser.add_argument('--onnx', action='store_true', help='for tddfa')
23 | parser.add_argument('--mode', type=str, default='cpu', help='gpu or cpu mode')
24 |
25 | # parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256)
26 | # parser.set_defaults(serial_batches=True)
27 | parser.set_defaults(no_flip=True)
28 | parser.set_defaults(phase='test')
29 | self.isTrain = False
30 | return parser
31 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self, parser):
6 | BaseOptions.initialize(self, parser)
7 | # for displays
8 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
9 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
10 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
11 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
12 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
13 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
14 | parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
15 | parser.add_argument('--tensorboard', default=True, help='if specified, use tensorboard logging. Requires tensorflow installed')
16 | parser.add_argument('--load_pretrain', type=str, default='',
17 | help='load the pretrained model from the specified location')
18 |
19 | # for training
20 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
21 | parser.add_argument('--recognition', action='store_true', help='train only recognition')
22 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
23 | parser.add_argument('--noload_D', action='store_true', help='whether to load D when continue training')
24 | parser.add_argument('--pose_noise', action='store_true', help='whether to use pose noise training')
25 | parser.add_argument('--load_separately', action='store_true', help='whether to continue train by loading separate models')
26 | parser.add_argument('--niter', type=int, default=10, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
27 | parser.add_argument('--niter_decay', type=int, default=1000, help='# of iter to linearly decay learning rate to zero')
28 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.')
29 |
30 | parser.add_argument('--G_pretrain_path', type=str, default='./checkpoints/100_net_G.pth', help='G pretrain path')
31 | parser.add_argument('--D_pretrain_path', type=str, default='', help='D pretrain path')
32 | parser.add_argument('--E_pretrain_path', type=str, default='', help='E pretrain path')
33 | parser.add_argument('--V_pretrain_path', type=str, default='', help='V pretrain path')
34 | parser.add_argument('--A_pretrain_path', type=str, default='', help='E pretrain path')
35 | parser.add_argument('--A_sync_pretrain_path', type=str, default='', help='E pretrain path')
36 | parser.add_argument('--netE_pretrain_path', type=str, default='', help='E pretrain path')
37 |
38 | parser.add_argument('--fix_netV', action='store_true', help='if specified, fix net V')
39 | parser.add_argument('--fix_netE', action='store_true', help='if specified, fix net E')
40 | parser.add_argument('--fix_netE_mouth', action='store_true', help='if specified, fix net E mapper, fc and mapper')
41 | parser.add_argument('--fix_netE_mouth_embed', action='store_true', help='if specified, fix net E mapper, fc and mapper')
42 | parser.add_argument('--fix_netE_headpose', action='store_true', help='if specified, fix net E headpose')
43 | parser.add_argument('--fix_netA_sync', action='store_true', help='if specified fix net A_sync')
44 | parser.add_argument('--fix_netG', action='store_true', help='if specified, fix net G')
45 | parser.add_argument('--fix_netD', action='store_true', help='if specified, fix net D')
46 | parser.add_argument('--no_cross_modal', action='store_true', help='if specified, do *not* use cls loss')
47 | parser.add_argument('--softmax_contrastive', action='store_true', help='if specified, use contrastive loss for img and audio embed')
48 | # for discriminators
49 |
50 | parser.add_argument('--baseline_sync', action='store_true', help='train baseline sync')
51 | parser.add_argument('--style_feature_loss', action='store_true', help='to use style feature loss')
52 | # parser.add_argument('--vggface_checkpoint', type=str, default='', help='pth to vggface ckpt')
53 | parser.add_argument('--pretrain', action='store_true', help='Use outsider pretrain')
54 | parser.add_argument('--disentangle', action='store_true', help='whether to use disentangle loss')
55 | self.isTrain = True
56 | return parser
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.2.0
2 | torchvision
3 | dominate>=2.3.1
4 | dill
5 | scikit-image
6 | numpy>=1.15.4
7 | scipy>=1.1.0
8 | matplotlib
9 | opencv-python>=3.4.3.18
10 | tensorboard==1.14.0
11 | tqdm
12 | librosa
--------------------------------------------------------------------------------
/scripts/align_68.py:
--------------------------------------------------------------------------------
1 | import face_alignment
2 | import os
3 | import cv2
4 | import skimage.transform as trans
5 | import argparse
6 | import torch
7 | import numpy as np
8 |
9 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
10 |
11 |
12 | def get_affine(src):
13 | dst = np.array([[87, 59],
14 | [137, 59],
15 | [112, 120]], dtype=np.float32)
16 | tform = trans.SimilarityTransform()
17 | tform.estimate(src, dst)
18 | M = tform.params[0:2, :]
19 | return M
20 |
21 |
22 | def affine_align_img(img, M, crop_size=224):
23 | warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
24 | return warped
25 |
26 |
27 | def affine_align_3landmarks(landmarks, M):
28 | new_landmarks = np.concatenate([landmarks, np.ones((3, 1))], 1)
29 | affined_landmarks = np.matmul(new_landmarks, M.transpose())
30 | return affined_landmarks
31 |
32 |
33 | def get_eyes_mouths(landmark):
34 | three_points = np.zeros((3, 2))
35 | three_points[0] = landmark[36:42].mean(0)
36 | three_points[1] = landmark[42:48].mean(0)
37 | three_points[2] = landmark[60:68].mean(0)
38 | return three_points
39 |
40 |
41 | def get_mouth_bias(three_points):
42 | bias = np.array([112, 120]) - three_points[2]
43 | return bias
44 |
45 |
46 | def align_folder(folder_path, folder_save_path):
47 |
48 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device)
49 | preds = fa.get_landmarks_from_directory(folder_path)
50 |
51 | sumpoints = 0
52 | three_points_list = []
53 |
54 | for img in preds.keys():
55 | pred_points = np.array(preds[img])
56 | if pred_points is None or len(pred_points.shape) != 3:
57 | print('preprocessing failed')
58 | return False
59 | else:
60 | num_faces, size, _ = pred_points.shape
61 | if num_faces == 1 and size == 68:
62 |
63 | three_points = get_eyes_mouths(pred_points[0])
64 | sumpoints += three_points
65 | three_points_list.append(three_points)
66 | else:
67 |
68 | print('preprocessing failed')
69 | return False
70 | avg_points = sumpoints / len(preds)
71 | M = get_affine(avg_points)
72 | p_bias = None
73 | for i, img_pth in enumerate(preds.keys()):
74 | three_points = three_points_list[i]
75 | affined_3landmarks = affine_align_3landmarks(three_points, M)
76 | bias = get_mouth_bias(affined_3landmarks)
77 | if p_bias is None:
78 | bias = bias
79 | else:
80 | bias = p_bias * 0.2 + bias * 0.8
81 | p_bias = bias
82 | M_i = M.copy()
83 | M_i[:, 2] = M[:, 2] + bias
84 | img = cv2.imread(img_pth)
85 | wrapped = affine_align_img(img, M_i)
86 | img_save_path = os.path.join(folder_save_path, img_pth.split('/')[-1])
87 | cv2.imwrite(img_save_path, wrapped)
88 | print('cropped files saved at {}'.format(folder_save_path))
89 |
90 |
91 | def main():
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument('--folder_path', help='the folder which needs processing')
94 | args = parser.parse_args()
95 |
96 | if os.path.isdir(args.folder_path):
97 | home_path = '/'.join(args.folder_path.split('/')[:-1])
98 | save_img_path = os.path.join(home_path, args.folder_path.split('/')[-1] + '_cropped')
99 | os.makedirs(save_img_path, exist_ok=True)
100 |
101 | align_folder(args.folder_path, save_img_path)
102 |
103 |
104 | if __name__ == '__main__':
105 | main()
106 |
--------------------------------------------------------------------------------
/scripts/prepare_testing_files.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
4 | import argparse
5 | import glob
6 | import csv
7 | import numpy as np
8 | from config.AudioConfig import AudioConfig
9 |
10 |
11 | def mkdir(path):
12 | if not os.path.exists(path):
13 | os.makedirs(path)
14 |
15 |
16 | def proc_frames(src_path, dst_path):
17 | cmd = 'ffmpeg -i \"{}\" -start_number 0 -qscale:v 2 \"{}\"/%06d.jpg -loglevel error -y'.format(src_path, dst_path)
18 | os.system(cmd)
19 | frames = glob.glob(os.path.join(dst_path, '*.jpg'))
20 | return len(frames)
21 |
22 |
23 | def proc_audio(src_mouth_path, dst_audio_path):
24 | audio_command = 'ffmpeg -i \"{}\" -loglevel error -y -f wav -acodec pcm_s16le ' \
25 | '-ar 16000 \"{}\"'.format(src_mouth_path, dst_audio_path)
26 | os.system(audio_command)
27 |
28 |
29 | if __name__ == "__main__":
30 | parser = argparse.ArgumentParser()
31 | # parser.add_argument('--dst_dir_path', default='/mnt/lustre/DATAshare3/VoxCeleb2',
32 | # help="dst file position")
33 | parser.add_argument('--dir_path', default='./misc',
34 | help="dst file position")
35 | parser.add_argument('--src_pose_path', default='./misc/Pose_Source/00473.mp4',
36 | help="pose source file position, this could be an mp4 or a folder")
37 | parser.add_argument('--src_audio_path', default='./misc/Audio_Source/00015.mp4',
38 | help="audio source file position, it could be an mp3 file or an mp4 video with audio")
39 | parser.add_argument('--src_mouth_frame_path', default=None,
40 | help="mouth frame file position, the video frames synced with audios")
41 | parser.add_argument('--src_input_path', default='./misc/Input/00098.mp4',
42 | help="input file position, it could be a folder with frames, a jpg or an mp4")
43 | parser.add_argument('--csv_path', default='./misc/demo2.csv',
44 | help="path to output index files")
45 | parser.add_argument('--convert_spectrogram', action='store_true', help='whether to convert audio to spectrogram')
46 |
47 | args = parser.parse_args()
48 | dir_path = args.dir_path
49 | mkdir(dir_path)
50 |
51 | # ===================== process input =======================================================
52 | input_save_path = os.path.join(dir_path, 'Input')
53 | mkdir(input_save_path)
54 | input_name = args.src_input_path.split('/')[-1].split('.')[0]
55 | num_inputs = 1
56 | dst_input_path = os.path.join(input_save_path, input_name)
57 | mkdir(dst_input_path)
58 | if args.src_input_path.split('/')[-1].split('.')[-1] == 'mp4':
59 | num_inputs = proc_frames(args.src_input_path, dst_input_path)
60 | elif os.path.isdir(args.src_input_path):
61 | dst_input_path = args.src_input_path
62 | else:
63 | os.system('cp {} {}'.format(args.src_input_path, os.path.join(dst_input_path, args.src_input_path.split('/')[-1])))
64 |
65 |
66 | # ===================== process audio =======================================================
67 | audio_source_save_path = os.path.join(dir_path, 'Audio_Source')
68 | mkdir(audio_source_save_path)
69 | audio_name = args.src_audio_path.split('/')[-1].split('.')[0]
70 | spec_dir = 'None'
71 | dst_audio_path = os.path.join(audio_source_save_path, audio_name + '.mp3')
72 |
73 | if args.src_audio_path.split('/')[-1].split('.')[-1] == 'mp3':
74 | os.system('cp {} {}'.format(args.src_audio_path, dst_audio_path))
75 | if args.src_mouth_frame_path and os.path.isdir(args.src_mouth_frame_path):
76 | dst_mouth_frame_path = args.src_mouth_frame_path
77 | num_mouth_frames = len(glob.glob(os.path.join(args.src_mouth_frame_path, '*.jpg')) + glob.glob(os.path.join(args.src_mouth_frame_path, '*.png')))
78 | else:
79 | dst_mouth_frame_path = 'None'
80 | num_mouth_frames = 0
81 | else:
82 | mouth_source_save_path = os.path.join(dir_path, 'Mouth_Source')
83 | mkdir(mouth_source_save_path)
84 | dst_mouth_frame_path = os.path.join(mouth_source_save_path, audio_name)
85 | mkdir(dst_mouth_frame_path)
86 | proc_audio(args.src_audio_path, dst_audio_path)
87 | num_mouth_frames = proc_frames(args.src_audio_path, dst_mouth_frame_path)
88 |
89 | if args.convert_spectrogram:
90 | audio = AudioConfig(fft_size=1280, hop_size=160)
91 | wav = audio.read_audio(dst_audio_path)
92 | spectrogram = audio.audio_to_spectrogram(wav)
93 | spec_dir = os.path.join(audio_source_save_path, audio_name + '.npy')
94 | np.save(spec_dir,
95 | spectrogram.astype(np.float32), allow_pickle=False)
96 |
97 | # ===================== process pose =======================================================
98 | if os.path.isdir(args.src_pose_path):
99 | num_pose_frames = len(glob.glob(os.path.join(args.src_pose_path, '*.jpg')) + glob.glob(os.path.join(args.src_pose_path, '*.png')))
100 | dst_pose_frame_path = args.src_pose_path
101 | else:
102 | pose_source_save_path = os.path.join(dir_path, 'Pose_Source')
103 | mkdir(pose_source_save_path)
104 | pose_name = args.src_pose_path.split('/')[-1].split('.')[0]
105 | dst_pose_frame_path = os.path.join(pose_source_save_path, pose_name)
106 | mkdir(dst_pose_frame_path)
107 | num_pose_frames = proc_frames(args.src_pose_path, dst_pose_frame_path)
108 |
109 | # ===================== form csv =======================================================
110 |
111 | with open(args.csv_path, 'w', newline='') as csvfile:
112 | writer = csv.writer(csvfile, delimiter=' ', quoting=csv.QUOTE_MINIMAL)
113 | writer.writerows([[dst_input_path, str(num_inputs), dst_pose_frame_path, str(num_pose_frames),
114 | dst_audio_path, dst_mouth_frame_path, str(num_mouth_frames), spec_dir]])
115 | print('meta-info saved at ' + args.csv_path)
116 |
117 | csvfile.close()
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import dominate
3 | from dominate.tags import *
4 | import os
5 |
6 |
7 | class HTML:
8 | def __init__(self, web_dir, title, refresh=0):
9 | if web_dir.endswith('.html'):
10 | web_dir, html_name = os.path.split(web_dir)
11 | else:
12 | web_dir, html_name = web_dir, 'index.html'
13 | self.title = title
14 | self.web_dir = web_dir
15 | self.html_name = html_name
16 | self.img_dir = os.path.join(self.web_dir, 'images')
17 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir):
18 | os.makedirs(self.web_dir)
19 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir):
20 | os.makedirs(self.img_dir)
21 |
22 | self.doc = dominate.document(title=title)
23 | with self.doc:
24 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
25 | if refresh > 0:
26 | with self.doc.head:
27 | meta(http_equiv="refresh", content=str(refresh))
28 |
29 | def get_image_dir(self):
30 | return self.img_dir
31 |
32 | def add_header(self, str):
33 | with self.doc:
34 | h3(str)
35 |
36 | def add_table(self, border=1):
37 | self.t = table(border=border, style="table-layout: fixed;")
38 | self.doc.add(self.t)
39 |
40 | def add_images(self, ims, txts, links, width=512):
41 | self.add_table()
42 | with self.t:
43 | with tr():
44 | for im, txt, link in zip(ims, txts, links):
45 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
46 | with p():
47 | with a(href=os.path.join('images', link)):
48 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
49 | br()
50 | p(txt.encode('utf-8'))
51 |
52 | def save(self):
53 | html_file = os.path.join(self.web_dir, self.html_name)
54 | f = open(html_file, 'wt')
55 | f.write(self.doc.render())
56 | f.close()
57 |
58 |
59 | if __name__ == '__main__':
60 | html = HTML('web/', 'test_html')
61 | html.add_header('hello world')
62 |
63 | ims = []
64 | txts = []
65 | links = []
66 | for n in range(4):
67 | ims.append('image_%d.jpg' % n)
68 | txts.append('text_%d' % n)
69 | links.append('image_%d.jpg' % n)
70 | html.add_images(ims, txts, links)
71 | html.save()
72 |
--------------------------------------------------------------------------------
/util/iter_counter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 |
5 |
6 | # Helper class that keeps track of training iterations
7 | class IterationCounter():
8 | def __init__(self, opt, dataset_size):
9 | self.opt = opt
10 | self.dataset_size = dataset_size
11 |
12 | self.first_epoch = 1
13 | self.total_epochs = opt.niter + opt.niter_decay if opt.isTrain else 1
14 | self.epoch_iter = 0 # iter number within each epoch
15 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
16 | if opt.isTrain and opt.continue_train:
17 | try:
18 | self.first_epoch, self.epoch_iter = np.loadtxt(
19 | self.iter_record_path, delimiter=',', dtype=int)
20 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter))
21 | except:
22 | print('Could not load iteration record at %s. Starting from beginning.' %
23 | self.iter_record_path)
24 |
25 | self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter
26 |
27 | # return the iterator of epochs for the training
28 | def training_epochs(self):
29 | return range(self.first_epoch, self.total_epochs + 1)
30 |
31 | def record_epoch_start(self, epoch):
32 | self.epoch_start_time = time.time()
33 | self.epoch_iter = 0
34 | self.last_iter_time = time.time()
35 | self.current_epoch = epoch
36 |
37 | def record_one_iteration(self):
38 | current_time = time.time()
39 |
40 | # the last remaining batch is dropped (see data/__init__.py),
41 | # so we can assume batch size is always opt.batchSize
42 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
43 | self.last_iter_time = current_time
44 | self.total_steps_so_far += self.opt.batchSize
45 | self.epoch_iter += self.opt.batchSize
46 |
47 | def record_epoch_end(self):
48 | current_time = time.time()
49 | self.time_per_epoch = current_time - self.epoch_start_time
50 | print('End of epoch %d / %d \t Time Taken: %d sec' %
51 | (self.current_epoch, self.total_epochs, self.time_per_epoch))
52 | if self.current_epoch % self.opt.save_epoch_freq == 0:
53 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0),
54 | delimiter=',', fmt='%d')
55 | print('Saved current iteration count at %s.' % self.iter_record_path)
56 |
57 | def record_current_iter(self):
58 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter),
59 | delimiter=',', fmt='%d')
60 | print('Saved current iteration count at %s.' % self.iter_record_path)
61 |
62 | def needs_saving(self):
63 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
64 |
65 | def needs_printing(self):
66 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
67 |
68 | def needs_displaying(self):
69 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
70 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | import re
2 | import importlib
3 | import torch
4 | from argparse import Namespace
5 | import numpy as np
6 | from PIL import Image
7 | import os
8 | import argparse
9 | import dill as pickle
10 | import skimage.transform as trans
11 | import cv2
12 |
13 |
14 | def save_obj(obj, name):
15 | with open(name, 'wb') as f:
16 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
17 |
18 |
19 | def load_obj(name):
20 | with open(name, 'rb') as f:
21 | return pickle.load(f)
22 |
23 | # returns a configuration for creating a generator
24 | # |default_opt| should be the opt of the current experiment
25 | # |**kwargs|: if any configuration should be overriden, it can be specified here
26 |
27 |
28 | def copyconf(default_opt, **kwargs):
29 | conf = argparse.Namespace(**vars(default_opt))
30 | for key in kwargs:
31 | print(key, kwargs[key])
32 | setattr(conf, key, kwargs[key])
33 | return conf
34 |
35 |
36 | def tile_images(imgs, picturesPerRow=4):
37 | """ Code borrowed from
38 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997
39 | """
40 |
41 | # Padding
42 | if imgs.shape[0] % picturesPerRow == 0:
43 | rowPadding = 0
44 | else:
45 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow
46 | if rowPadding > 0:
47 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0)
48 |
49 | # Tiling Loop (The conditionals are not necessary anymore)
50 | tiled = []
51 | for i in range(0, imgs.shape[0], picturesPerRow):
52 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1))
53 |
54 | tiled = np.concatenate(tiled, axis=0)
55 | return tiled
56 |
57 |
58 | # Converts a Tensor into a Numpy array
59 | # |imtype|: the desired type of the converted numpy array
60 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=True):
61 | if isinstance(image_tensor, list):
62 | image_numpy = []
63 | for i in range(len(image_tensor)):
64 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
65 | return image_numpy
66 |
67 | if image_tensor.dim() == 4:
68 | # transform each image in the batch
69 | images_np = []
70 | for b in range(image_tensor.size(0)):
71 | one_image = image_tensor[b]
72 | one_image_np = tensor2im(one_image)
73 | images_np.append(one_image_np.reshape(1, *one_image_np.shape))
74 | images_np = np.concatenate(images_np, axis=0)
75 | if tile:
76 | images_tiled = tile_images(images_np)
77 | return images_tiled
78 | else:
79 | if len(images_np.shape) == 4 and images_np.shape[0] == 1:
80 | images_np = images_np[0]
81 | return images_np
82 |
83 | if image_tensor.dim() == 2:
84 | image_tensor = image_tensor.unsqueeze(0)
85 | image_numpy = image_tensor.detach().cpu().float().numpy()
86 | if normalize:
87 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
88 | else:
89 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
90 | image_numpy = np.clip(image_numpy, 0, 255)
91 | if image_numpy.shape[2] == 1:
92 | image_numpy = image_numpy[:, :, 0]
93 | return image_numpy.astype(imtype)
94 |
95 |
96 |
97 | def save_image(image_numpy, image_path, create_dir=False):
98 | if create_dir:
99 | os.makedirs(os.path.dirname(image_path), exist_ok=True)
100 | if len(image_numpy.shape) == 4:
101 | image_numpy = image_numpy[0]
102 | if len(image_numpy.shape) == 2:
103 | image_numpy = np.expand_dims(image_numpy, axis=2)
104 | if image_numpy.shape[2] == 1:
105 | image_numpy = np.repeat(image_numpy, 3, 2)
106 | image_pil = Image.fromarray(image_numpy)
107 |
108 | # save to png
109 | image_pil.save(image_path)
110 | # image_pil.save(image_path.replace('.jpg', '.png'))
111 |
112 |
113 | def save_torch_img(img, save_path):
114 | image_numpy = tensor2im(img,tile=False)
115 | save_image(image_numpy, save_path, create_dir=True)
116 | return image_numpy
117 |
118 |
119 |
120 | def mkdirs(paths):
121 | if isinstance(paths, list) and not isinstance(paths, str):
122 | for path in paths:
123 | mkdir(path)
124 | else:
125 | mkdir(paths)
126 |
127 |
128 | def mkdir(path):
129 | if not os.path.exists(path):
130 | os.makedirs(path)
131 |
132 |
133 | def atoi(text):
134 | return int(text) if text.isdigit() else text
135 |
136 |
137 | def natural_keys(text):
138 | '''
139 | alist.sort(key=natural_keys) sorts in human order
140 | http://nedbatchelder.com/blog/200712/human_sorting.html
141 | (See Toothy's implementation in the comments)
142 | '''
143 | return [atoi(c) for c in re.split('(\d+)', text)]
144 |
145 |
146 | def natural_sort(items):
147 | items.sort(key=natural_keys)
148 |
149 |
150 | def str2bool(v):
151 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
152 | return True
153 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
154 | return False
155 | else:
156 | raise argparse.ArgumentTypeError('Boolean value expected.')
157 |
158 |
159 | def find_class_in_module(target_cls_name, module):
160 | target_cls_name = target_cls_name.replace('_', '').lower()
161 | clslib = importlib.import_module(module)
162 | cls = None
163 | for name, clsobj in clslib.__dict__.items():
164 | if name.lower() == target_cls_name:
165 | cls = clsobj
166 |
167 | if cls is None:
168 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
169 | exit(0)
170 |
171 | return cls
172 |
173 |
174 | def save_network(net, label, epoch, opt):
175 | save_filename = '%s_net_%s.pth' % (epoch, label)
176 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
177 | torch.save(net.cpu().state_dict(), save_path)
178 | if len(opt.gpu_ids) and torch.cuda.is_available():
179 | net.cuda()
180 |
181 |
182 | def load_network(net, label, epoch, opt):
183 | save_filename = '%s_net_%s.pth' % (epoch, label)
184 | save_dir = os.path.join(opt.checkpoints_dir, opt.name)
185 | save_path = os.path.join(save_dir, save_filename)
186 | weights = torch.load(save_path)
187 | net.load_state_dict(weights)
188 | return net
189 |
190 |
191 | def copy_state_dict(state_dict, model, strip=None, replace=None):
192 | tgt_state = model.state_dict()
193 | copied_names = set()
194 | for name, param in state_dict.items():
195 | if strip is not None and replace is None and name.startswith(strip):
196 | name = name[len(strip):]
197 | if strip is not None and replace is not None:
198 | name = name.replace(strip, replace)
199 | if name not in tgt_state:
200 | continue
201 | if isinstance(param, torch.nn.Parameter):
202 | param = param.data
203 | if param.size() != tgt_state[name].size():
204 | print('mismatch:', name, param.size(), tgt_state[name].size())
205 | continue
206 | tgt_state[name].copy_(param)
207 | copied_names.add(name)
208 |
209 | missing = set(tgt_state.keys()) - copied_names
210 | if len(missing) > 0:
211 | print("missing keys in state_dict:", missing)
212 |
213 |
214 |
215 | def freeze_model(net):
216 | for param in net.parameters():
217 | param.requires_grad = False
218 | ###############################################################################
219 | # Code from
220 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
221 | # Modified so it complies with the Citscape label map colors
222 | ###############################################################################
223 | def uint82bin(n, count=8):
224 | """returns the binary of integer n, count refers to amount of bits"""
225 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
226 |
227 | def build_landmark_dict(ldmk_path):
228 | with open(ldmk_path) as f:
229 | lines = f.readlines()
230 | ldmk_dict = {}
231 | paths = []
232 | for line in lines:
233 | info = line.strip().split()
234 | key = info[-1]
235 | if "/" in key:
236 | key = key.split("/")[-1]
237 | # key = int(key.split(".")[0])
238 | value = info[:-1]
239 | paths.append(key)
240 | value = [float(it) for it in value]
241 | if len(info) == 106 * 2 + 1: # landmark+name
242 | value = [float(it) for it in info[:106 * 2]]
243 | elif len(info) == 106 * 2 + 1 + 6: # affmat+landmark+name
244 | value = [float(it) for it in info[6:106 * 2 + 6]]
245 | elif len(info) == 20 * 2 + 2: # mouth landmark+name
246 | value = [float(it) for it in info[:-1]]
247 | ldmk_dict[key] = value
248 | return ldmk_dict, paths
249 |
250 |
251 | def get_affine(src, dst):
252 | tform = trans.SimilarityTransform()
253 | tform.estimate(src, dst)
254 | M = tform.params[0:2, :]
255 | return M
256 |
257 | def affine_align_img(img, M, crop_size=224):
258 | warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0)
259 | return warped
260 |
261 | def calc_loop_idx(idx, loop_num):
262 | flag = -1 * ((idx // loop_num % 2) * 2 - 1)
263 | new_idx = -flag * (flag - 1) // 2 + flag * (idx % loop_num)
264 | return (new_idx + loop_num) % loop_num
265 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ntpath
3 | import time
4 | from . import util
5 | from . import html
6 | import scipy.misc
7 | import torch
8 | import torchvision.utils as vutils
9 | from torch.utils.tensorboard import SummaryWriter
10 | try:
11 | from StringIO import StringIO # Python 2.7
12 | except ImportError:
13 | from io import BytesIO # Python 3.x
14 |
15 | class Visualizer():
16 | def __init__(self, opt):
17 | self.opt = opt
18 | self.tf_log = opt.isTrain and opt.tf_log
19 | self.tensorboard = opt.isTrain and opt.tensorboard
20 | self.use_html = opt.isTrain and not opt.no_html
21 | self.win_size = opt.display_winsize
22 | self.name = opt.name
23 | if self.tf_log:
24 | import tensorflow as tf
25 | self.tf = tf
26 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
27 | self.writer = tf.summary.FileWriter(self.log_dir)
28 |
29 | if self.tensorboard:
30 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
31 | self.writer = SummaryWriter(self.log_dir, comment=opt.name)
32 |
33 | if self.use_html:
34 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
35 | self.img_dir = os.path.join(self.web_dir, 'images')
36 | print('create web directory %s...' % self.web_dir)
37 | util.mkdirs([self.web_dir, self.img_dir])
38 | if opt.isTrain:
39 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
40 | with open(self.log_name, "a") as log_file:
41 | now = time.strftime("%c")
42 | log_file.write('================ Training Loss (%s) ================\n' % now)
43 |
44 | # |visuals|: dictionary of images to display or save
45 | def display_current_results(self, visuals, epoch, step):
46 |
47 | ## convert tensors to numpy arrays
48 |
49 |
50 | if self.tf_log: # show images in tensorboard output
51 | img_summaries = []
52 | visuals = self.convert_visuals_to_numpy(visuals)
53 | for label, image_numpy in visuals.items():
54 | # Write the image to a string
55 | try:
56 | s = StringIO()
57 | except:
58 | s = BytesIO()
59 | if len(image_numpy.shape) >= 4:
60 | image_numpy = image_numpy[0]
61 | scipy.misc.toimage(image_numpy).save(s, format="jpeg")
62 | # Create an Image object
63 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
64 | # Create a Summary value
65 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
66 |
67 | # Create and write Summary
68 | summary = self.tf.Summary(value=img_summaries)
69 | self.writer.add_summary(summary, step)
70 |
71 | if self.tensorboard: # show images in tensorboard output
72 | img_summaries = []
73 | for label, image_numpy in visuals.items():
74 | # Write the image to a string
75 | try:
76 | s = StringIO()
77 | except:
78 | s = BytesIO()
79 | # if len(image_numpy.shape) >= 4:
80 | # image_numpy = image_numpy[0]
81 | # scipy.misc.toimage(image_numpy).save(s, format="jpeg")
82 | # Create an Image object
83 | # self.writer.add_image(tag=label, img_tensor=image_numpy, global_step=step, dataformats='HWC')
84 | # Create a Summary value
85 | batch_size = image_numpy.size(0)
86 | x = vutils.make_grid(image_numpy[:min(batch_size, 16)], normalize=True, scale_each=True)
87 | self.writer.add_image(label, x, step)
88 |
89 |
90 | if self.use_html: # save images to a html file
91 | for label, image_numpy in visuals.items():
92 | if isinstance(image_numpy, list):
93 | for i in range(len(image_numpy)):
94 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i))
95 | util.save_image(image_numpy[i], img_path)
96 | else:
97 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label))
98 | if len(image_numpy.shape) >= 4:
99 | image_numpy = image_numpy[0]
100 | util.save_image(image_numpy, img_path)
101 |
102 | # update website
103 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5)
104 | for n in range(epoch, 0, -1):
105 | webpage.add_header('epoch [%d]' % n)
106 | ims = []
107 | txts = []
108 | links = []
109 |
110 | for label, image_numpy in visuals.items():
111 | if isinstance(image_numpy, list):
112 | for i in range(len(image_numpy)):
113 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i)
114 | ims.append(img_path)
115 | txts.append(label+str(i))
116 | links.append(img_path)
117 | else:
118 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label)
119 | ims.append(img_path)
120 | txts.append(label)
121 | links.append(img_path)
122 | if len(ims) < 10:
123 | webpage.add_images(ims, txts, links, width=self.win_size)
124 | else:
125 | num = int(round(len(ims)/2.0))
126 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
127 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
128 | webpage.save()
129 |
130 | # errors: dictionary of error labels and values
131 | def plot_current_errors(self, errors, step):
132 | if self.tf_log:
133 | for tag, value in errors.items():
134 | value = value.mean().float()
135 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
136 | self.writer.add_summary(summary, step)
137 |
138 | if self.tensorboard:
139 | for tag, value in errors.items():
140 | value = value.mean().float()
141 | self.writer.add_scalar(tag=tag, scalar_value=value, global_step=step)
142 |
143 | # errors: same format as |errors| of plotCurrentErrors
144 | def print_current_errors(self, opt, epoch, i, errors, t):
145 | message = opt.name + ' (epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
146 | for k, v in errors.items():
147 | #print(v)
148 | #if v != 0:
149 | v = v.mean().float()
150 | message += '%s: %.3f ' % (k, v)
151 |
152 | print(message)
153 | with open(self.log_name, "a") as log_file:
154 | log_file.write('%s\n' % message)
155 |
156 | def convert_visuals_to_numpy(self, visuals):
157 | for key, t in visuals.items():
158 | tile = self.opt.batchSize > 8
159 | if 'input_label' == key:
160 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile)
161 | else:
162 | t = util.tensor2im(t, tile=tile)
163 | visuals[key] = t
164 | return visuals
165 |
166 | # save image to the disk
167 | def save_images(self, webpage, visuals, image_path):
168 | visuals = self.convert_visuals_to_numpy(visuals)
169 |
170 | image_dir = webpage.get_image_dir()
171 | short_path = ntpath.basename(image_path[0])
172 | name = os.path.splitext(short_path)[0]
173 |
174 | webpage.add_header(name)
175 | ims = []
176 | txts = []
177 | links = []
178 |
179 | for label, image_numpy in visuals.items():
180 | image_name = os.path.join(label, '%s.png' % (name))
181 | save_path = os.path.join(image_dir, image_name)
182 | util.save_image(image_numpy, save_path, create_dir=True)
183 |
184 | ims.append(image_name)
185 | txts.append(label)
186 | links.append(image_name)
187 | webpage.add_images(ims, txts, links, width=self.win_size)
188 |
--------------------------------------------------------------------------------