├── .gitignore ├── LICENSE ├── README.md ├── config └── AudioConfig.py ├── data ├── __init__.py ├── base_dataset.py └── voxtest_dataset.py ├── experiments └── demo_vox.sh ├── inference.py ├── misc ├── Audio_Source.zip ├── Input.zip ├── Mouth_Source.zip ├── Pose_Source.zip ├── demo.csv ├── demo.gif ├── demo_id.gif ├── method.png └── output.gif ├── models ├── __init__.py ├── av_model.py └── networks │ ├── FAN_feature_extractor.py │ ├── __init__.py │ ├── __pycache__ │ ├── FAN_feature_extractor.cpython-36.pyc │ ├── FAN_feature_extractor.cpython-37.pyc │ ├── Voxceleb_model.cpython-36.pyc │ ├── Voxceleb_model.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── architecture.cpython-36.pyc │ ├── architecture.cpython-37.pyc │ ├── audio_architecture.cpython-36.pyc │ ├── audio_architecture.cpython-37.pyc │ ├── audio_network.cpython-37.pyc │ ├── base_network.cpython-36.pyc │ ├── base_network.cpython-37.pyc │ ├── discriminator.cpython-36.pyc │ ├── discriminator.cpython-37.pyc │ ├── encoder.cpython-36.pyc │ ├── encoder.cpython-37.pyc │ ├── generator.cpython-36.pyc │ ├── generator.cpython-37.pyc │ ├── loss.cpython-36.pyc │ ├── loss.cpython-37.pyc │ ├── normalization.cpython-36.pyc │ ├── normalization.cpython-37.pyc │ ├── stylegan2.cpython-36.pyc │ ├── stylegan2.cpython-37.pyc │ ├── vision_network.cpython-36.pyc │ └── vision_network.cpython-37.pyc │ ├── architecture.py │ ├── audio_network.py │ ├── base_network.py │ ├── discriminator.py │ ├── encoder.py │ ├── generator.py │ ├── loss.py │ ├── sync_batchnorm │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ ├── batchnorm.cpython-37.pyc │ │ ├── comm.cpython-36.pyc │ │ ├── comm.cpython-37.pyc │ │ ├── replicate.cpython-36.pyc │ │ ├── replicate.cpython-37.pyc │ │ ├── scatter_gather.cpython-36.pyc │ │ └── scatter_gather.cpython-37.pyc │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ ├── scatter_gather.py │ └── unittest.py │ ├── util.py │ └── vision_network.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── scripts ├── align_68.py └── prepare_testing_files.py └── util ├── __init__.py ├── html.py ├── iter_counter.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | **/*.pyc 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # custom 108 | demo/ 109 | checkpoints/ 110 | results/ 111 | .vscode 112 | .idea 113 | *.pkl 114 | *.pkl.json 115 | *.log.json 116 | work_dirs/ 117 | *.avi 118 | 119 | # Pytorch 120 | *.pth 121 | *.tar 122 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | Section 1 -- Definitions. 69 | 70 | a. Adapted Material means material subject to Copyright and Similar 71 | Rights that is derived from or based upon the Licensed Material 72 | and in which the Licensed Material is translated, altered, 73 | arranged, transformed, or otherwise modified in a manner requiring 74 | permission under the Copyright and Similar Rights held by the 75 | Licensor. For purposes of this Public License, where the Licensed 76 | Material is a musical work, performance, or sound recording, 77 | Adapted Material is always produced where the Licensed Material is 78 | synched in timed relation with a moving image. 79 | 80 | b. Adapter's License means the license You apply to Your Copyright 81 | and Similar Rights in Your contributions to Adapted Material in 82 | accordance with the terms and conditions of this Public License. 83 | 84 | c. Copyright and Similar Rights means copyright and/or similar rights 85 | closely related to copyright including, without limitation, 86 | performance, broadcast, sound recording, and Sui Generis Database 87 | Rights, without regard to how the rights are labeled or 88 | categorized. For purposes of this Public License, the rights 89 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 90 | Rights. 91 | 92 | d. Effective Technological Measures means those measures that, in the 93 | absence of proper authority, may not be circumvented under laws 94 | fulfilling obligations under Article 11 of the WIPO Copyright 95 | Treaty adopted on December 20, 1996, and/or similar international 96 | agreements. 97 | 98 | e. Exceptions and Limitations means fair use, fair dealing, and/or 99 | any other exception or limitation to Copyright and Similar Rights 100 | that applies to Your use of the Licensed Material. 101 | 102 | f. Licensed Material means the artistic or literary work, database, 103 | or other material to which the Licensor applied this Public 104 | License. 105 | 106 | g. Licensed Rights means the rights granted to You subject to the 107 | terms and conditions of this Public License, which are limited to 108 | all Copyright and Similar Rights that apply to Your use of the 109 | Licensed Material and that the Licensor has authority to license. 110 | 111 | h. Licensor means the individual(s) or entity(ies) granting rights 112 | under this Public License. 113 | 114 | i. Share means to provide material to the public by any means or 115 | process that requires permission under the Licensed Rights, such 116 | as reproduction, public display, public performance, distribution, 117 | dissemination, communication, or importation, and to make material 118 | available to the public including in ways that members of the 119 | public may access the material from a place and at a time 120 | individually chosen by them. 121 | 122 | j. Sui Generis Database Rights means rights other than copyright 123 | resulting from Directive 96/9/EC of the European Parliament and of 124 | the Council of 11 March 1996 on the legal protection of databases, 125 | as amended and/or succeeded, as well as other essentially 126 | equivalent rights anywhere in the world. 127 | 128 | k. You means the individual or entity exercising the Licensed Rights 129 | under this Public License. Your has a corresponding meaning. 130 | 131 | Section 2 -- Scope. 132 | 133 | a. License grant. 134 | 135 | 1. Subject to the terms and conditions of this Public License, 136 | the Licensor hereby grants You a worldwide, royalty-free, 137 | non-sublicensable, non-exclusive, irrevocable license to 138 | exercise the Licensed Rights in the Licensed Material to: 139 | 140 | a. reproduce and Share the Licensed Material, in whole or 141 | in part; and 142 | 143 | b. produce, reproduce, and Share Adapted Material. 144 | 145 | 2. Exceptions and Limitations. For the avoidance of doubt, where 146 | Exceptions and Limitations apply to Your use, this Public 147 | License does not apply, and You do not need to comply with 148 | its terms and conditions. 149 | 150 | 3. Term. The term of this Public License is specified in Section 151 | 6(a). 152 | 153 | 4. Media and formats; technical modifications allowed. The 154 | Licensor authorizes You to exercise the Licensed Rights in 155 | all media and formats whether now known or hereafter created, 156 | and to make technical modifications necessary to do so. The 157 | Licensor waives and/or agrees not to assert any right or 158 | authority to forbid You from making technical modifications 159 | necessary to exercise the Licensed Rights, including 160 | technical modifications necessary to circumvent Effective 161 | Technological Measures. For purposes of this Public License, 162 | simply making modifications authorized by this Section 2(a) 163 | (4) never produces Adapted Material. 164 | 165 | 5. Downstream recipients. 166 | 167 | a. Offer from the Licensor -- Licensed Material. Every 168 | recipient of the Licensed Material automatically 169 | receives an offer from the Licensor to exercise the 170 | Licensed Rights under the terms and conditions of this 171 | Public License. 172 | 173 | b. No downstream restrictions. You may not offer or impose 174 | any additional or different terms or conditions on, or 175 | apply any Effective Technological Measures to, the 176 | Licensed Material if doing so restricts exercise of the 177 | Licensed Rights by any recipient of the Licensed 178 | Material. 179 | 180 | 6. No endorsement. Nothing in this Public License constitutes or 181 | may be construed as permission to assert or imply that You 182 | are, or that Your use of the Licensed Material is, connected 183 | with, or sponsored, endorsed, or granted official status by, 184 | the Licensor or others designated to receive attribution as 185 | provided in Section 3(a)(1)(A)(i). 186 | 187 | b. Other rights. 188 | 189 | 1. Moral rights, such as the right of integrity, are not 190 | licensed under this Public License, nor are publicity, 191 | privacy, and/or other similar personality rights; however, to 192 | the extent possible, the Licensor waives and/or agrees not to 193 | assert any such rights held by the Licensor to the limited 194 | extent necessary to allow You to exercise the Licensed 195 | Rights, but not otherwise. 196 | 197 | 2. Patent and trademark rights are not licensed under this 198 | Public License. 199 | 200 | 3. To the extent possible, the Licensor waives any right to 201 | collect royalties from You for the exercise of the Licensed 202 | Rights, whether directly or through a collecting society 203 | under any voluntary or waivable statutory or compulsory 204 | licensing scheme. In all other cases the Licensor expressly 205 | reserves any right to collect such royalties. 206 | 207 | Section 3 -- License Conditions. 208 | 209 | Your exercise of the Licensed Rights is expressly made subject to the 210 | following conditions. 211 | 212 | a. Attribution. 213 | 214 | 1. If You Share the Licensed Material (including in modified 215 | form), You must: 216 | 217 | a. retain the following if it is supplied by the Licensor 218 | with the Licensed Material: 219 | 220 | i. identification of the creator(s) of the Licensed 221 | Material and any others designated to receive 222 | attribution, in any reasonable manner requested by 223 | the Licensor (including by pseudonym if 224 | designated); 225 | 226 | ii. a copyright notice; 227 | 228 | iii. a notice that refers to this Public License; 229 | 230 | iv. a notice that refers to the disclaimer of 231 | warranties; 232 | 233 | v. a URI or hyperlink to the Licensed Material to the 234 | extent reasonably practicable; 235 | 236 | b. indicate if You modified the Licensed Material and 237 | retain an indication of any previous modifications; and 238 | 239 | c. indicate the Licensed Material is licensed under this 240 | Public License, and include the text of, or the URI or 241 | hyperlink to, this Public License. 242 | 243 | 2. You may satisfy the conditions in Section 3(a)(1) in any 244 | reasonable manner based on the medium, means, and context in 245 | which You Share the Licensed Material. For example, it may be 246 | reasonable to satisfy the conditions by providing a URI or 247 | hyperlink to a resource that includes the required 248 | information. 249 | 250 | 3. If requested by the Licensor, You must remove any of the 251 | information required by Section 3(a)(1)(A) to the extent 252 | reasonably practicable. 253 | 254 | 4. If You Share Adapted Material You produce, the Adapter's 255 | License You apply must not prevent recipients of the Adapted 256 | Material from complying with this Public License. 257 | 258 | Section 4 -- Sui Generis Database Rights. 259 | 260 | Where the Licensed Rights include Sui Generis Database Rights that 261 | apply to Your use of the Licensed Material: 262 | 263 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 264 | to extract, reuse, reproduce, and Share all or a substantial 265 | portion of the contents of the database; 266 | 267 | b. if You include all or a substantial portion of the database 268 | contents in a database in which You have Sui Generis Database 269 | Rights, then the database in which You have Sui Generis Database 270 | Rights (but not its individual contents) is Adapted Material; and 271 | 272 | c. You must comply with the conditions in Section 3(a) if You Share 273 | all or a substantial portion of the contents of the database. 274 | 275 | For the avoidance of doubt, this Section 4 supplements and does not 276 | replace Your obligations under this Public License where the Licensed 277 | Rights include other Copyright and Similar Rights. 278 | 279 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 280 | 281 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 282 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 283 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 284 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 285 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 286 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 287 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 288 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 289 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 290 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 291 | 292 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 293 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 294 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 295 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 296 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 297 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 298 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 299 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 300 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 301 | 302 | c. The disclaimer of warranties and limitation of liability provided 303 | above shall be interpreted in a manner that, to the extent 304 | possible, most closely approximates an absolute disclaimer and 305 | waiver of all liability. 306 | 307 | Section 6 -- Term and Termination. 308 | 309 | a. This Public License applies for the term of the Copyright and 310 | Similar Rights licensed here. However, if You fail to comply with 311 | this Public License, then Your rights under this Public License 312 | terminate automatically. 313 | 314 | b. Where Your right to use the Licensed Material has terminated under 315 | Section 6(a), it reinstates: 316 | 317 | 1. automatically as of the date the violation is cured, provided 318 | it is cured within 30 days of Your discovery of the 319 | violation; or 320 | 321 | 2. upon express reinstatement by the Licensor. 322 | 323 | For the avoidance of doubt, this Section 6(b) does not affect any 324 | right the Licensor may have to seek remedies for Your violations 325 | of this Public License. 326 | 327 | c. For the avoidance of doubt, the Licensor may also offer the 328 | Licensed Material under separate terms or conditions or stop 329 | distributing the Licensed Material at any time; however, doing so 330 | will not terminate this Public License. 331 | 332 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 333 | License. 334 | 335 | Section 7 -- Other Terms and Conditions. 336 | 337 | a. The Licensor shall not be bound by any additional or different 338 | terms or conditions communicated by You unless expressly agreed. 339 | 340 | b. Any arrangements, understandings, or agreements regarding the 341 | Licensed Material not stated herein are separate from and 342 | independent of the terms and conditions of this Public License. 343 | 344 | Section 8 -- Interpretation. 345 | 346 | a. For the avoidance of doubt, this Public License does not, and 347 | shall not be interpreted to, reduce, limit, restrict, or impose 348 | conditions on any use of the Licensed Material that could lawfully 349 | be made without permission under this Public License. 350 | 351 | b. To the extent possible, if any provision of this Public License is 352 | deemed unenforceable, it shall be automatically reformed to the 353 | minimum extent necessary to make it enforceable. If the provision 354 | cannot be reformed, it shall be severed from this Public License 355 | without affecting the enforceability of the remaining terms and 356 | conditions. 357 | 358 | c. No term or condition of this Public License will be waived and no 359 | failure to comply consented to unless expressly agreed to by the 360 | Licensor. 361 | 362 | d. Nothing in this Public License constitutes or may be interpreted 363 | as a limitation upon, or waiver of, any privileges and immunities 364 | that apply to the Licensor or You, including from the legal 365 | processes of any jurisdiction or authority. 366 | 367 | ======================================================================= 368 | 369 | Creative Commons is not a party to its public licenses. 370 | Notwithstanding, Creative Commons may elect to apply one of its public 371 | licenses to material it publishes and in those instances will be 372 | considered the "Licensor." Except for the limited purpose of indicating 373 | that material is shared under a Creative Commons public license or as 374 | otherwise permitted by the Creative Commons policies published at 375 | creativecommons.org/policies, Creative Commons does not authorize the 376 | use of the trademark "Creative Commons" or any other trademark or logo 377 | of Creative Commons without its prior written consent including, 378 | without limitation, in connection with any unauthorized modifications 379 | to any of its public licenses or any other arrangements, 380 | understandings, or agreements concerning use of licensed material. For 381 | the avoidance of doubt, this paragraph does not form part of the public 382 | licenses. 383 | 384 | Creative Commons may be contacted at creativecommons.org. 385 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation (CVPR 2021) 2 | 3 | [Hang Zhou](https://hangz-nju-cuhk.github.io/), Yasheng Sun, [Wayne Wu](https://wywu.github.io/), [Chen Change Loy](http://personal.ie.cuhk.edu.hk/~ccloy/), [Xiaogang Wang](http://www.ee.cuhk.edu.hk/~xgwang/), and [Ziwei Liu](https://liuziwei7.github.io/). 4 | 5 | 6 | 7 | 8 | ### [Project](https://hangz-nju-cuhk.github.io/projects/PC-AVS) | [Paper](https://arxiv.org/abs/2104.11116) | [Demo](https://www.youtube.com/watch?v=lNQQHIggnUg) 9 | 10 | 11 | We propose **Pose-Controllable Audio-Visual System (PC-AVS)**, 12 | which achieves free pose control when driving arbitrary talking faces with audios. Instead of learning pose motions from audios, we leverage another pose source video to compensate only for head motions. 13 | The key is to devise an implicit low-dimension pose code that is free of mouth shape or identity information. 14 | In this way, audio-visual representations are modularized into spaces of three key factors: speech content, head pose, and identity information. 15 | 16 | 17 | 18 | ## Requirements 19 | * Python 3.6 and [Pytorch](https://pytorch.org/) 1.3.0 are used. Basic requirements are listed in the 'requirements.txt'. 20 | 21 | ``` 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | 26 | ## Quick Start: Generate Demo Results 27 | * Download the pre-trained [checkpoints](https://drive.google.com/file/d/1Zehr3JLIpzdg2S5zZrhIbpYPKF-4gKU_/view?usp=sharing). 28 | 29 | * Create the default folder ```./checkpoints``` and 30 | unzip the ```demo.zip``` at ```./checkpoints/demo```. There should be 5 ```pth```s in it. 31 | 32 | * Unzip all ```*.zip``` files within the ```misc``` folder. 33 | 34 | * Run the demo scripts: 35 | ``` bash 36 | bash experiments/demo_vox.sh 37 | ``` 38 | 39 | 40 | * The ```--gen_video``` argument is by default on, 41 | [ffmpeg](https://www.ffmpeg.org/) >= 4.0.0 is required to use this flag in linux systems. 42 | All frames along with an ```avconcat.mp4``` video file will be saved in the ```./id_517600055_pose_517600078_audio_681600002/results``` folder. 43 | 44 | 45 | 46 | 47 | From left to right are the *reference input*, the *generated results*, 48 | the *pose source video* and the *synced original video* with the driving audio. 49 | 50 | ## Prepare Testing Meta Data 51 | 52 | * ### Automatic VoxCeleb2 Data Formulation 53 | 54 | The inference code ```experiments/demo.sh``` refers to ```./misc/demo.csv``` for testing data paths. 55 | In linux systems, any applicable ```csv``` file can be created automatically by running: 56 | 57 | ```bash 58 | python scripts/prepare_testing_files.py 59 | ``` 60 | 61 | Then modify the ```meta_path_vox``` in ```experiments/demo_vox.sh``` to ```'./misc/demo2.csv'``` and run 62 | 63 | ``` bash 64 | bash experiments/demo_vox.sh 65 | ``` 66 | An additional result should be seen saved. 67 | 68 | * ### Metadata Details 69 | 70 | Detailedly, in ```scripts/prepare_testing_files.py``` there are certain flags which enjoy great flexibility when formulating the metadata: 71 | 72 | 1. ```--src_pose_path``` denotes the driving pose source path. 73 | It can be an ```mp4``` file or a folder containing frames in the form of ```%06d.jpg``` starting from 0. 74 | 75 | 2. ```--src_audio_path``` denotes the audio source's path. 76 | It can be an ```mp3``` audio file or an ```mp4``` video file. If a video is given, 77 | the frames will be automatically saved in ```./misc/Mouth_Source/video_name```, and disables the ```--src_mouth_frame_path``` flag. 78 | 79 | 3. ```--src_mouth_frame_path```. When ```--src_audio_path``` is not a video path, 80 | this flags could provide the folder containing the video frames synced with the source audio. 81 | 82 | 4. ```--src_input_path``` is the path to the input reference image. When the path is a video file, we will convert it to frames. 83 | 84 | 5. ```--csv_path``` the path to the to-be-saved metadata. 85 | 86 | You can manually modify the metadata ```csv``` file or add lines to it according to the rules defined in the ```scripts/prepare_testing_files.py``` file or the dataloader ```data/voxtest_dataset.py```. 87 | 88 | We provide a number of demo choices in the ```misc``` folder, including several ones used in our [video](https://www.youtube.com/watch?v=lNQQHIggnUg). 89 | Feel free to rearrange them even across folders. And you are welcome to record audio files by yourself. 90 | 91 | * ### Self-Prepared Data Processing 92 | Our model handles only **VoxCeleb2-like** cropped data, thus pre-processing is needed for self-prepared data. 93 | 94 | To process self-prepared data [face-alignment](https://github.com/1adrianb/face-alignment) is needed. It can be installed by running 95 | ``` 96 | pip install face-alignment 97 | ``` 98 | 99 | Assuming that a video is already processed into a ```[name]``` folder by previous steps through ```prepare_testing_files.py```, 100 | you can run 101 | ``` 102 | python scripts/align_68.py --folder_path [name] 103 | ``` 104 | 105 | The cropped images will be saved at an additional ```[name_cropped]``` folder. 106 | Then you can manually change the ```demo.csv``` file or alter the directory folder path and run the preprocessing file again. 107 | 108 | ## Have More Fun 109 | * We also support synthesizing and driving a talking head solely from audio. 110 | * Download the pre-trained [checkpoints](https://drive.google.com/file/d/1K-UZYRn7Oz2VEumrIRMCvr69FhQjQNx8/view?usp=share_link). 111 | * Checkout the *speech2talkingface* branch. 112 | ``` bash 113 | git checkout speech2talkingface 114 | ``` 115 | * Follow similar steps as quick start and run the demo scripts. 116 | ``` bash 117 | bash experiments/demo_vox.sh 118 | ``` 119 | 120 | 121 | 122 | From left to right are the *generated results*, 123 | the *pose source video* and the *synced original video* with the driving audio. 124 | 125 | 126 | ## Train Your Own Model 127 | * Not supported yet. 128 | 129 | ## License and Citation 130 | 131 | The usage of this software is under [CC-BY-4.0](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/LICENSE). 132 | ``` 133 | @InProceedings{zhou2021pose, 134 | author = {Zhou, Hang and Sun, Yasheng and Wu, Wayne and Loy, Chen Change and Wang, Xiaogang and Liu, Ziwei}, 135 | title = {Pose-Controllable Talking Face Generation by Implicitly Modularized Audio-Visual Representation}, 136 | booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 137 | year = {2021} 138 | } 139 | 140 | @inproceedings{sun2021speech2talking, 141 | title={Speech2Talking-Face: Inferring and Driving a Face with Synchronized Audio-Visual Representation.}, 142 | author={Sun, Yasheng and Zhou, Hang and Liu, Ziwei and Koike, Hideki}, 143 | booktitle={IJCAI}, 144 | volume={2}, 145 | pages={4}, 146 | year={2021} 147 | } 148 | ``` 149 | 150 | ## Acknowledgement 151 | * The structure of this codebase is borrowed from [SPADE](https://github.com/NVlabs/SPADE). 152 | * The generator is borrowed from [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch). 153 | * The audio encoder is borrowed from [voxceleb_trainer](https://github.com/clovaai/voxceleb_trainer). 154 | -------------------------------------------------------------------------------- /config/AudioConfig.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | from scipy import signal 5 | from scipy.io import wavfile 6 | import lws 7 | 8 | 9 | class AudioConfig: 10 | def __init__(self, frame_rate=25, 11 | sample_rate=16000, 12 | num_mels=80, 13 | fft_size=1280, 14 | hop_size=160, 15 | num_frames_per_clip=5, 16 | save_mel=True 17 | ): 18 | self.frame_rate = frame_rate 19 | self.sample_rate = sample_rate 20 | self.num_bins_per_frame = int(sample_rate / hop_size / frame_rate) 21 | self.num_frames_per_clip = num_frames_per_clip 22 | self.silence_threshold = 2 23 | self.num_mels = num_mels 24 | self.save_mel = save_mel 25 | self.fmin = 125 26 | self.fmax = 7600 27 | self.fft_size = fft_size 28 | self.hop_size = hop_size 29 | self.frame_shift_ms = None 30 | self.min_level_db = -100 31 | self.ref_level_db = 20 32 | self.rescaling = True 33 | self.rescaling_max = 0.999 34 | self.allow_clipping_in_normalization = True 35 | self.log_scale_min = -32.23619130191664 36 | self.norm_audio = True 37 | self.with_phase = False 38 | 39 | def load_wav(self, path): 40 | return librosa.core.load(path, sr=self.sample_rate)[0] 41 | 42 | def audio_normalize(self, samples, desired_rms=0.1, eps=1e-4): 43 | rms = np.maximum(eps, np.sqrt(np.mean(samples ** 2))) 44 | samples = samples * (desired_rms / rms) 45 | return samples 46 | 47 | def generate_spectrogram_magphase(self, audio): 48 | spectro = librosa.core.stft(audio, hop_length=self.get_hop_size(), n_fft=self.fft_size, center=True) 49 | spectro_mag, spectro_phase = librosa.core.magphase(spectro) 50 | spectro_mag = np.expand_dims(spectro_mag, axis=0) 51 | if self.with_phase: 52 | spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0) 53 | return spectro_mag, spectro_phase 54 | else: 55 | return spectro_mag 56 | 57 | def save_wav(self, wav, path): 58 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 59 | wavfile.write(path, self.sample_rate, wav.astype(np.int16)) 60 | 61 | def trim(self, quantized): 62 | start, end = self.start_and_end_indices(quantized, self.silence_threshold) 63 | return quantized[start:end] 64 | 65 | def adjust_time_resolution(self, quantized, mel): 66 | """Adjust time resolution by repeating features 67 | 68 | Args: 69 | quantized (ndarray): (T,) 70 | mel (ndarray): (N, D) 71 | 72 | Returns: 73 | tuple: Tuple of (T,) and (T, D) 74 | """ 75 | assert len(quantized.shape) == 1 76 | assert len(mel.shape) == 2 77 | 78 | upsample_factor = quantized.size // mel.shape[0] 79 | mel = np.repeat(mel, upsample_factor, axis=0) 80 | n_pad = quantized.size - mel.shape[0] 81 | if n_pad != 0: 82 | assert n_pad > 0 83 | mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0) 84 | 85 | # trim 86 | start, end = self.start_and_end_indices(quantized, self.silence_threshold) 87 | 88 | return quantized[start:end], mel[start:end, :] 89 | 90 | adjast_time_resolution = adjust_time_resolution # 'adjust' is correct spelling, this is for compatibility 91 | 92 | def start_and_end_indices(self, quantized, silence_threshold=2): 93 | for start in range(quantized.size): 94 | if abs(quantized[start] - 127) > silence_threshold: 95 | break 96 | for end in range(quantized.size - 1, 1, -1): 97 | if abs(quantized[end] - 127) > silence_threshold: 98 | break 99 | 100 | assert abs(quantized[start] - 127) > silence_threshold 101 | assert abs(quantized[end] - 127) > silence_threshold 102 | 103 | return start, end 104 | 105 | def melspectrogram(self, y): 106 | D = self._lws_processor().stft(y).T 107 | S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db 108 | if not self.allow_clipping_in_normalization: 109 | assert S.max() <= 0 and S.min() - self.min_level_db >= 0 110 | return self._normalize(S) 111 | 112 | def get_hop_size(self): 113 | hop_size = self.hop_size 114 | if hop_size is None: 115 | assert self.frame_shift_ms is not None 116 | hop_size = int(self.frame_shift_ms / 1000 * self.sample_rate) 117 | return hop_size 118 | 119 | def _lws_processor(self): 120 | return lws.lws(self.fft_size, self.get_hop_size(), mode="speech") 121 | 122 | def lws_num_frames(self, length, fsize, fshift): 123 | """Compute number of time frames of lws spectrogram 124 | """ 125 | pad = (fsize - fshift) 126 | if length % fshift == 0: 127 | M = (length + pad * 2 - fsize) // fshift + 1 128 | else: 129 | M = (length + pad * 2 - fsize) // fshift + 2 130 | return M 131 | 132 | def lws_pad_lr(self, x, fsize, fshift): 133 | """Compute left and right padding lws internally uses 134 | """ 135 | M = self.lws_num_frames(len(x), fsize, fshift) 136 | pad = (fsize - fshift) 137 | T = len(x) + 2 * pad 138 | r = (M - 1) * fshift + fsize - T 139 | return pad, pad + r 140 | 141 | 142 | def _linear_to_mel(self, spectrogram): 143 | global _mel_basis 144 | _mel_basis = self._build_mel_basis() 145 | return np.dot(_mel_basis, spectrogram) 146 | 147 | def _build_mel_basis(self): 148 | assert self.fmax <= self.sample_rate // 2 149 | return librosa.filters.mel(self.sample_rate, self.fft_size, 150 | fmin=self.fmin, fmax=self.fmax, 151 | n_mels=self.num_mels) 152 | 153 | def _amp_to_db(self, x): 154 | min_level = np.exp(self.min_level_db / 20 * np.log(10)) 155 | return 20 * np.log10(np.maximum(min_level, x)) 156 | 157 | def _db_to_amp(self, x): 158 | return np.power(10.0, x * 0.05) 159 | 160 | def _normalize(self, S): 161 | return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1) 162 | 163 | def _denormalize(self, S): 164 | return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db 165 | 166 | def read_audio(self, audio_path): 167 | wav = self.load_wav(audio_path) 168 | if self.norm_audio: 169 | wav = self.audio_normalize(wav) 170 | else: 171 | wav = wav / np.abs(wav).max() 172 | 173 | return wav 174 | 175 | def audio_to_spectrogram(self, wav): 176 | if self.save_mel: 177 | spectrogram = self.melspectrogram(wav).astype(np.float32).T 178 | else: 179 | spectrogram = self.generate_spectrogram_magphase(wav) 180 | return spectrogram 181 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_dataset import BaseDataset 4 | 5 | 6 | def find_dataset_using_name(dataset_name): 7 | # Given the option --dataset [datasetname], 8 | # the file "datasets/datasetname_dataset.py" 9 | # will be imported. 10 | dataset_filename = "data." + dataset_name + "_dataset" 11 | datasetlib = importlib.import_module(dataset_filename) 12 | 13 | # In the file, the class called DatasetNameDataset() will 14 | # be instantiated. It has to be a subclass of BaseDataset, 15 | # and it is case-insensitive. 16 | dataset = None 17 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 18 | for name, cls in datasetlib.__dict__.items(): 19 | if name.lower() == target_dataset_name.lower() \ 20 | and issubclass(cls, BaseDataset): 21 | dataset = cls 22 | 23 | if dataset is None: 24 | raise ValueError("In %s.py, there should be a subclass of BaseDataset " 25 | "with class name that matches %s in lowercase." % 26 | (dataset_filename, target_dataset_name)) 27 | 28 | return dataset 29 | 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | 36 | def create_dataloader(opt): 37 | dataset_modes = opt.dataset_mode.split(',') 38 | if len(dataset_modes) == 1: 39 | dataset = find_dataset_using_name(opt.dataset_mode) 40 | instance = dataset() 41 | instance.initialize(opt) 42 | print("dataset [%s] of size %d was created" % 43 | (type(instance).__name__, len(instance))) 44 | if not opt.isTrain: 45 | shuffle = False 46 | else: 47 | shuffle = True 48 | dataloader = torch.utils.data.DataLoader( 49 | instance, 50 | batch_size=opt.batchSize, 51 | shuffle=shuffle, 52 | num_workers=int(opt.nThreads), 53 | drop_last=opt.isTrain 54 | ) 55 | return dataloader 56 | 57 | else: 58 | dataloader_dict = {} 59 | for dataset_mode in dataset_modes: 60 | dataset = find_dataset_using_name(dataset_mode) 61 | instance = dataset() 62 | instance.initialize(opt) 63 | print("dataset [%s] of size %d was created" % 64 | (type(instance).__name__, len(instance))) 65 | if not opt.isTrain: 66 | shuffle = not opt.defined_driven 67 | else: 68 | shuffle = True 69 | dataloader = torch.utils.data.DataLoader( 70 | instance, 71 | batch_size=opt.batchSize, 72 | shuffle=shuffle, 73 | num_workers=int(opt.nThreads), 74 | drop_last=opt.isTrain 75 | ) 76 | dataloader_dict[dataset_mode] = dataloader 77 | return dataloader_dict 78 | 79 | 80 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import cv2 6 | 7 | 8 | class BaseDataset(data.Dataset): 9 | def __init__(self): 10 | super(BaseDataset, self).__init__() 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train): 14 | return parser 15 | 16 | def initialize(self, opt): 17 | pass 18 | 19 | def to_Tensor(self, img): 20 | if img.ndim == 3: 21 | wrapped_img = img.transpose(2, 0, 1) / 255.0 22 | elif img.ndim == 4: 23 | wrapped_img = img.transpose(0, 3, 1, 2) / 255.0 24 | else: 25 | wrapped_img = img / 255.0 26 | wrapped_img = torch.from_numpy(wrapped_img).float() 27 | 28 | return wrapped_img * 2 - 1 29 | 30 | def face_augmentation(self, img, crop_size): 31 | img = self._color_transfer(img) 32 | img = self._reshape(img, crop_size) 33 | img = self._blur_and_sharp(img) 34 | return img 35 | 36 | def _blur_and_sharp(self, img): 37 | blur = np.random.randint(0, 2) 38 | img2 = img.copy() 39 | output = [] 40 | for i in range(len(img2)): 41 | if blur: 42 | ksize = np.random.choice([3, 5, 7, 9]) 43 | output.append(cv2.medianBlur(img2[i], ksize)) 44 | else: 45 | kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]]) 46 | output.append(cv2.filter2D(img2[i], -1, kernel)) 47 | output = np.stack(output) 48 | return output 49 | 50 | def _color_transfer(self, img): 51 | 52 | transfer_c = np.random.uniform(0.3, 1.6) 53 | 54 | start_channel = np.random.randint(0, 2) 55 | end_channel = np.random.randint(start_channel + 1, 4) 56 | 57 | img2 = img.copy() 58 | 59 | img2[:, :, :, start_channel:end_channel] = np.minimum(np.maximum(img[:, :, :, start_channel:end_channel] * transfer_c, np.zeros(img[:, :, :, start_channel:end_channel].shape)), 60 | np.ones(img[:, :, :, start_channel:end_channel].shape) * 255) 61 | return img2 62 | 63 | def perspective_transform(self, img, crop_size=224, pers_size=10, enlarge_size=-10): 64 | h, w, c = img.shape 65 | dst = np.array([ 66 | [-enlarge_size, -enlarge_size], 67 | [-enlarge_size + pers_size, w + enlarge_size], 68 | [h + enlarge_size, -enlarge_size], 69 | [h + enlarge_size - pers_size, w + enlarge_size],], dtype=np.float32) 70 | src = np.array([[-enlarge_size, -enlarge_size], [-enlarge_size, w + enlarge_size], 71 | [h + enlarge_size, -enlarge_size], [h + enlarge_size, w + enlarge_size]]).astype(np.float32()) 72 | M = cv2.getPerspectiveTransform(src, dst) 73 | warped = cv2.warpPerspective(img, M, (crop_size, crop_size), borderMode=cv2.BORDER_REPLICATE) 74 | return warped, M 75 | 76 | def _reshape(self, img, crop_size): 77 | reshape = np.random.randint(0, 2) 78 | reshape_size = np.random.randint(15, 25) 79 | extra_padding_size = np.random.randint(0, reshape_size // 2) 80 | pers_size = np.random.randint(20, 30) * pow(-1, np.random.randint(2)) 81 | 82 | enlarge_size = np.random.randint(20, 40) * pow(-1, np.random.randint(2)) 83 | shape = img[0].shape 84 | img2 = img.copy() 85 | output = [] 86 | for i in range(len(img2)): 87 | if reshape: 88 | im = cv2.resize(img2[i], (shape[0] - reshape_size*2, shape[1] + reshape_size*2)) 89 | im = cv2.copyMakeBorder(im, 0, 0, reshape_size + extra_padding_size, reshape_size + extra_padding_size, cv2.cv2.BORDER_REFLECT) 90 | im = im[reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :, :] 91 | im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size) 92 | output.append(im) 93 | else: 94 | im = cv2.resize(img2[i], (shape[0] + reshape_size*2, shape[1] - reshape_size*2)) 95 | im = cv2.copyMakeBorder(im, reshape_size + extra_padding_size, reshape_size + extra_padding_size, 0, 0, cv2.cv2.BORDER_REFLECT) 96 | im = im[:, reshape_size - extra_padding_size:shape[0] + reshape_size + extra_padding_size, :] 97 | im, _ = self.perspective_transform(im, crop_size=crop_size, pers_size=pers_size, enlarge_size=enlarge_size) 98 | output.append(im) 99 | output = np.stack(output) 100 | return output -------------------------------------------------------------------------------- /data/voxtest_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | from config import AudioConfig 5 | import shutil 6 | import cv2 7 | import glob 8 | import random 9 | import torch 10 | from data.base_dataset import BaseDataset 11 | import util.util as util 12 | 13 | 14 | class VOXTestDataset(BaseDataset): 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser.add_argument('--no_pairing_check', action='store_true', 19 | help='If specified, skip sanity check of correct label-image file pairing') 20 | return parser 21 | 22 | def cv2_loader(self, img_str): 23 | img_array = np.frombuffer(img_str, dtype=np.uint8) 24 | return cv2.imdecode(img_array, cv2.IMREAD_COLOR) 25 | 26 | def load_img(self, image_path, M=None, crop=True, crop_len=16): 27 | img = cv2.imread(image_path) 28 | 29 | if img is None: 30 | raise Exception('None Image') 31 | 32 | if M is not None: 33 | img = cv2.warpAffine(img, M, (self.opt.crop_size, self.opt.crop_size), borderMode=cv2.BORDER_REPLICATE) 34 | 35 | if crop: 36 | img = img[:self.opt.crop_size - crop_len*2, crop_len:self.opt.crop_size - crop_len] 37 | if self.opt.target_crop_len > 0: 38 | img = img[self.opt.target_crop_len:self.opt.crop_size - self.opt.target_crop_len, self.opt.target_crop_len:self.opt.crop_size - self.opt.target_crop_len] 39 | img = cv2.resize(img, (self.opt.crop_size, self.opt.crop_size)) 40 | 41 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 42 | return img 43 | 44 | def fill_list(self, tmp_list): 45 | length = len(tmp_list) 46 | if length % self.opt.batchSize != 0: 47 | end = math.ceil(length / self.opt.batchSize) * self.opt.batchSize 48 | tmp_list = tmp_list + tmp_list[-1 * (end - length) :] 49 | return tmp_list 50 | 51 | def frame2audio_indexs(self, frame_inds): 52 | start_frame_ind = frame_inds - self.audio.num_frames_per_clip // 2 53 | 54 | start_audio_inds = start_frame_ind * self.audio.num_bins_per_frame 55 | return start_audio_inds 56 | 57 | def initialize(self, opt): 58 | self.opt = opt 59 | self.path_label = opt.path_label 60 | self.clip_len = opt.clip_len 61 | self.frame_interval = opt.frame_interval 62 | self.num_clips = opt.num_clips 63 | self.frame_rate = opt.frame_rate 64 | self.num_inputs = opt.num_inputs 65 | self.filename_tmpl = opt.filename_tmpl 66 | 67 | self.mouth_num_frames = None 68 | self.mouth_frame_path = None 69 | self.pose_num_frames = None 70 | 71 | self.audio = AudioConfig.AudioConfig(num_frames_per_clip=opt.num_frames_per_clip, hop_size=opt.hop_size) 72 | self.num_audio_bins = self.audio.num_frames_per_clip * self.audio.num_bins_per_frame 73 | 74 | 75 | assert len(opt.path_label.split()) == 8, opt.path_label 76 | id_path, ref_num, \ 77 | pose_frame_path, pose_num_frames, \ 78 | audio_path, mouth_frame_path, mouth_num_frames, spectrogram_path = opt.path_label.split() 79 | 80 | 81 | id_idx, mouth_idx = id_path.split('/')[-1], audio_path.split('/')[-1].split('.')[0] 82 | if not os.path.isdir(pose_frame_path): 83 | pose_frame_path = id_path 84 | pose_num_frames = 1 85 | 86 | pose_idx = pose_frame_path.split('/')[-1] 87 | id_idx, pose_idx, mouth_idx = str(id_idx), str(pose_idx), str(mouth_idx) 88 | 89 | self.processed_file_savepath = os.path.join('results', 'id_' + id_idx + '_pose_' + pose_idx + 90 | '_audio_' + os.path.basename(audio_path)[:-4]) 91 | if not os.path.exists(self.processed_file_savepath): os.makedirs(self.processed_file_savepath) 92 | 93 | 94 | if not os.path.isfile(spectrogram_path): 95 | wav = self.audio.read_audio(audio_path) 96 | self.spectrogram = self.audio.audio_to_spectrogram(wav) 97 | 98 | else: 99 | self.spectrogram = np.load(spectrogram_path) 100 | 101 | if os.path.isdir(mouth_frame_path): 102 | self.mouth_frame_path = mouth_frame_path 103 | self.mouth_num_frames = mouth_num_frames 104 | 105 | self.pose_num_frames = int(pose_num_frames) 106 | 107 | self.target_frame_inds = np.arange(2, len(self.spectrogram) // self.audio.num_bins_per_frame - 2) 108 | self.audio_inds = self.frame2audio_indexs(self.target_frame_inds) 109 | 110 | self.dataset_size = len(self.target_frame_inds) 111 | 112 | id_img_paths = glob.glob(os.path.join(id_path, '*.jpg')) + glob.glob(os.path.join(id_path, '*.png')) 113 | random.shuffle(id_img_paths) 114 | opt.num_inputs = min(len(id_img_paths), opt.num_inputs) 115 | id_img_tensors = [] 116 | 117 | for i, image_path in enumerate(id_img_paths): 118 | id_img_tensor = self.to_Tensor(self.load_img(image_path)) 119 | id_img_tensors += [id_img_tensor] 120 | shutil.copyfile(image_path, os.path.join(self.processed_file_savepath, 'ref_id_{}.jpg'.format(i))) 121 | if i == (opt.num_inputs - 1): 122 | break 123 | self.id_img_tensor = torch.stack(id_img_tensors) 124 | self.pose_frame_path = pose_frame_path 125 | self.audio_path = audio_path 126 | self.id_path = id_path 127 | self.mouth_frame_path = mouth_frame_path 128 | self.initialized = False 129 | 130 | 131 | def paths_match(self, path1, path2): 132 | filename1_without_ext = os.path.splitext(os.path.basename(path1)[-10:])[0] 133 | filename2_without_ext = os.path.splitext(os.path.basename(path2)[-10:])[0] 134 | return filename1_without_ext == filename2_without_ext 135 | 136 | def load_one_frame(self, frame_ind, video_path, M=None, crop=True): 137 | filepath = os.path.join(video_path, self.filename_tmpl.format(frame_ind)) 138 | img = self.load_img(filepath, M=M, crop=crop) 139 | img = self.to_Tensor(img) 140 | return img 141 | 142 | def load_spectrogram(self, audio_ind): 143 | mel_shape = self.spectrogram.shape 144 | 145 | if (audio_ind + self.num_audio_bins) <= mel_shape[0] and audio_ind >= 0: 146 | spectrogram = np.array(self.spectrogram[audio_ind:audio_ind + self.num_audio_bins, :]).astype('float32') 147 | else: 148 | print('(audio_ind {} + opt.num_audio_bins {}) > mel_shape[0] {} '.format(audio_ind, self.num_audio_bins, 149 | mel_shape[0])) 150 | if audio_ind > 0: 151 | spectrogram = np.array(self.spectrogram[audio_ind:audio_ind + self.num_audio_bins, :]).astype('float32') 152 | else: 153 | spectrogram = np.zeros((self.num_audio_bins, mel_shape[1])).astype(np.float16).astype(np.float32) 154 | 155 | spectrogram = torch.from_numpy(spectrogram) 156 | spectrogram = spectrogram.unsqueeze(0) 157 | 158 | spectrogram = spectrogram.transpose(-2, -1) 159 | return spectrogram 160 | 161 | def __getitem__(self, index): 162 | 163 | img_index = self.target_frame_inds[index] 164 | mel_index = self.audio_inds[index] 165 | 166 | pose_index = util.calc_loop_idx(img_index, self.pose_num_frames) 167 | 168 | pose_frame = self.load_one_frame(pose_index, self.pose_frame_path) 169 | 170 | if os.path.isdir(self.mouth_frame_path): 171 | mouth_frame = self.load_one_frame(img_index, self.mouth_frame_path) 172 | else: 173 | mouth_frame = torch.zeros_like(pose_frame) 174 | 175 | spectrograms = self.load_spectrogram(mel_index) 176 | 177 | input_dict = { 178 | 'input': self.id_img_tensor, 179 | 'target': mouth_frame, 180 | 'driving_pose_frames': pose_frame, 181 | 'augmented': pose_frame, 182 | 'label': torch.zeros(1), 183 | } 184 | if self.opt.use_audio: 185 | input_dict['spectrograms'] = spectrograms 186 | 187 | # Give subclasses a chance to modify the final output 188 | self.postprocess(input_dict) 189 | 190 | return input_dict 191 | 192 | def postprocess(self, input_dict): 193 | return input_dict 194 | 195 | def __len__(self): 196 | return self.dataset_size 197 | 198 | def get_processed_file_savepath(self): 199 | return self.processed_file_savepath 200 | -------------------------------------------------------------------------------- /experiments/demo_vox.sh: -------------------------------------------------------------------------------- 1 | meta_path_vox='./misc/demo.csv' 2 | 3 | python -u inference.py \ 4 | --name demo \ 5 | --meta_path_vox ${meta_path_vox} \ 6 | --dataset_mode voxtest \ 7 | --netG modulate \ 8 | --netA resseaudio \ 9 | --netA_sync ressesync \ 10 | --netD multiscale \ 11 | --netV resnext \ 12 | --netE fan \ 13 | --model av \ 14 | --gpu_ids 0 \ 15 | --clip_len 1 \ 16 | --batchSize 16 \ 17 | --style_dim 2560 \ 18 | --nThreads 4 \ 19 | --input_id_feature \ 20 | --generate_interval 1 \ 21 | --style_feature_loss \ 22 | --use_audio 1 \ 23 | --noise_pose \ 24 | --driving_pose \ 25 | --gen_video \ 26 | --generate_from_audio_only \ 27 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('..') 4 | from options.test_options import TestOptions 5 | import torch 6 | from models import create_model 7 | import data 8 | import util.util as util 9 | from tqdm import tqdm 10 | 11 | 12 | def video_concat(processed_file_savepath, name, video_names, audio_path): 13 | cmd = ['ffmpeg'] 14 | num_inputs = len(video_names) 15 | for video_name in video_names: 16 | cmd += ['-i', '\'' + str(os.path.join(processed_file_savepath, video_name + '.mp4'))+'\'',] 17 | 18 | cmd += ['-filter_complex hstack=inputs=' + str(num_inputs), 19 | '\'' + str(os.path.join(processed_file_savepath, name+'.mp4')) + '\'', '-loglevel error -y'] 20 | cmd = ' '.join(cmd) 21 | os.system(cmd) 22 | 23 | video_add_audio(name, audio_path, processed_file_savepath) 24 | 25 | 26 | def video_add_audio(name, audio_path, processed_file_savepath): 27 | os.system('cp {} {}'.format(audio_path, processed_file_savepath)) 28 | cmd = ['ffmpeg', '-i', '\'' + os.path.join(processed_file_savepath, name + '.mp4') + '\'', 29 | '-i', audio_path, 30 | '-q:v 0', 31 | '-strict -2', 32 | '\'' + os.path.join(processed_file_savepath, 'av' + name + '.mp4') + '\'', 33 | '-loglevel error -y'] 34 | cmd = ' '.join(cmd) 35 | os.system(cmd) 36 | 37 | 38 | def img2video(dst_path, prefix, video_path): 39 | cmd = ['ffmpeg', '-i', '\'' + video_path + '/' + prefix + '%d.jpg' 40 | + '\'', '-q:v 0', '\'' + dst_path + '/' + prefix + '.mp4' + '\'', '-loglevel error -y'] 41 | cmd = ' '.join(cmd) 42 | os.system(cmd) 43 | 44 | 45 | def inference_single_audio(opt, path_label, model): 46 | # 47 | opt.path_label = path_label 48 | dataloader = data.create_dataloader(opt) 49 | processed_file_savepath = dataloader.dataset.get_processed_file_savepath() 50 | 51 | idx = 0 52 | if opt.driving_pose: 53 | video_names = ['Input_', 'G_Pose_Driven_', 'Pose_Source_', 'Mouth_Source_'] 54 | else: 55 | video_names = ['Input_', 'G_Fix_Pose_', 'Mouth_Source_'] 56 | is_mouth_frame = os.path.isdir(dataloader.dataset.mouth_frame_path) 57 | if not is_mouth_frame: 58 | video_names.pop() 59 | save_paths = [] 60 | for name in video_names: 61 | save_path = os.path.join(processed_file_savepath, name) 62 | util.mkdir(save_path) 63 | save_paths.append(save_path) 64 | for data_i in tqdm(dataloader): 65 | # print('==============', i, '===============') 66 | fake_image_original_pose_a, fake_image_driven_pose_a = model.forward(data_i, mode='inference') 67 | 68 | for num in range(len(fake_image_driven_pose_a)): 69 | util.save_torch_img(data_i['input'][num], os.path.join(save_paths[0], video_names[0] + str(idx) + '.jpg')) 70 | if opt.driving_pose: 71 | util.save_torch_img(fake_image_driven_pose_a[num], 72 | os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg')) 73 | util.save_torch_img(data_i['driving_pose_frames'][num], 74 | os.path.join(save_paths[2], video_names[2] + str(idx) + '.jpg')) 75 | else: 76 | util.save_torch_img(fake_image_original_pose_a[num], 77 | os.path.join(save_paths[1], video_names[1] + str(idx) + '.jpg')) 78 | if is_mouth_frame: 79 | util.save_torch_img(data_i['target'][num], os.path.join(save_paths[-1], video_names[-1] + str(idx) + '.jpg')) 80 | idx += 1 81 | 82 | if opt.gen_video: 83 | for i, video_name in enumerate(video_names): 84 | img2video(processed_file_savepath, video_name, save_paths[i]) 85 | video_concat(processed_file_savepath, 'concat', video_names, dataloader.dataset.audio_path) 86 | 87 | print('results saved...' + processed_file_savepath) 88 | del dataloader 89 | return 90 | 91 | 92 | def main(): 93 | 94 | opt = TestOptions().parse() 95 | opt.isTrain = False 96 | torch.manual_seed(0) 97 | model = create_model(opt).cuda() 98 | model.eval() 99 | 100 | with open(opt.meta_path_vox, 'r') as f: 101 | lines = f.read().splitlines() 102 | 103 | for clip_idx, path_label in enumerate(lines): 104 | try: 105 | assert len(path_label.split()) == 8, path_label 106 | 107 | inference_single_audio(opt, path_label, model) 108 | 109 | except Exception as ex: 110 | import traceback 111 | traceback.print_exc() 112 | print(path_label + '\n') 113 | print(str(ex)) 114 | 115 | 116 | if __name__ == '__main__': 117 | main() 118 | -------------------------------------------------------------------------------- /misc/Audio_Source.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Audio_Source.zip -------------------------------------------------------------------------------- /misc/Input.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Input.zip -------------------------------------------------------------------------------- /misc/Mouth_Source.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Mouth_Source.zip -------------------------------------------------------------------------------- /misc/Pose_Source.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/Pose_Source.zip -------------------------------------------------------------------------------- /misc/demo.csv: -------------------------------------------------------------------------------- 1 | misc/Input/517600055 1 misc/Pose_Source/517600078 160 misc/Audio_Source/681600002.mp3 misc/Mouth_Source/681600002 363 dummy 2 | -------------------------------------------------------------------------------- /misc/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/demo.gif -------------------------------------------------------------------------------- /misc/demo_id.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/demo_id.gif -------------------------------------------------------------------------------- /misc/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/method.png -------------------------------------------------------------------------------- /misc/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/misc/output.gif -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | def find_model_using_name(model_name): 4 | # Given the option --model [modelname], 5 | # the file "models/modelname_model.py" 6 | # will be imported. 7 | model_filename = "models." + model_name + "_model" 8 | modellib = importlib.import_module(model_filename) 9 | 10 | # In the file, the class called ModelNameModel() will 11 | # be instantiated. It has to be a subclass of torch.nn.Module, 12 | # and it is case-insensitive. 13 | model = None 14 | target_model_name = model_name.replace('_', '') + 'model' 15 | for name, cls in modellib.__dict__.items(): 16 | if name.lower() == target_model_name.lower(): 17 | model = cls 18 | 19 | if model is None: 20 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 21 | exit(0) 22 | 23 | return model 24 | 25 | 26 | def get_option_setter(model_name): 27 | model_class = find_model_using_name(model_name) 28 | return model_class.modify_commandline_options 29 | 30 | 31 | def create_model(opt): 32 | model = find_model_using_name(opt.model) 33 | instance = model(opt) 34 | print("model [%s] was created" % (type(instance).__name__)) 35 | 36 | return instance 37 | -------------------------------------------------------------------------------- /models/networks/FAN_feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from util import util 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 10 | stride=strd, padding=padding, bias=bias) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes): 15 | super(ConvBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 22 | 23 | if in_planes != out_planes: 24 | self.downsample = nn.Sequential( 25 | nn.BatchNorm2d(in_planes), 26 | nn.ReLU(True), 27 | nn.Conv2d(in_planes, out_planes, 28 | kernel_size=1, stride=1, bias=False), 29 | ) 30 | else: 31 | self.downsample = None 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out1 = self.bn1(x) 37 | out1 = F.relu(out1, True) 38 | out1 = self.conv1(out1) 39 | 40 | out2 = self.bn2(out1) 41 | out2 = F.relu(out2, True) 42 | out2 = self.conv2(out2) 43 | 44 | out3 = self.bn3(out2) 45 | out3 = F.relu(out3, True) 46 | out3 = self.conv3(out3) 47 | 48 | out3 = torch.cat((out1, out2, out3), 1) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(residual) 52 | 53 | out3 += residual 54 | 55 | return out3 56 | 57 | 58 | class HourGlass(nn.Module): 59 | def __init__(self, num_modules, depth, num_features): 60 | super(HourGlass, self).__init__() 61 | self.num_modules = num_modules 62 | self.depth = depth 63 | self.features = num_features 64 | self.dropout = nn.Dropout(0.5) 65 | 66 | self._generate_network(self.depth) 67 | 68 | def _generate_network(self, level): 69 | self.add_module('b1_' + str(level), ConvBlock(256, 256)) 70 | 71 | self.add_module('b2_' + str(level), ConvBlock(256, 256)) 72 | 73 | if level > 1: 74 | self._generate_network(level - 1) 75 | else: 76 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) 77 | 78 | self.add_module('b3_' + str(level), ConvBlock(256, 256)) 79 | 80 | def _forward(self, level, inp): 81 | # Upper branch 82 | up1 = inp 83 | up1 = self._modules['b1_' + str(level)](up1) 84 | up1 = self.dropout(up1) 85 | # Lower branch 86 | low1 = F.max_pool2d(inp, 2, stride=2) 87 | low1 = self._modules['b2_' + str(level)](low1) 88 | 89 | if level > 1: 90 | low2 = self._forward(level - 1, low1) 91 | else: 92 | low2 = low1 93 | low2 = self._modules['b2_plus_' + str(level)](low2) 94 | 95 | low3 = low2 96 | low3 = self._modules['b3_' + str(level)](low3) 97 | up1size = up1.size() 98 | rescale_size = (up1size[2], up1size[3]) 99 | up2 = F.upsample(low3, size=rescale_size, mode='bilinear') 100 | 101 | return up1 + up2 102 | 103 | def forward(self, x): 104 | return self._forward(self.depth, x) 105 | 106 | 107 | class FAN_use(nn.Module): 108 | def __init__(self): 109 | super(FAN_use, self).__init__() 110 | self.num_modules = 1 111 | 112 | # Base part 113 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 114 | self.bn1 = nn.BatchNorm2d(64) 115 | self.conv2 = ConvBlock(64, 128) 116 | self.conv3 = ConvBlock(128, 128) 117 | self.conv4 = ConvBlock(128, 256) 118 | 119 | # Stacking part 120 | hg_module = 0 121 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) 122 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 123 | self.add_module('conv_last' + str(hg_module), 124 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 125 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 126 | 68, kernel_size=1, stride=1, padding=0)) 127 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 128 | 129 | if hg_module < self.num_modules - 1: 130 | self.add_module( 131 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 132 | self.add_module('al' + str(hg_module), nn.Conv2d(68, 133 | 256, kernel_size=1, stride=1, padding=0)) 134 | 135 | self.avgpool = nn.MaxPool2d((2, 2), 2) 136 | self.conv6 = nn.Conv2d(68, 1, 3, 2, 1) 137 | self.fc = nn.Linear(28 * 28, 512) 138 | self.bn5 = nn.BatchNorm2d(68) 139 | self.relu = nn.ReLU(True) 140 | 141 | def forward(self, x): 142 | x = F.relu(self.bn1(self.conv1(x)), True) 143 | x = F.max_pool2d(self.conv2(x), 2) 144 | x = self.conv3(x) 145 | x = self.conv4(x) 146 | 147 | previous = x 148 | 149 | i = 0 150 | hg = self._modules['m' + str(i)](previous) 151 | 152 | ll = hg 153 | ll = self._modules['top_m_' + str(i)](ll) 154 | 155 | ll = self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)) 156 | tmp_out = self._modules['l' + str(i)](F.relu(ll)) 157 | 158 | net = self.relu(self.bn5(tmp_out)) 159 | net = self.conv6(net) 160 | net = net.view(-1, net.shape[-2] * net.shape[-1]) 161 | net = self.relu(net) 162 | net = self.fc(net) 163 | return net 164 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.networks.base_network import BaseNetwork 3 | from models.networks.loss import * 4 | from models.networks.discriminator import MultiscaleDiscriminator, ImageDiscriminator 5 | from models.networks.generator import ModulateGenerator 6 | from models.networks.encoder import ResSEAudioEncoder, ResNeXtEncoder, ResSESyncEncoder, FanEncoder 7 | import util.util as util 8 | 9 | 10 | def find_network_using_name(target_network_name, filename): 11 | target_class_name = target_network_name + filename 12 | module_name = 'models.networks.' + filename 13 | network = util.find_class_in_module(target_class_name, module_name) 14 | 15 | assert issubclass(network, BaseNetwork), \ 16 | "Class %s should be a subclass of BaseNetwork" % network 17 | 18 | return network 19 | 20 | 21 | def modify_commandline_options(parser, is_train): 22 | opt, _ = parser.parse_known_args() 23 | 24 | netG_cls = find_network_using_name(opt.netG, 'generator') 25 | parser = netG_cls.modify_commandline_options(parser, is_train) 26 | if is_train: 27 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 28 | parser = netD_cls.modify_commandline_options(parser, is_train) 29 | netA_cls = find_network_using_name(opt.netA, 'encoder') 30 | parser = netA_cls.modify_commandline_options(parser, is_train) 31 | # parser = netA_sync_cls.modify_commandline_options(parser, is_train) 32 | 33 | return parser 34 | 35 | 36 | def create_network(cls, opt): 37 | net = cls(opt) 38 | net.print_network() 39 | if len(opt.gpu_ids) > 0: 40 | assert(torch.cuda.is_available()) 41 | net.cuda() 42 | net.init_weights(opt.init_type, opt.init_variance) 43 | return net 44 | 45 | 46 | def define_networks(opt, name, type): 47 | netG_cls = find_network_using_name(name, type) 48 | return create_network(netG_cls, opt) 49 | 50 | def define_G(opt): 51 | netG_cls = find_network_using_name(opt.netG, 'generator') 52 | return create_network(netG_cls, opt) 53 | 54 | 55 | def define_D(opt): 56 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 57 | return create_network(netD_cls, opt) 58 | 59 | def define_A(opt): 60 | netA_cls = find_network_using_name(opt.netA, 'encoder') 61 | return create_network(netA_cls, opt) 62 | 63 | def define_A_sync(opt): 64 | netA_cls = find_network_using_name(opt.netA_sync, 'encoder') 65 | return create_network(netA_cls, opt) 66 | 67 | 68 | def define_E(opt): 69 | # there exists only one encoder type 70 | netE_cls = find_network_using_name(opt.netE, 'encoder') 71 | return create_network(netE_cls, opt) 72 | 73 | 74 | def define_V(opt): 75 | # there exists only one encoder type 76 | netV_cls = find_network_using_name(opt.netV, 'encoder') 77 | return create_network(netV_cls, opt) 78 | 79 | 80 | def define_P(opt): 81 | netP_cls = find_network_using_name(opt.netP, 'encoder') 82 | return create_network(netP_cls, opt) 83 | 84 | 85 | def define_F_rec(opt): 86 | netF_rec_cls = find_network_using_name(opt.netF_rec, 'encoder') 87 | return create_network(netF_rec_cls, opt) 88 | -------------------------------------------------------------------------------- /models/networks/__pycache__/FAN_feature_extractor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/FAN_feature_extractor.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/FAN_feature_extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/FAN_feature_extractor.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/Voxceleb_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/Voxceleb_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/Voxceleb_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/Voxceleb_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/architecture.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/architecture.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/architecture.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/architecture.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/audio_architecture.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/audio_architecture.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/audio_architecture.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/audio_architecture.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/audio_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/audio_network.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/base_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/base_network.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/base_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/base_network.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/discriminator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/discriminator.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/encoder.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/encoder.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/generator.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/generator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/generator.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/normalization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/normalization.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/normalization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/normalization.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/stylegan2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/stylegan2.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/stylegan2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/stylegan2.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/vision_network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/vision_network.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/vision_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/__pycache__/vision_network.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from models.networks.encoder import VGGEncoder 6 | from util import util 7 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 8 | import torch.nn.utils.spectral_norm as spectral_norm 9 | 10 | 11 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 12 | class VGG19(torch.nn.Module): 13 | def __init__(self, requires_grad=False): 14 | super(VGG19, self).__init__() 15 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 16 | self.slice1 = torch.nn.Sequential() 17 | self.slice2 = torch.nn.Sequential() 18 | self.slice3 = torch.nn.Sequential() 19 | self.slice4 = torch.nn.Sequential() 20 | self.slice5 = torch.nn.Sequential() 21 | for x in range(2): 22 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(2, 7): 24 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(7, 12): 26 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(12, 21): 28 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 29 | for x in range(21, 30): 30 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h_relu1 = self.slice1(X) 37 | h_relu2 = self.slice2(h_relu1) 38 | h_relu3 = self.slice3(h_relu2) 39 | h_relu4 = self.slice4(h_relu3) 40 | h_relu5 = self.slice5(h_relu4) 41 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 42 | return out 43 | 44 | 45 | class VGGFace19(torch.nn.Module): 46 | def __init__(self, opt, requires_grad=False): 47 | super(VGGFace19, self).__init__() 48 | self.model = VGGEncoder(opt) 49 | self.opt = opt 50 | ckpt = torch.load(opt.VGGFace_pretrain_path) 51 | print("=> loading checkpoint '{}'".format(opt.VGGFace_pretrain_path)) 52 | util.copy_state_dict(ckpt, self.model) 53 | vgg_pretrained_features = self.model.model.features 54 | len_features = len(self.model.model.features) 55 | self.slice1 = torch.nn.Sequential() 56 | self.slice2 = torch.nn.Sequential() 57 | self.slice3 = torch.nn.Sequential() 58 | self.slice4 = torch.nn.Sequential() 59 | self.slice5 = torch.nn.Sequential() 60 | self.slice6 = torch.nn.Sequential() 61 | 62 | for x in range(2): 63 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 64 | for x in range(2, 7): 65 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 66 | for x in range(7, 12): 67 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 68 | for x in range(12, 21): 69 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 70 | for x in range(21, 30): 71 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 72 | for x in range(30, len_features): 73 | self.slice6.add_module(str(x), vgg_pretrained_features[x]) 74 | if not requires_grad: 75 | for param in self.parameters(): 76 | param.requires_grad = False 77 | 78 | def forward(self, X): 79 | X = X.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) 80 | h_relu1 = self.slice1(X) 81 | h_relu2 = self.slice2(h_relu1) 82 | h_relu3 = self.slice3(h_relu2) 83 | h_relu4 = self.slice4(h_relu3) 84 | h_relu5 = self.slice5(h_relu4) 85 | h_relu6 = self.slice6(h_relu5) 86 | out = [h_relu3, h_relu4, h_relu5, h_relu6, h_relu6] 87 | return out 88 | 89 | 90 | # Returns a function that creates a normalization function 91 | # that does not condition on semantic map 92 | def get_nonspade_norm_layer(opt, norm_type='instance'): 93 | # helper function to get # output channels of the previous layer 94 | def get_out_channel(layer): 95 | if hasattr(layer, 'out_channels'): 96 | return getattr(layer, 'out_channels') 97 | return layer.weight.size(0) 98 | 99 | # this function will be returned 100 | def add_norm_layer(layer): 101 | nonlocal norm_type 102 | if norm_type.startswith('spectral'): 103 | layer = spectral_norm(layer) 104 | subnorm_type = norm_type[len('spectral'):] 105 | else: 106 | subnorm_type = norm_type 107 | 108 | if subnorm_type == 'none' or len(subnorm_type) == 0: 109 | return layer 110 | 111 | # remove bias in the previous layer, which is meaningless 112 | # since it has no effect after normalization 113 | if getattr(layer, 'bias', None) is not None: 114 | delattr(layer, 'bias') 115 | layer.register_parameter('bias', None) 116 | 117 | if subnorm_type == 'batch': 118 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 119 | elif subnorm_type == 'syncbatch': 120 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 121 | elif subnorm_type == 'instance': 122 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 123 | else: 124 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 125 | 126 | return nn.Sequential(layer, norm_layer) 127 | 128 | return add_norm_layer 129 | -------------------------------------------------------------------------------- /models/networks/audio_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResNetSE(nn.Module): 6 | def __init__(self, block, layers, num_filters, nOut, encoder_type='SAP', n_mels=80, n_mel_T=1, log_input=True, **kwargs): 7 | super(ResNetSE, self).__init__() 8 | 9 | print('Embedding size is %d, encoder %s.' % (nOut, encoder_type)) 10 | 11 | self.inplanes = num_filters[0] 12 | self.encoder_type = encoder_type 13 | self.n_mels = n_mels 14 | self.log_input = log_input 15 | 16 | self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 19 | 20 | self.layer1 = self._make_layer(block, num_filters[0], layers[0]) 21 | self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2)) 22 | self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2)) 23 | self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2)) 24 | 25 | self.instancenorm = nn.InstanceNorm1d(n_mels) 26 | 27 | outmap_size = int(self.n_mels * n_mel_T / 8) 28 | 29 | self.attention = nn.Sequential( 30 | nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), 31 | nn.ReLU(), 32 | nn.BatchNorm1d(128), 33 | nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), 34 | nn.Softmax(dim=2), 35 | ) 36 | 37 | if self.encoder_type == "SAP": 38 | out_dim = num_filters[3] * outmap_size 39 | elif self.encoder_type == "ASP": 40 | out_dim = num_filters[3] * outmap_size * 2 41 | else: 42 | raise ValueError('Undefined encoder') 43 | 44 | self.fc = nn.Linear(out_dim, nOut) 45 | 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 49 | elif isinstance(m, nn.BatchNorm2d): 50 | nn.init.constant_(m.weight, 1) 51 | nn.init.constant_(m.bias, 0) 52 | 53 | def _make_layer(self, block, planes, blocks, stride=1): 54 | downsample = None 55 | if stride != 1 or self.inplanes != planes * block.expansion: 56 | downsample = nn.Sequential( 57 | nn.Conv2d(self.inplanes, planes * block.expansion, 58 | kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(planes * block.expansion), 60 | ) 61 | 62 | layers = [] 63 | layers.append(block(self.inplanes, planes, stride, downsample)) 64 | self.inplanes = planes * block.expansion 65 | for i in range(1, blocks): 66 | layers.append(block(self.inplanes, planes)) 67 | 68 | return nn.Sequential(*layers) 69 | 70 | def new_parameter(self, *size): 71 | out = nn.Parameter(torch.FloatTensor(*size)) 72 | nn.init.xavier_normal_(out) 73 | return out 74 | 75 | def forward(self, x): 76 | 77 | # with torch.no_grad(): 78 | # x = self.torchfb(x) + 1e-6 79 | # if self.log_input: x = x.log() 80 | # x = self.instancenorm(x).unsqueeze(1) 81 | 82 | x = self.conv1(x) 83 | x = self.relu(x) 84 | x = self.bn1(x) 85 | 86 | x = self.layer1(x) 87 | x = self.layer2(x) 88 | x = self.layer3(x) 89 | x = self.layer4(x) 90 | 91 | x = x.reshape(x.size()[0], -1, x.size()[-1]) 92 | 93 | w = self.attention(x) 94 | 95 | if self.encoder_type == "SAP": 96 | x = torch.sum(x * w, dim=2) 97 | elif self.encoder_type == "ASP": 98 | mu = torch.sum(x * w, dim=2) 99 | sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) 100 | x = torch.cat((mu, sg), 1) 101 | 102 | x = x.view(x.size()[0], -1) 103 | x = self.fc(x) 104 | 105 | return x 106 | 107 | 108 | 109 | 110 | class SEBasicBlock(nn.Module): 111 | expansion = 1 112 | 113 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 114 | super(SEBasicBlock, self).__init__() 115 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 116 | self.bn1 = nn.BatchNorm2d(planes) 117 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 118 | self.bn2 = nn.BatchNorm2d(planes) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.se = SELayer(planes, reduction) 121 | self.downsample = downsample 122 | self.stride = stride 123 | 124 | def forward(self, x): 125 | residual = x 126 | 127 | out = self.conv1(x) 128 | out = self.relu(out) 129 | out = self.bn1(out) 130 | 131 | out = self.conv2(out) 132 | out = self.bn2(out) 133 | out = self.se(out) 134 | 135 | if self.downsample is not None: 136 | residual = self.downsample(x) 137 | 138 | out += residual 139 | out = self.relu(out) 140 | return out 141 | 142 | 143 | class SEBottleneck(nn.Module): 144 | expansion = 4 145 | 146 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 147 | super(SEBottleneck, self).__init__() 148 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 149 | self.bn1 = nn.BatchNorm2d(planes) 150 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 151 | padding=1, bias=False) 152 | self.bn2 = nn.BatchNorm2d(planes) 153 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 154 | self.bn3 = nn.BatchNorm2d(planes * 4) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.se = SELayer(planes * 4, reduction) 157 | self.downsample = downsample 158 | self.stride = stride 159 | 160 | def forward(self, x): 161 | residual = x 162 | 163 | out = self.conv1(x) 164 | out = self.bn1(out) 165 | out = self.relu(out) 166 | 167 | out = self.conv2(out) 168 | out = self.bn2(out) 169 | out = self.relu(out) 170 | 171 | out = self.conv3(out) 172 | out = self.bn3(out) 173 | out = self.se(out) 174 | 175 | if self.downsample is not None: 176 | residual = self.downsample(x) 177 | 178 | out += residual 179 | out = self.relu(out) 180 | 181 | return out 182 | 183 | 184 | class SELayer(nn.Module): 185 | def __init__(self, channel, reduction=8): 186 | super(SELayer, self).__init__() 187 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 188 | self.fc = nn.Sequential( 189 | nn.Linear(channel, channel // reduction), 190 | nn.ReLU(inplace=True), 191 | nn.Linear(channel // reduction, channel), 192 | nn.Sigmoid() 193 | ) 194 | 195 | def forward(self, x): 196 | b, c, _, _ = x.size() 197 | y = self.avg_pool(x).view(b, c) 198 | y = self.fc(y).view(b, c, 1, 1) 199 | return x * y 200 | -------------------------------------------------------------------------------- /models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import init 3 | 4 | 5 | class BaseNetwork(nn.Module): 6 | def __init__(self): 7 | super(BaseNetwork, self).__init__() 8 | 9 | @staticmethod 10 | def modify_commandline_options(parser, is_train): 11 | return parser 12 | 13 | def print_network(self): 14 | if isinstance(self, list): 15 | self = self[0] 16 | num_params = 0 17 | for param in self.parameters(): 18 | num_params += param.numel() 19 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 20 | 'To see the architecture, do print(network).' 21 | % (type(self).__name__, num_params / 1000000)) 22 | 23 | def init_weights(self, init_type='normal', gain=0.02): 24 | def init_func(m): 25 | classname = m.__class__.__name__ 26 | if classname.find('BatchNorm2d') != -1: 27 | if hasattr(m, 'weight') and m.weight is not None: 28 | init.normal_(m.weight.data, 1.0, gain) 29 | if hasattr(m, 'bias') and m.bias is not None: 30 | init.constant_(m.bias.data, 0.0) 31 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 32 | if init_type == 'normal': 33 | init.normal_(m.weight.data, 0.0, gain) 34 | elif init_type == 'xavier': 35 | init.xavier_normal_(m.weight.data, gain=gain) 36 | elif init_type == 'xavier_uniform': 37 | init.xavier_uniform_(m.weight.data, gain=1.0) 38 | elif init_type == 'kaiming': 39 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 40 | elif init_type == 'orthogonal': 41 | init.orthogonal_(m.weight.data, gain=gain) 42 | elif init_type == 'none': # uses pytorch's default init method 43 | m.reset_parameters() 44 | else: 45 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 46 | if hasattr(m, 'bias') and m.bias is not None: 47 | init.constant_(m.bias.data, 0.0) 48 | 49 | self.apply(init_func) 50 | 51 | # propagate to children 52 | for m in self.children(): 53 | if hasattr(m, 'init_weights'): 54 | m.init_weights(init_type, gain) 55 | -------------------------------------------------------------------------------- /models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from models.networks.base_network import BaseNetwork 4 | import util.util as util 5 | import torch 6 | from models.networks.architecture import get_nonspade_norm_layer 7 | import torch.nn.functional as F 8 | 9 | 10 | class MultiscaleDiscriminator(BaseNetwork): 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | parser.add_argument('--netD_subarch', type=str, default='n_layer', 14 | help='architecture of each discriminator') 15 | parser.add_argument('--num_D', type=int, default=2, 16 | help='number of discriminators to be used in multiscale') 17 | opt, _ = parser.parse_known_args() 18 | 19 | # define properties of each discriminator of the multiscale discriminator 20 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', 21 | 'models.networks.discriminator') 22 | subnetD.modify_commandline_options(parser, is_train) 23 | 24 | return parser 25 | 26 | def __init__(self, opt): 27 | super(MultiscaleDiscriminator, self).__init__() 28 | self.opt = opt 29 | 30 | for i in range(opt.num_D): 31 | subnetD = self.create_single_discriminator(opt) 32 | self.add_module('discriminator_%d' % i, subnetD) 33 | 34 | def create_single_discriminator(self, opt): 35 | subarch = opt.netD_subarch 36 | if subarch == 'n_layer': 37 | netD = NLayerDiscriminator(opt) 38 | else: 39 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) 40 | return netD 41 | 42 | def downsample(self, input): 43 | return F.avg_pool2d(input, kernel_size=3, 44 | stride=2, padding=[1, 1], 45 | count_include_pad=False) 46 | 47 | # Returns list of lists of discriminator outputs. 48 | # The final result is of size opt.num_D x opt.n_layers_D 49 | def forward(self, input): 50 | result = [] 51 | get_intermediate_features = not self.opt.no_ganFeat_loss 52 | for name, D in self.named_children(): 53 | out = D(input) 54 | if not get_intermediate_features: 55 | out = [out] 56 | result.append(out) 57 | input = self.downsample(input) 58 | 59 | return result 60 | 61 | 62 | # Defines the PatchGAN discriminator with the specified arguments. 63 | class NLayerDiscriminator(BaseNetwork): 64 | @staticmethod 65 | def modify_commandline_options(parser, is_train): 66 | parser.add_argument('--n_layers_D', type=int, default=4, 67 | help='# layers in each discriminator') 68 | return parser 69 | 70 | def __init__(self, opt): 71 | 72 | super(NLayerDiscriminator, self).__init__() 73 | self.opt = opt 74 | 75 | kw = 4 76 | padw = int(np.ceil((kw - 1.0) / 2)) 77 | nf = opt.ndf 78 | input_nc = self.compute_D_input_nc(opt) 79 | 80 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 81 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 82 | nn.LeakyReLU(0.2, False)]] 83 | 84 | for n in range(1, opt.n_layers_D): 85 | nf_prev = nf 86 | nf = min(nf * 2, 512) 87 | stride = 1 if n == opt.n_layers_D - 1 else 2 88 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 89 | stride=stride, padding=padw)), 90 | nn.LeakyReLU(0.2, False) 91 | ]] 92 | 93 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 94 | 95 | # We divide the layers into groups to extract intermediate layer outputs 96 | for n in range(len(sequence)): 97 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 98 | 99 | def compute_D_input_nc(self, opt): 100 | if opt.D_input == "concat": 101 | input_nc = opt.label_nc + opt.output_nc 102 | if opt.contain_dontcare_label: 103 | input_nc += 1 104 | if not opt.no_instance: 105 | input_nc += 1 106 | else: 107 | input_nc = 3 108 | return input_nc 109 | 110 | def forward(self, input): 111 | results = [input] 112 | for submodel in self.children(): 113 | 114 | # intermediate_output = checkpoint(submodel, results[-1]) 115 | intermediate_output = submodel(results[-1]) 116 | results.append(intermediate_output) 117 | 118 | get_intermediate_features = not self.opt.no_ganFeat_loss 119 | if get_intermediate_features: 120 | return results[0:] 121 | else: 122 | return results[-1] 123 | 124 | 125 | class AudioSubDiscriminator(BaseNetwork): 126 | def __init__(self, opt, nc, audio_nc): 127 | super(AudioSubDiscriminator, self).__init__() 128 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 129 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 130 | sequence = [] 131 | sequence += [norm_layer(nn.Conv1d(nc, nc, 3, 2, 1)), 132 | nn.ReLU() 133 | ] 134 | sequence += [norm_layer(nn.Conv1d(nc, audio_nc, 3, 2, 1)), 135 | nn.ReLU() 136 | ] 137 | 138 | self.conv = nn.Sequential(*sequence) 139 | self.cosine = nn.CosineSimilarity() 140 | self.mapping = nn.Linear(audio_nc, audio_nc) 141 | 142 | def forward(self, result, audio): 143 | region = result[result.shape[3] // 2:result.shape[3] - 2, result.shape[4] // 3: 2 * result.shape[4] // 3] 144 | visual = self.avgpool(region) 145 | cos = self.cosine(visual, self.mapping(audio)) 146 | return cos 147 | 148 | 149 | class ImageDiscriminator(BaseNetwork): 150 | """Defines a PatchGAN discriminator""" 151 | def modify_commandline_options(parser, is_train): 152 | parser.add_argument('--n_layers_D', type=int, default=4, 153 | help='# layers in each discriminator') 154 | return parser 155 | 156 | def __init__(self, opt, n_layers=3, norm_layer=nn.BatchNorm2d): 157 | """Construct a PatchGAN discriminator 158 | Parameters: 159 | input_nc (int) -- the number of channels in input images 160 | ndf (int) -- the number of filters in the last conv layer 161 | n_layers (int) -- the number of conv layers in the discriminator 162 | norm_layer -- normalization layer 163 | """ 164 | super(ImageDiscriminator, self).__init__() 165 | use_bias = norm_layer == nn.InstanceNorm2d 166 | if opt.D_input == "concat": 167 | input_nc = opt.label_nc + opt.output_nc 168 | else: 169 | input_nc = opt.label_nc 170 | ndf = 64 171 | kw = 4 172 | padw = 1 173 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 174 | nf_mult = 1 175 | nf_mult_prev = 1 176 | for n in range(1, n_layers): # gradually increase the number of filters 177 | nf_mult_prev = nf_mult 178 | nf_mult = min(2 ** n, 8) 179 | sequence += [ 180 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 181 | norm_layer(ndf * nf_mult), 182 | nn.LeakyReLU(0.2, True) 183 | ] 184 | 185 | nf_mult_prev = nf_mult 186 | nf_mult = min(2 ** n_layers, 8) 187 | sequence += [ 188 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 189 | norm_layer(ndf * nf_mult), 190 | nn.LeakyReLU(0.2, True) 191 | ] 192 | 193 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 194 | self.model = nn.Sequential(*sequence) 195 | 196 | def forward(self, input): 197 | """Standard forward.""" 198 | return self.model(input) 199 | 200 | 201 | class FeatureDiscriminator(BaseNetwork): 202 | def __init__(self, opt): 203 | super(FeatureDiscriminator, self).__init__() 204 | self.opt = opt 205 | self.fc = nn.Linear(512, opt.num_labels) 206 | self.dropout = nn.Dropout(0.5) 207 | 208 | def forward(self, x): 209 | x0 = x.view(-1, 512) 210 | net = self.dropout(x0) 211 | net = self.fc(net) 212 | return net 213 | 214 | 215 | -------------------------------------------------------------------------------- /models/networks/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from models.networks.base_network import BaseNetwork 5 | import torchvision.models.mobilenet 6 | from util import util 7 | from models.networks.audio_network import ResNetSE, SEBasicBlock 8 | import torch 9 | from models.networks.FAN_feature_extractor import FAN_use 10 | from torchvision.models.vgg import vgg19_bn 11 | from models.networks.vision_network import ResNeXt50 12 | 13 | 14 | class ResSEAudioEncoder(BaseNetwork): 15 | def __init__(self, opt, nOut=2048, n_mel_T=None): 16 | super(ResSEAudioEncoder, self).__init__() 17 | self.nOut = nOut 18 | # Number of filters 19 | num_filters = [32, 64, 128, 256] 20 | if n_mel_T is None: # use it when use audio identity 21 | n_mel_T = opt.n_mel_T 22 | self.model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, self.nOut, n_mel_T=n_mel_T) 23 | self.fc = nn.Linear(self.nOut, opt.num_classes) 24 | 25 | def forward_feature(self, x): 26 | 27 | input_size = x.size() 28 | if len(input_size) == 5: 29 | bz, clip_len, c, f, t = input_size 30 | x = x.view(bz * clip_len, c, f, t) 31 | out = self.model(x) 32 | return out 33 | 34 | def forward(self, x): 35 | out = self.forward_feature(x) 36 | score = self.fc(out) 37 | return out, score 38 | 39 | 40 | class ResSESyncEncoder(ResSEAudioEncoder): 41 | def __init__(self, opt): 42 | super(ResSESyncEncoder, self).__init__(opt, nOut=512, n_mel_T=1) 43 | 44 | 45 | class ResNeXtEncoder(ResNeXt50): 46 | def __init__(self, opt): 47 | super(ResNeXtEncoder, self).__init__(opt) 48 | 49 | 50 | class VGGEncoder(BaseNetwork): 51 | def __init__(self, opt): 52 | super(VGGEncoder, self).__init__() 53 | self.model = vgg19_bn(num_classes=opt.num_classes) 54 | 55 | def forward(self, x): 56 | return self.model(x) 57 | 58 | 59 | class FanEncoder(BaseNetwork): 60 | def __init__(self, opt): 61 | super(FanEncoder, self).__init__() 62 | self.opt = opt 63 | pose_dim = self.opt.pose_dim 64 | self.model = FAN_use() 65 | self.classifier = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, opt.num_classes)) 66 | 67 | # mapper to mouth subspace 68 | self.to_mouth = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512)) 69 | self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim)) 70 | self.mouth_fc = nn.Sequential(nn.ReLU(), nn.Linear(512*opt.clip_len, opt.num_classes)) 71 | 72 | # mapper to head pose subspace 73 | self.to_headpose = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 512)) 74 | self.headpose_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, pose_dim)) 75 | self.headpose_fc = nn.Sequential(nn.ReLU(), nn.Linear(pose_dim*opt.clip_len, opt.num_classes)) 76 | 77 | def load_pretrain(self): 78 | check_point = torch.load(self.opt.FAN_pretrain_path) 79 | print("=> loading checkpoint '{}'".format(self.opt.FAN_pretrain_path)) 80 | util.copy_state_dict(check_point, self.model) 81 | 82 | def forward_feature(self, x): 83 | net = self.model(x) 84 | return net 85 | 86 | def forward(self, x): 87 | x0 = x.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) 88 | net = self.forward_feature(x0) 89 | scores = self.classifier(net.view(-1, self.opt.num_clips, 512).mean(1)) 90 | return net, scores 91 | -------------------------------------------------------------------------------- /models/networks/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.networks.architecture import VGG19, VGGFace19 5 | 6 | 7 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 8 | # When LSGAN is used, it is basically same as MSELoss, 9 | # but it abstracts away the need to create the target label tensor 10 | # that has the same size as the input 11 | class GANLoss(nn.Module): 12 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 13 | tensor=torch.FloatTensor, opt=None): 14 | super(GANLoss, self).__init__() 15 | self.real_label = target_real_label 16 | self.fake_label = target_fake_label 17 | self.real_label_tensor = None 18 | self.fake_label_tensor = None 19 | self.zero_tensor = None 20 | self.Tensor = tensor 21 | self.gan_mode = gan_mode 22 | self.opt = opt 23 | if gan_mode == 'ls': 24 | pass 25 | elif gan_mode == 'original': 26 | pass 27 | elif gan_mode == 'w': 28 | pass 29 | elif gan_mode == 'hinge': 30 | pass 31 | else: 32 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 33 | 34 | def get_target_tensor(self, input, target_is_real): 35 | if target_is_real: 36 | if self.real_label_tensor is None: 37 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 38 | self.real_label_tensor.requires_grad_(False) 39 | return self.real_label_tensor.expand_as(input) 40 | else: 41 | if self.fake_label_tensor is None: 42 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 43 | self.fake_label_tensor.requires_grad_(False) 44 | return self.fake_label_tensor.expand_as(input) 45 | 46 | def get_zero_tensor(self, input): 47 | if self.zero_tensor is None: 48 | self.zero_tensor = self.Tensor(1).fill_(0) 49 | self.zero_tensor.requires_grad_(False) 50 | return self.zero_tensor.expand_as(input) 51 | 52 | def loss(self, input, target_is_real, for_discriminator=True): 53 | if self.gan_mode == 'original': # cross entropy loss 54 | target_tensor = self.get_target_tensor(input, target_is_real) 55 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 56 | return loss 57 | elif self.gan_mode == 'ls': 58 | target_tensor = self.get_target_tensor(input, target_is_real) 59 | return F.mse_loss(input, target_tensor) 60 | elif self.gan_mode == 'hinge': 61 | if for_discriminator: 62 | if target_is_real: 63 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 64 | loss = -torch.mean(minval) 65 | else: 66 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 67 | loss = -torch.mean(minval) 68 | else: 69 | assert target_is_real, "The generator's hinge loss must be aiming for real" 70 | loss = -torch.mean(input) 71 | return loss 72 | else: 73 | # wgan 74 | if target_is_real: 75 | return -input.mean() 76 | else: 77 | return input.mean() 78 | 79 | def __call__(self, input, target_is_real, for_discriminator=True): 80 | # computing loss is a bit complicated because |input| may not be 81 | # a tensor, but list of tensors in case of multiscale discriminator 82 | if isinstance(input, list): 83 | loss = 0 84 | for pred_i in input: 85 | if isinstance(pred_i, list): 86 | pred_i = pred_i[-1] 87 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 88 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 89 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 90 | loss += new_loss 91 | return loss / len(input) 92 | else: 93 | return self.loss(input, target_is_real, for_discriminator) 94 | 95 | 96 | # Perceptual loss that uses a pretrained VGG network 97 | class VGGLoss(nn.Module): 98 | def __init__(self, opt, vgg=VGG19()): 99 | super(VGGLoss, self).__init__() 100 | self.vgg = vgg.cuda() 101 | self.criterion = nn.L1Loss() 102 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 103 | 104 | def forward(self, x, y, layer=0): 105 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 106 | loss = 0 107 | for i in range(len(x_vgg)): 108 | if i >= layer: 109 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 110 | return loss 111 | 112 | 113 | # KL Divergence loss used in VAE with an image encoder 114 | class KLDLoss(nn.Module): 115 | def forward(self, mu, logvar): 116 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 117 | 118 | 119 | class CrossEntropyLoss(nn.Module): 120 | """Cross Entropy Loss 121 | 122 | It will calculate cross_entropy loss given cls_score and label. 123 | """ 124 | 125 | def forward(self, cls_score, label): 126 | loss_cls = F.cross_entropy(cls_score, label) 127 | return loss_cls 128 | 129 | 130 | class SumLogSoftmaxLoss(nn.Module): 131 | 132 | def forward(self, x): 133 | out = F.log_softmax(x, dim=1) 134 | loss = - torch.mean(out) + torch.mean(F.log_softmax(torch.ones_like(out), dim=1) ) 135 | return loss 136 | 137 | 138 | class L2SoftmaxLoss(nn.Module): 139 | def __init__(self): 140 | super(L2SoftmaxLoss, self).__init__() 141 | self.softmax = nn.Softmax() 142 | self.L2loss = nn.MSELoss() 143 | self.label = None 144 | 145 | def forward(self, x): 146 | out = self.softmax(x) 147 | self.label = (torch.ones(out.size()).float() * (1 / x.size(1))).cuda() 148 | loss = self.L2loss(out, self.label) 149 | return loss 150 | 151 | 152 | class SoftmaxContrastiveLoss(nn.Module): 153 | def __init__(self): 154 | super(SoftmaxContrastiveLoss, self).__init__() 155 | self.cross_ent = nn.CrossEntropyLoss() 156 | 157 | def l2_norm(self, x): 158 | x_norm = F.normalize(x, p=2, dim=1) 159 | return x_norm 160 | 161 | def l2_sim(self, feature1, feature2): 162 | Feature = feature1.expand(feature1.size(0), feature1.size(0), feature1.size(1)).transpose(0, 1) 163 | return torch.norm(Feature - feature2, p=2, dim=2) 164 | 165 | @torch.no_grad() 166 | def evaluate(self, face_feat, audio_feat, mode='max'): 167 | assert mode in 'max' or 'confusion', '{} must be in max or confusion'.format(mode) 168 | face_feat = self.l2_norm(face_feat) 169 | audio_feat = self.l2_norm(audio_feat) 170 | cross_dist = 1.0 / self.l2_sim(face_feat, audio_feat) 171 | 172 | print(cross_dist) 173 | if mode == 'max': 174 | label = torch.arange(face_feat.size(0)).to(cross_dist.device) 175 | max_idx = torch.argmax(cross_dist, dim=1) 176 | # print(max_idx, label) 177 | acc = torch.sum(label == max_idx) * 1.0 / label.size(0) 178 | else: 179 | raise ValueError 180 | 181 | return acc 182 | 183 | def forward(self, face_feat, audio_feat, mode='max'): 184 | assert mode in 'max' or 'confusion', '{} must be in max or confusion'.format(mode) 185 | 186 | face_feat = self.l2_norm(face_feat) 187 | audio_feat = self.l2_norm(audio_feat) 188 | 189 | cross_dist = 1.0 / self.l2_sim(face_feat, audio_feat) 190 | 191 | if mode == 'max': 192 | label = torch.arange(face_feat.size(0)).to(cross_dist.device) 193 | loss = F.cross_entropy(cross_dist, label) 194 | else: 195 | raise ValueError 196 | return loss 197 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 2 | from .batchnorm import patch_sync_batchnorm, convert_model 3 | from .replicate import DataParallelWithCallback, patch_replication_callback 4 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/replicate.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-36.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/models/networks/sync_batchnorm/__pycache__/scatter_gather.cpython-37.pyc -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import contextlib 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | try: 10 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 11 | except ImportError: 12 | ReduceAddCoalesced = Broadcast = None 13 | 14 | try: 15 | from jactorch.parallel.comm import SyncMaster 16 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback 17 | except ImportError: 18 | from .comm import SyncMaster 19 | from .replicate import DataParallelWithCallback 20 | 21 | __all__ = [ 22 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 23 | 'patch_sync_batchnorm', 'convert_model' 24 | ] 25 | 26 | 27 | def _sum_ft(tensor): 28 | """sum over the first and last dimention""" 29 | return tensor.sum(dim=0).sum(dim=-1) 30 | 31 | 32 | def _unsqueeze_ft(tensor): 33 | """add new dimensions at the front and the tail""" 34 | return tensor.unsqueeze(0).unsqueeze(-1) 35 | 36 | 37 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 38 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 39 | 40 | 41 | class _SynchronizedBatchNorm(_BatchNorm): 42 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 43 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' 44 | 45 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 46 | 47 | self._sync_master = SyncMaster(self._data_parallel_master) 48 | 49 | self._is_parallel = False 50 | self._parallel_id = None 51 | self._slave_pipe = None 52 | 53 | def forward(self, input): 54 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 55 | if not (self._is_parallel and self.training): 56 | return F.batch_norm( 57 | input, self.running_mean, self.running_var, self.weight, self.bias, 58 | self.training, self.momentum, self.eps) 59 | 60 | # Resize the input to (B, C, -1). 61 | input_shape = input.size() 62 | input = input.view(input.size(0), self.num_features, -1) 63 | 64 | # Compute the sum and square-sum. 65 | sum_size = input.size(0) * input.size(2) 66 | input_sum = _sum_ft(input) 67 | input_ssum = _sum_ft(input ** 2) 68 | 69 | # Reduce-and-broadcast the statistics. 70 | if self._parallel_id == 0: 71 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 72 | else: 73 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 74 | 75 | # Compute the output. 76 | if self.affine: 77 | # MJY:: Fuse the multiplication for speed. 78 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 79 | else: 80 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 81 | 82 | # Reshape it. 83 | return output.view(input_shape) 84 | 85 | def __data_parallel_replicate__(self, ctx, copy_id): 86 | self._is_parallel = True 87 | self._parallel_id = copy_id 88 | 89 | # parallel_id == 0 means master device. 90 | if self._parallel_id == 0: 91 | ctx.sync_master = self._sync_master 92 | else: 93 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 94 | 95 | def _data_parallel_master(self, intermediates): 96 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 97 | 98 | # Always using same "device order" makes the ReduceAdd operation faster. 99 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 101 | 102 | to_reduce = [i[1][:2] for i in intermediates] 103 | to_reduce = [j for i in to_reduce for j in i] # flatten 104 | target_gpus = [i[1].sum.get_device() for i in intermediates] 105 | 106 | sum_size = sum([i[1].sum_size for i in intermediates]) 107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 108 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 109 | 110 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 111 | 112 | outputs = [] 113 | for i, rec in enumerate(intermediates): 114 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 115 | 116 | return outputs 117 | 118 | def _compute_mean_std(self, sum_, ssum, size): 119 | """Compute the mean and standard-deviation with sum and square-sum. This method 120 | also maintains the moving average on the master device.""" 121 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 122 | mean = sum_ / size 123 | sumvar = ssum - sum_ * mean 124 | unbias_var = sumvar / (size - 1) 125 | bias_var = sumvar / size 126 | 127 | if hasattr(torch, 'no_grad'): 128 | with torch.no_grad(): 129 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 130 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 131 | else: 132 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 133 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 134 | 135 | return mean, bias_var.clamp(self.eps) ** -0.5 136 | 137 | 138 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 139 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 140 | mini-batch. 141 | 142 | .. math:: 143 | 144 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 145 | 146 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 147 | standard-deviation are reduced across all devices during training. 148 | 149 | For example, when one uses `nn.DataParallel` to wrap the network during 150 | training, PyTorch's implementation normalize the tensor on each device using 151 | the statistics only on that device, which accelerated the computation and 152 | is also easy to implement, but the statistics might be inaccurate. 153 | Instead, in this synchronized version, the statistics will be computed 154 | over all training samples distributed on multiple devices. 155 | 156 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 157 | as the built-in PyTorch implementation. 158 | 159 | The mean and standard-deviation are calculated per-dimension over 160 | the mini-batches and gamma and beta are learnable parameter vectors 161 | of size C (where C is the input size). 162 | 163 | During training, this layer keeps a running estimate of its computed mean 164 | and variance. The running sum is kept with a default momentum of 0.1. 165 | 166 | During evaluation, this running mean/variance is used for normalization. 167 | 168 | Because the BatchNorm is done over the `C` dimension, computing statistics 169 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 170 | 171 | Args: 172 | num_features: num_features from an expected input of size 173 | `batch_size x num_features [x width]` 174 | eps: a value added to the denominator for numerical stability. 175 | Default: 1e-5 176 | momentum: the value used for the running_mean and running_var 177 | computation. Default: 0.1 178 | affine: a boolean value that when set to ``True``, gives the layer learnable 179 | affine parameters. Default: ``True`` 180 | 181 | Shape:: 182 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 183 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 184 | 185 | Examples: 186 | >>> # With Learnable Parameters 187 | >>> m = SynchronizedBatchNorm1d(100) 188 | >>> # Without Learnable Parameters 189 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 190 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 191 | >>> output = m(input) 192 | """ 193 | 194 | def _check_input_dim(self, input): 195 | if input.dim() != 2 and input.dim() != 3: 196 | raise ValueError('expected 2D or 3D input (got {}D input)' 197 | .format(input.dim())) 198 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 199 | 200 | 201 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 202 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 203 | of 3d inputs 204 | 205 | .. math:: 206 | 207 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 208 | 209 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 210 | standard-deviation are reduced across all devices during training. 211 | 212 | For example, when one uses `nn.DataParallel` to wrap the network during 213 | training, PyTorch's implementation normalize the tensor on each device using 214 | the statistics only on that device, which accelerated the computation and 215 | is also easy to implement, but the statistics might be inaccurate. 216 | Instead, in this synchronized version, the statistics will be computed 217 | over all training samples distributed on multiple devices. 218 | 219 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 220 | as the built-in PyTorch implementation. 221 | 222 | The mean and standard-deviation are calculated per-dimension over 223 | the mini-batches and gamma and beta are learnable parameter vectors 224 | of size C (where C is the input size). 225 | 226 | During training, this layer keeps a running estimate of its computed mean 227 | and variance. The running sum is kept with a default momentum of 0.1. 228 | 229 | During evaluation, this running mean/variance is used for normalization. 230 | 231 | Because the BatchNorm is done over the `C` dimension, computing statistics 232 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 233 | 234 | Args: 235 | num_features: num_features from an expected input of 236 | size batch_size x num_features x height x width 237 | eps: a value added to the denominator for numerical stability. 238 | Default: 1e-5 239 | momentum: the value used for the running_mean and running_var 240 | computation. Default: 0.1 241 | affine: a boolean value that when set to ``True``, gives the layer learnable 242 | affine parameters. Default: ``True`` 243 | 244 | Shape:: 245 | - Input: :math:`(N, C, H, W)` 246 | - Output: :math:`(N, C, H, W)` (same shape as input) 247 | 248 | Examples: 249 | >>> # With Learnable Parameters 250 | >>> m = SynchronizedBatchNorm2d(100) 251 | >>> # Without Learnable Parameters 252 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 253 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 254 | >>> output = m(input) 255 | """ 256 | 257 | def _check_input_dim(self, input): 258 | if input.dim() != 4: 259 | raise ValueError('expected 4D input (got {}D input)' 260 | .format(input.dim())) 261 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 262 | 263 | 264 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 265 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 266 | of 4d inputs 267 | 268 | .. math:: 269 | 270 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 271 | 272 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 273 | standard-deviation are reduced across all devices during training. 274 | 275 | For example, when one uses `nn.DataParallel` to wrap the network during 276 | training, PyTorch's implementation normalize the tensor on each device using 277 | the statistics only on that device, which accelerated the computation and 278 | is also easy to implement, but the statistics might be inaccurate. 279 | Instead, in this synchronized version, the statistics will be computed 280 | over all training samples distributed on multiple devices. 281 | 282 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 283 | as the built-in PyTorch implementation. 284 | 285 | The mean and standard-deviation are calculated per-dimension over 286 | the mini-batches and gamma and beta are learnable parameter vectors 287 | of size C (where C is the input size). 288 | 289 | During training, this layer keeps a running estimate of its computed mean 290 | and variance. The running sum is kept with a default momentum of 0.1. 291 | 292 | During evaluation, this running mean/variance is used for normalization. 293 | 294 | Because the BatchNorm is done over the `C` dimension, computing statistics 295 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 296 | or Spatio-temporal BatchNorm 297 | 298 | Args: 299 | num_features: num_features from an expected input of 300 | size batch_size x num_features x depth x height x width 301 | eps: a value added to the denominator for numerical stability. 302 | Default: 1e-5 303 | momentum: the value used for the running_mean and running_var 304 | computation. Default: 0.1 305 | affine: a boolean value that when set to ``True``, gives the layer learnable 306 | affine parameters. Default: ``True`` 307 | 308 | Shape:: 309 | - Input: :math:`(N, C, D, H, W)` 310 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 311 | 312 | Examples: 313 | >>> # With Learnable Parameters 314 | >>> m = SynchronizedBatchNorm3d(100) 315 | >>> # Without Learnable Parameters 316 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 317 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 318 | >>> output = m(input) 319 | """ 320 | 321 | def _check_input_dim(self, input): 322 | if input.dim() != 5: 323 | raise ValueError('expected 5D input (got {}D input)' 324 | .format(input.dim())) 325 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 326 | 327 | 328 | @contextlib.contextmanager 329 | def patch_sync_batchnorm(): 330 | import torch.nn as nn 331 | 332 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d 333 | 334 | nn.BatchNorm1d = SynchronizedBatchNorm1d 335 | nn.BatchNorm2d = SynchronizedBatchNorm2d 336 | nn.BatchNorm3d = SynchronizedBatchNorm3d 337 | 338 | yield 339 | 340 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup 341 | 342 | 343 | def convert_model(module): 344 | """Traverse the input module and its child recursively 345 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d 346 | to SynchronizedBatchNorm*N*d 347 | 348 | Args: 349 | module: the input module needs to be convert to SyncBN model 350 | 351 | Examples: 352 | >>> import torch.nn as nn 353 | >>> import torchvision 354 | >>> # m is a standard pytorch model 355 | >>> m = torchvision.models.resnet18(True) 356 | >>> m = nn.DataParallel(m) 357 | >>> # after convert, m is using SyncBN 358 | >>> m = convert_model(m) 359 | """ 360 | if isinstance(module, torch.nn.DataParallel): 361 | mod = module.module 362 | mod = convert_model(mod) 363 | mod = DataParallelWithCallback(mod) 364 | return mod 365 | 366 | mod = module 367 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, 368 | torch.nn.modules.batchnorm.BatchNorm2d, 369 | torch.nn.modules.batchnorm.BatchNorm3d], 370 | [SynchronizedBatchNorm1d, 371 | SynchronizedBatchNorm2d, 372 | SynchronizedBatchNorm3d]): 373 | if isinstance(module, pth_module): 374 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) 375 | mod.running_mean = module.running_mean 376 | mod.running_var = module.running_var 377 | if module.affine: 378 | mod.weight.data = module.weight.data.clone().detach() 379 | mod.bias.data = module.bias.data.clone().detach() 380 | 381 | for name, child in module.named_children(): 382 | mod.add_module(name, convert_model(child)) 383 | 384 | return mod 385 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | __all__ = ['BatchNorm2dReimpl'] 6 | 7 | 8 | class BatchNorm2dReimpl(nn.Module): 9 | """ 10 | A re-implementation of batch normalization, used for testing the numerical 11 | stability. 12 | 13 | Author: acgtyrant 14 | See also: 15 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 16 | """ 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 18 | super().__init__() 19 | 20 | self.num_features = num_features 21 | self.eps = eps 22 | self.momentum = momentum 23 | self.weight = nn.Parameter(torch.empty(num_features)) 24 | self.bias = nn.Parameter(torch.empty(num_features)) 25 | self.register_buffer('running_mean', torch.zeros(num_features)) 26 | self.register_buffer('running_var', torch.ones(num_features)) 27 | self.reset_parameters() 28 | 29 | def reset_running_stats(self): 30 | self.running_mean.zero_() 31 | self.running_var.fill_(1) 32 | 33 | def reset_parameters(self): 34 | self.reset_running_stats() 35 | init.uniform_(self.weight) 36 | init.zeros_(self.bias) 37 | 38 | def forward(self, input_): 39 | batchsize, channels, height, width = input_.size() 40 | numel = batchsize * height * width 41 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 42 | sum_ = input_.sum(1) 43 | sum_of_square = input_.pow(2).sum(1) 44 | mean = sum_ / numel 45 | sumvar = sum_of_square - sum_ * mean 46 | 47 | self.running_mean = ( 48 | (1 - self.momentum) * self.running_mean 49 | + self.momentum * mean.detach() 50 | ) 51 | unbias_var = sumvar / (numel - 1) 52 | self.running_var = ( 53 | (1 - self.momentum) * self.running_var 54 | + self.momentum * unbias_var.detach() 55 | ) 56 | 57 | bias_var = sumvar / numel 58 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 59 | output = ( 60 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 61 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 62 | 63 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 64 | 65 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import collections 3 | import threading 4 | 5 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 6 | 7 | 8 | class FutureResult(object): 9 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 10 | 11 | def __init__(self): 12 | self._result = None 13 | self._lock = threading.Lock() 14 | self._cond = threading.Condition(self._lock) 15 | 16 | def put(self, result): 17 | with self._lock: 18 | assert self._result is None, 'Previous result has\'t been fetched.' 19 | self._result = result 20 | self._cond.notify() 21 | 22 | def get(self): 23 | with self._lock: 24 | if self._result is None: 25 | self._cond.wait() 26 | 27 | res = self._result 28 | self._result = None 29 | return res 30 | 31 | 32 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 33 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 34 | 35 | 36 | class SlavePipe(_SlavePipeBase): 37 | """Pipe for master-slave communication.""" 38 | 39 | def run_slave(self, msg): 40 | self.queue.put((self.identifier, msg)) 41 | ret = self.result.get() 42 | self.queue.put(True) 43 | return ret 44 | 45 | 46 | class SyncMaster(object): 47 | """An abstract `SyncMaster` object. 48 | 49 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 50 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 51 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 52 | and passed to a registered callback. 53 | - After receiving the messages, the master device should gather the information and determine to message passed 54 | back to each slave devices. 55 | """ 56 | 57 | def __init__(self, master_callback): 58 | """ 59 | 60 | Args: 61 | master_callback: a callback to be invoked after having collected messages from slave devices. 62 | """ 63 | self._master_callback = master_callback 64 | self._queue = queue.Queue() 65 | self._registry = collections.OrderedDict() 66 | self._activated = False 67 | 68 | def __getstate__(self): 69 | return {'master_callback': self._master_callback} 70 | 71 | def __setstate__(self, state): 72 | self.__init__(state['master_callback']) 73 | 74 | def register_slave(self, identifier): 75 | """ 76 | Register an slave device. 77 | 78 | Args: 79 | identifier: an identifier, usually is the device id. 80 | 81 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 82 | 83 | """ 84 | if self._activated: 85 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 86 | self._activated = False 87 | self._registry.clear() 88 | future = FutureResult() 89 | self._registry[identifier] = _MasterRegistry(future) 90 | return SlavePipe(identifier, self._queue, future) 91 | 92 | def run_master(self, master_msg): 93 | """ 94 | Main entry for the master device in each forward pass. 95 | The messages were first collected from each devices (including the master device), and then 96 | an callback will be invoked to compute the message to be sent back to each devices 97 | (including the master device). 98 | 99 | Args: 100 | master_msg: the message that the master want to send to itself. This will be placed as the first 101 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 102 | 103 | Returns: the message to be sent back to the master device. 104 | 105 | """ 106 | self._activated = True 107 | 108 | intermediates = [(0, master_msg)] 109 | for i in range(self.nr_slaves): 110 | intermediates.append(self._queue.get()) 111 | 112 | results = self._master_callback(intermediates) 113 | assert results[0][0] == 0, 'The first result should belongs to the master.' 114 | 115 | for i, res in results: 116 | if i == 0: 117 | continue 118 | self._registry[i].result.put(res) 119 | 120 | for i in range(self.nr_slaves): 121 | assert self._queue.get() is True 122 | 123 | return results[0][1] 124 | 125 | @property 126 | def nr_slaves(self): 127 | return len(self._registry) 128 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | 4 | from torch.nn.parallel.data_parallel import DataParallel 5 | from .scatter_gather import scatter_kwargs 6 | 7 | __all__ = [ 8 | 'CallbackContext', 9 | 'execute_replication_callbacks', 10 | 'DataParallelWithCallback', 11 | 'patch_replication_callback' 12 | ] 13 | 14 | 15 | class CallbackContext(object): 16 | pass 17 | 18 | 19 | def execute_replication_callbacks(modules): 20 | """ 21 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 22 | 23 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 24 | 25 | Note that, as all modules are isomorphism, we assign each sub-module with a context 26 | (shared among multiple copies of this module on different devices). 27 | Through this context, different copies can share some information. 28 | 29 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 30 | of any slave copies. 31 | """ 32 | master_copy = modules[0] 33 | nr_modules = len(list(master_copy.modules())) 34 | ctxs = [CallbackContext() for _ in range(nr_modules)] 35 | 36 | for i, module in enumerate(modules): 37 | for j, m in enumerate(module.modules()): 38 | if hasattr(m, '__data_parallel_replicate__'): 39 | m.__data_parallel_replicate__(ctxs[j], i) 40 | 41 | 42 | class DataParallelWithCallback(DataParallel): 43 | """ 44 | Data Parallel with a replication callback. 45 | 46 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 47 | original `replicate` function. 48 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 49 | 50 | Examples: 51 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 52 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 53 | # sync_bn.__data_parallel_replicate__ will be invoked. 54 | """ 55 | def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_size=None): 56 | super(DataParallelWithCallback, self).__init__(module) 57 | 58 | if not torch.cuda.is_available(): 59 | self.module = module 60 | self.device_ids = [] 61 | return 62 | 63 | if device_ids is None: 64 | device_ids = list(range(torch.cuda.device_count())) 65 | if output_device is None: 66 | output_device = device_ids[0] 67 | self.dim = dim 68 | self.module = module 69 | self.device_ids = device_ids 70 | self.output_device = output_device 71 | self.chunk_size = chunk_size 72 | 73 | if len(self.device_ids) == 1: 74 | self.module.cuda(device_ids[0]) 75 | 76 | def forward(self, *inputs, **kwargs): 77 | if not self.device_ids: 78 | return self.module(*inputs, **kwargs) 79 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_size) 80 | if len(self.device_ids) == 1: 81 | return self.module(*inputs[0], **kwargs[0]) 82 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 83 | outputs = self.parallel_apply(replicas, inputs, kwargs) 84 | return self.gather(outputs, self.output_device) 85 | 86 | def scatter(self, inputs, kwargs, device_ids, chunk_size): 87 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_size=self.chunk_size) 88 | 89 | def replicate(self, module, device_ids): 90 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | 95 | 96 | def patch_replication_callback(data_parallel): 97 | """ 98 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 99 | Useful when you have customized `DataParallel` implementation. 100 | 101 | Examples: 102 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 103 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 104 | > patch_replication_callback(sync_bn) 105 | # this is equivalent to 106 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 107 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 108 | """ 109 | 110 | assert isinstance(data_parallel, DataParallel) 111 | 112 | old_replicate = data_parallel.replicate 113 | 114 | @functools.wraps(old_replicate) 115 | def new_replicate(module, device_ids): 116 | modules = old_replicate(module, device_ids) 117 | execute_replication_callbacks(modules) 118 | return modules 119 | 120 | data_parallel.replicate = new_replicate 121 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/scatter_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel._functions import Scatter, Gather 3 | 4 | 5 | def scatter(inputs, target_gpus, dim=0, chunk_size=None): 6 | r""" 7 | Slices tensors into approximately equal chunks and 8 | distributes them across given GPUs. Duplicates 9 | references to objects that are not tensors. 10 | """ 11 | def scatter_map(obj): 12 | if isinstance(obj, torch.Tensor): 13 | return Scatter.apply(target_gpus, chunk_size, dim, obj) 14 | if isinstance(obj, tuple) and len(obj) > 0: 15 | return list(zip(*map(scatter_map, obj))) 16 | if isinstance(obj, list) and len(obj) > 0: 17 | return list(map(list, zip(*map(scatter_map, obj)))) 18 | if isinstance(obj, dict) and len(obj) > 0: 19 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 20 | return [obj for targets in target_gpus] 21 | 22 | # After scatter_map is called, a scatter_map cell will exist. This cell 23 | # has a reference to the actual function scatter_map, which has references 24 | # to a closure that has a reference to the scatter_map cell (because the 25 | # fn is recursive). To avoid this reference cycle, we set the function to 26 | # None, clearing the cell 27 | try: 28 | res = scatter_map(inputs) 29 | finally: 30 | scatter_map = None 31 | return res 32 | 33 | 34 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_size=None): 35 | r"""Scatter with support for kwargs dictionary""" 36 | inputs = scatter(inputs, target_gpus, dim, chunk_size) if inputs else [] 37 | kwargs = scatter(kwargs, target_gpus, dim, chunk_size) if kwargs else [] 38 | if len(inputs) < len(kwargs): 39 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 40 | elif len(kwargs) < len(inputs): 41 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 42 | inputs = tuple(inputs) 43 | kwargs = tuple(kwargs) 44 | return inputs, kwargs 45 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | 5 | class TorchTestCase(unittest.TestCase): 6 | def assertTensorClose(self, x, y): 7 | adiff = float((x - y).abs().max()) 8 | if (y == 0).all(): 9 | rdiff = 'NaN' 10 | else: 11 | rdiff = float((adiff / y).abs().max()) 12 | 13 | message = ( 14 | 'Tensor close check failed\n' 15 | 'adiff={}\n' 16 | 'rdiff={}\n' 17 | ).format(adiff, rdiff) 18 | self.assertTrue(torch.allclose(x, y), message) 19 | 20 | -------------------------------------------------------------------------------- /models/networks/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | from math import * 8 | 9 | def P2sRt(P): 10 | ''' decompositing camera matrix P. 11 | Args: 12 | P: (3, 4). Affine Camera Matrix. 13 | Returns: 14 | s: scale factor. 15 | R: (3, 3). rotation matrix. 16 | t2d: (2,). 2d translation. 17 | ''' 18 | t3d = P[:, 3] 19 | R1 = P[0:1, :3] 20 | R2 = P[1:2, :3] 21 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2.0 22 | r1 = R1 / np.linalg.norm(R1) 23 | r2 = R2 / np.linalg.norm(R2) 24 | r3 = np.cross(r1, r2) 25 | 26 | R = np.concatenate((r1, r2, r3), 0) 27 | return s, R, t3d 28 | 29 | def matrix2angle(R): 30 | ''' compute three Euler angles from a Rotation Matrix. Ref: http://www.gregslabaugh.net/publications/euler.pdf 31 | Args: 32 | R: (3,3). rotation matrix 33 | Returns: 34 | x: yaw 35 | y: pitch 36 | z: roll 37 | ''' 38 | # assert(isRotationMatrix(R)) 39 | 40 | if R[2, 0] != 1 and R[2, 0] != -1: 41 | x = -asin(max(-1, min(R[2, 0], 1))) 42 | y = atan2(R[2, 1] / cos(x), R[2, 2] / cos(x)) 43 | z = atan2(R[1, 0] / cos(x), R[0, 0] / cos(x)) 44 | 45 | else: # Gimbal lock 46 | z = 0 # can be anything 47 | if R[2, 0] == -1: 48 | x = np.pi / 2 49 | y = z + atan2(R[0, 1], R[0, 2]) 50 | else: 51 | x = -np.pi / 2 52 | y = -z + atan2(-R[0, 1], -R[0, 2]) 53 | 54 | return [x, y, z] 55 | 56 | def angle2matrix(angles): 57 | ''' get rotation matrix from three rotation angles(radian). The same as in 3DDFA. 58 | Args: 59 | angles: [3,]. x, y, z angles 60 | x: yaw. 61 | y: pitch. 62 | z: roll. 63 | Returns: 64 | R: 3x3. rotation matrix. 65 | ''' 66 | # x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) 67 | # x, y, z = angles[0], angles[1], angles[2] 68 | y, x, z = angles[0], angles[1], angles[2] 69 | 70 | # x 71 | Rx=np.array([[1, 0, 0], 72 | [0, cos(x), -sin(x)], 73 | [0, sin(x), cos(x)]]) 74 | # y 75 | Ry=np.array([[ cos(y), 0, sin(y)], 76 | [ 0, 1, 0], 77 | [-sin(y), 0, cos(y)]]) 78 | # z 79 | Rz=np.array([[cos(z), -sin(z), 0], 80 | [sin(z), cos(z), 0], 81 | [ 0, 0, 1]]) 82 | R = Rz.dot(Ry).dot(Rx) 83 | return R.astype(np.float32) 84 | 85 | def tensor2im(input_image, imtype=np.uint8): 86 | """"Converts a Tensor array into a numpy image array. 87 | 88 | Parameters: 89 | input_image (tensor) -- the input image tensor array 90 | imtype (type) -- the desired type of the converted numpy array 91 | """ 92 | if not isinstance(input_image, np.ndarray): 93 | if isinstance(input_image, torch.Tensor): # get the data from a variable 94 | image_tensor = input_image.data 95 | else: 96 | return input_image 97 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 98 | if image_numpy.shape[0] == 1: # grayscale to RGB 99 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 100 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 101 | else: # if it is a numpy array, do nothing 102 | image_numpy = input_image 103 | return image_numpy.astype(imtype) 104 | 105 | 106 | def diagnose_network(net, name='network'): 107 | """Calculate and print the mean of average absolute(gradients) 108 | 109 | Parameters: 110 | net (torch network) -- Torch network 111 | name (str) -- the name of the network 112 | """ 113 | mean = 0.0 114 | count = 0 115 | for param in net.parameters(): 116 | if param.grad is not None: 117 | mean += torch.mean(torch.abs(param.grad.data)) 118 | count += 1 119 | if count > 0: 120 | mean = mean / count 121 | print(name) 122 | print(mean) 123 | 124 | 125 | def save_image(image_numpy, image_path): 126 | """Save a numpy image to the disk 127 | 128 | Parameters: 129 | image_numpy (numpy array) -- input numpy array 130 | image_path (str) -- the path of the image 131 | """ 132 | image_pil = Image.fromarray(image_numpy) 133 | image_pil.save(image_path) 134 | 135 | 136 | def print_numpy(x, val=True, shp=False): 137 | """Print the mean, min, max, median, std, and size of a numpy array 138 | 139 | Parameters: 140 | val (bool) -- if print the values of the numpy array 141 | shp (bool) -- if print the shape of the numpy array 142 | """ 143 | x = x.astype(np.float64) 144 | if shp: 145 | print('shape,', x.shape) 146 | if val: 147 | x = x.flatten() 148 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 149 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 150 | 151 | 152 | def mkdirs(paths): 153 | """create empty directories if they don't exist 154 | 155 | Parameters: 156 | paths (str list) -- a list of directory paths 157 | """ 158 | if isinstance(paths, list) and not isinstance(paths, str): 159 | for path in paths: 160 | mkdir(path) 161 | else: 162 | mkdir(paths) 163 | 164 | 165 | def mkdir(path): 166 | """create a single empty directory if it didn't exist 167 | 168 | Parameters: 169 | path (str) -- a single directory path 170 | """ 171 | if not os.path.exists(path): 172 | os.makedirs(path) 173 | -------------------------------------------------------------------------------- /models/networks/vision_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from models.networks.base_network import BaseNetwork 4 | from torchvision.models.resnet import ResNet, Bottleneck 5 | from util import util 6 | import torch 7 | 8 | model_urls = { 9 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 10 | } 11 | 12 | 13 | class ResNeXt50(BaseNetwork): 14 | def __init__(self, opt): 15 | super(ResNeXt50, self).__init__() 16 | self.model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4) 17 | self.opt = opt 18 | # self.reduced_id_dim = opt.reduced_id_dim 19 | self.conv1x1 = nn.Conv2d(512 * Bottleneck.expansion, 512, kernel_size=1, padding=0) 20 | self.fc = nn.Linear(512 * Bottleneck.expansion, opt.num_classes) 21 | # self.fc_pre = nn.Sequential(nn.Linear(512 * Bottleneck.expansion, self.reduced_id_dim), nn.ReLU()) 22 | 23 | 24 | def load_pretrain(self): 25 | check_point = torch.load(model_urls['resnext50_32x4d']) 26 | util.copy_state_dict(check_point, self.model) 27 | 28 | def forward_feature(self, input): 29 | x = self.model.conv1(input) 30 | x = self.model.bn1(x) 31 | x = self.model.relu(x) 32 | x = self.model.maxpool(x) 33 | 34 | x = self.model.layer1(x) 35 | x = self.model.layer2(x) 36 | x = self.model.layer3(x) 37 | x = self.model.layer4(x) 38 | net = self.model.avgpool(x) 39 | net = torch.flatten(net, 1) 40 | x = self.conv1x1(x) 41 | # x = self.fc_pre(x) 42 | return net, x 43 | 44 | def forward(self, input): 45 | input_batch = input.view(-1, self.opt.output_nc, self.opt.crop_size, self.opt.crop_size) 46 | net, x = self.forward_feature(input_batch) 47 | net = net.view(-1, self.opt.num_inputs, 512 * Bottleneck.expansion) 48 | x = F.adaptive_avg_pool2d(x, (7, 7)) 49 | x = x.view(-1, self.opt.num_inputs, 512, 7, 7) 50 | net = torch.mean(net, 1) 51 | x = torch.mean(x, 1) 52 | cls_scores = self.fc(net) 53 | 54 | return [net, x], cls_scores 55 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hangz-nju-cuhk/Talking-Face_PC-AVS/201d2b946caa93a2ce7c798dd080be65959d2592/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import math 4 | import os 5 | from util import util 6 | import torch 7 | import models 8 | import data 9 | import pickle 10 | 11 | 12 | class BaseOptions(): 13 | def __init__(self): 14 | self.initialized = False 15 | 16 | def initialize(self, parser): 17 | # experiment specifics 18 | parser.add_argument('--name', type=str, default='demo', help='name of the experiment. It decides where to store samples and models') 19 | parser.add_argument('--filename_tmpl', type=str, default='{:06}.jpg', help='name of the experiment. It decides where to store samples and models') 20 | parser.add_argument('--data_path', type=str, default='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv', help='where to load voxceleb train data') 21 | parser.add_argument('--lrw_data_path', type=str, 22 | default='/home/SENSETIME/zhouhang1/Downloads/VoxCeleb2/voxceleb2_train.csv', 23 | help='where to load lrw train data') 24 | 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids') 26 | parser.add_argument('--num_classes', type=int, default=5830, help='num classes') 27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 28 | parser.add_argument('--model', type=str, default='av', help='which model to use, rotate|rotatespade') 29 | parser.add_argument('--trainer', type=str, default='audio', help='which trainer to use, rotate|rotatespade') 30 | parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization') 31 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') 32 | parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization') 33 | parser.add_argument('--norm_A', type=str, default='spectralinstance', help='instance normalization or batch normalization') 34 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 35 | # input/output sizes 36 | parser.add_argument('--batchSize', type=int, default=2, help='input batch size') 37 | parser.add_argument('--preprocess_mode', type=str, default='resize_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none")) 38 | parser.add_argument('--crop_size', type=int, default=224, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') 39 | parser.add_argument('--crop_len', type=int, default=16, help='Crop len') 40 | parser.add_argument('--target_crop_len', type=int, default=0, help='Crop len') 41 | parser.add_argument('--crop', action='store_true', help='whether to crop the image') 42 | parser.add_argument('--clip_len', type=int, default=1, help='num of imgs to process') 43 | parser.add_argument('--pose_dim', type=int, default=12, help='num of imgs to process') 44 | parser.add_argument('--frame_interval', type=int, default=1, help='the interval of frams') 45 | parser.add_argument('--num_clips', type=int, default=1, help='num of clips to process') 46 | parser.add_argument('--num_inputs', type=int, default=1, help='num of inputs to the network') 47 | parser.add_argument('--feature_encoded_dim', type=int, default=2560, help='dim of reduced id feature') 48 | 49 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio') 50 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 51 | parser.add_argument('--audio_nc', type=int, default=256, help='# of output audio channels') 52 | parser.add_argument('--frame_rate', type=int, default=25, help='fps') 53 | parser.add_argument('--num_frames_per_clip', type=int, default=5, help='num of frames one audio bin') 54 | parser.add_argument('--hop_size', type=int, default=160, help='audio hop size') 55 | parser.add_argument('--generate_interval', type=int, default=1, help='select frames to generate') 56 | parser.add_argument('--dis_feat_rec', action='store_true', help='select frames to generate') 57 | 58 | parser.add_argument('--train_recognition', action='store_true', help='train recognition only') 59 | parser.add_argument('--train_sync', action='store_true', help='train sync only') 60 | parser.add_argument('--train_word', action='store_true', help='train word only') 61 | parser.add_argument('--train_dis_pose', action='store_true', help='train dis pose') 62 | parser.add_argument('--generate_from_audio_only', action='store_true', help='if specified, generate only from audio features') 63 | parser.add_argument('--noise_pose', action='store_true', help='noise pose to generate a talking face') 64 | parser.add_argument('--style_feature_loss', action='store_true', help='style_feature_loss') 65 | 66 | # for setting inputsf 67 | parser.add_argument('--dataset_mode', type=str, default='voxtest') 68 | parser.add_argument('--landmark_align', action='store_true', help='wether there is landmark_align') 69 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 70 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 71 | parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data') 72 | parser.add_argument('--n_mel_T', default=4, type=int, help='# threads for loading data') 73 | parser.add_argument('--num_bins_per_frame', type=int, default=4, help='n_melT') 74 | 75 | parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 76 | parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default') 77 | parser.add_argument('--use_audio', type=int, default=1, help='use audio as driven input') 78 | parser.add_argument('--use_audio_id', type=int, default=0, help='use audio id') 79 | parser.add_argument('--augment_target', action='store_true', help='whether to use checkpoint') 80 | parser.add_argument('--verbose', action='store_true', help='just add') 81 | 82 | parser.add_argument('--display_winsize', type=int, default=224, help='display window size') 83 | 84 | # for generator 85 | parser.add_argument('--netG', type=str, default='modulate', help='selects model to use for netG (modulate)') 86 | parser.add_argument('--netA', type=str, default='resseaudio', help='selects model to use for netA (audio | spade)') 87 | parser.add_argument('--netA_sync', type=str, default='ressesync', help='selects model to use for netA (audio | spade)') 88 | parser.add_argument('--netV', type=str, default='resnext', help='selects model to use for netV (mobile | id)') 89 | parser.add_argument('--netE', type=str, default='fan', help='selects model to use for netV (mobile | fan)') 90 | parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image|projection)') 91 | parser.add_argument('--D_input', type=str, default='single', help='(concat|single|hinge)') 92 | parser.add_argument('--driven_type', type=str, default='face', help='selects model to use for netV (heatmap | face)') 93 | parser.add_argument('--landmark_type', type=str, default='min', help='selects model to use for netV (mobile | fan)') 94 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 95 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') 96 | parser.add_argument('--feature_fusion', type=str, default='concat', help='style fusion method') 97 | parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution') 98 | 99 | # for instance-wise features 100 | parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') 101 | parser.add_argument('--input_id_feature', action='store_true', help='if specified, use id feature as style gan input') 102 | parser.add_argument('--load_landmark', action='store_true', help='if specified, load landmarks') 103 | parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 104 | parser.add_argument('--style_dim', type=int, default=2580, help='# of encoder filters in the first conv layer') 105 | 106 | ####################### weight settings ################################################################### 107 | 108 | parser.add_argument('--vgg_face', action='store_true', help='if specified, use VGG feature matching loss') 109 | 110 | parser.add_argument('--VGGFace_pretrain_path', type=str, default='', help='VGGFace pretrain path') 111 | parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 112 | parser.add_argument('--lambda_image', type=float, default=1.0, help='weight for image reconstruction') 113 | parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss') 114 | parser.add_argument('--lambda_vggface', type=float, default=5.0, help='weight for vggface loss') 115 | parser.add_argument('--lambda_rotate_D', type=float, default='0.1', 116 | help='rotated D loss weight') 117 | parser.add_argument('--lambda_D', type=float, default=1, 118 | help='D loss weight') 119 | parser.add_argument('--lambda_softmax', type=float, default=1000000, help='weight for softmax loss') 120 | parser.add_argument('--lambda_crossmodal', type=float, default=1, help='weight for softmax loss') 121 | 122 | parser.add_argument('--lambda_contrastive', type=float, default=100, help='if specified, use contrastive loss for img and audio embed') 123 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 124 | 125 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 126 | parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 127 | parser.add_argument('--no_id_loss', action='store_true', help='if specified, do *not* use cls loss') 128 | parser.add_argument('--word_loss', action='store_true', help='if specified, do *not* use cls loss') 129 | parser.add_argument('--no_spectrogram', action='store_true', help='if specified, do *not* use mel spectrogram, use mfcc') 130 | 131 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 132 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') 133 | 134 | ############################## optimizer ############################# 135 | parser.add_argument('--optimizer', type=str, default='adam') 136 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 137 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 138 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') 139 | 140 | parser.add_argument('--no_gaussian_landmark', action='store_true', help='whether to use no_gaussian_landmark (1.0 landmark) for rotatespade model') 141 | parser.add_argument('--label_mask', action='store_true', help='whether to use face mask') 142 | parser.add_argument('--positional_encode', action='store_true', help='whether to use positional encode') 143 | parser.add_argument('--use_transformer', action='store_true', help='whether to use transformer') 144 | parser.add_argument('--has_mask', action='store_true', help='whether to use mask in transformer') 145 | parser.add_argument('--heatmap_size', type=float, default=3, help='the size of the heatmap, used in rotatespade model') 146 | 147 | self.initialized = True 148 | return parser 149 | 150 | def gather_options(self): 151 | # initialize parser with basic options 152 | if not self.initialized: 153 | parser = argparse.ArgumentParser( 154 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 155 | parser = self.initialize(parser) 156 | 157 | # get the basic options 158 | opt, unknown = parser.parse_known_args() 159 | 160 | # modify model-related parser options 161 | model_name = opt.model 162 | model_option_setter = models.get_option_setter(model_name) 163 | parser = model_option_setter(parser, self.isTrain) 164 | 165 | # modify dataset-related parser options 166 | dataset_mode = opt.dataset_mode 167 | dataset_modes = opt.dataset_mode.split(',') 168 | 169 | if len(dataset_modes) == 1: 170 | dataset_option_setter = data.get_option_setter(dataset_mode) 171 | parser = dataset_option_setter(parser, self.isTrain) 172 | else: 173 | for dm in dataset_modes: 174 | dataset_option_setter = data.get_option_setter(dm) 175 | parser = dataset_option_setter(parser, self.isTrain) 176 | 177 | opt, unknown = parser.parse_known_args() 178 | 179 | # if there is opt_file, load it. 180 | # lt options will be overwritten 181 | if opt.load_from_opt_file: 182 | parser = self.update_options_from_file(parser, opt) 183 | 184 | opt = parser.parse_args() 185 | self.parser = parser 186 | return opt 187 | 188 | def print_options(self, opt): 189 | message = '' 190 | message += '----------------- Options ---------------\n' 191 | for k, v in sorted(vars(opt).items()): 192 | comment = '' 193 | default = self.parser.get_default(k) 194 | if v != default: 195 | comment = '\t[default: %s]' % str(default) 196 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 197 | message += '----------------- End -------------------' 198 | print(message) 199 | 200 | def option_file_path(self, opt, makedir=False): 201 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 202 | if makedir: 203 | util.mkdirs(expr_dir) 204 | file_name = os.path.join(expr_dir, 'opt') 205 | return file_name 206 | 207 | def save_options(self, opt): 208 | file_name = self.option_file_path(opt, makedir=True) 209 | with open(file_name + '.txt', 'wt') as opt_file: 210 | for k, v in sorted(vars(opt).items()): 211 | comment = '' 212 | default = self.parser.get_default(k) 213 | if v != default: 214 | comment = '\t[default: %s]' % str(default) 215 | opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) 216 | 217 | with open(file_name + '.pkl', 'wb') as opt_file: 218 | pickle.dump(opt, opt_file) 219 | 220 | def update_options_from_file(self, parser, opt): 221 | new_opt = self.load_options(opt) 222 | for k, v in sorted(vars(opt).items()): 223 | if hasattr(new_opt, k) and v != getattr(new_opt, k): 224 | new_val = getattr(new_opt, k) 225 | parser.set_defaults(**{k: new_val}) 226 | return parser 227 | 228 | def load_options(self, opt): 229 | file_name = self.option_file_path(opt, makedir=False) 230 | new_opt = pickle.load(open(file_name + '.pkl', 'rb')) 231 | return new_opt 232 | 233 | def parse(self, save=False): 234 | 235 | opt = self.gather_options() 236 | opt.isTrain = self.isTrain # train or test 237 | 238 | self.print_options(opt) 239 | if opt.isTrain: 240 | self.save_options(opt) 241 | # Set semantic_nc based on the option. 242 | # This will be convenient in many places 243 | # set gpu ids 244 | str_ids = opt.gpu_ids.split(',') 245 | opt.gpu_ids = [] 246 | for str_id in str_ids: 247 | id = int(str_id) 248 | if id >= 0: 249 | opt.gpu_ids.append(id) 250 | if len(opt.gpu_ids) > 0: 251 | torch.cuda.set_device(opt.gpu_ids[0]) 252 | 253 | 254 | self.opt = opt 255 | return self.opt 256 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | BaseOptions.initialize(self, parser) 7 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 8 | parser.add_argument('--input_path', type=str, default='./checkpoints/results/input_path', help='defined input path.') 9 | parser.add_argument('--meta_path_vox', type=str, default='./misc/demo.csv', help='the meta data path') 10 | parser.add_argument('--driving_pose', action='store_true', help='driven pose to generate a talking face') 11 | parser.add_argument('--list_num', type=int, default=0, help='list num') 12 | parser.add_argument('--fitting_iterations', type=int, default=10, help='The iterarions for fit testing') 13 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 14 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run') 15 | parser.add_argument('--start_ind', type=int, default=0, help='the start id for defined driven') 16 | parser.add_argument('--list_start', type=int, default=0, help='which num in the list to start') 17 | parser.add_argument('--list_end', type=int, default=float("inf"), help='how many test images to run') 18 | parser.add_argument('--save_path', type=str, default='./results/', help='where to save data') 19 | parser.add_argument('--multi_gpu', action='store_true', help='whether to use multi gpus') 20 | parser.add_argument('--defined_driven', action='store_true', help='whether to use defined driven') 21 | parser.add_argument('--gen_video', action='store_true', help='whether to generate videos') 22 | parser.add_argument('--onnx', action='store_true', help='for tddfa') 23 | parser.add_argument('--mode', type=str, default='cpu', help='gpu or cpu mode') 24 | 25 | # parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256) 26 | # parser.set_defaults(serial_batches=True) 27 | parser.set_defaults(no_flip=True) 28 | parser.set_defaults(phase='test') 29 | self.isTrain = False 30 | return parser 31 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | BaseOptions.initialize(self, parser) 7 | # for displays 8 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 9 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 10 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 11 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 12 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 13 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 14 | parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 15 | parser.add_argument('--tensorboard', default=True, help='if specified, use tensorboard logging. Requires tensorflow installed') 16 | parser.add_argument('--load_pretrain', type=str, default='', 17 | help='load the pretrained model from the specified location') 18 | 19 | # for training 20 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 21 | parser.add_argument('--recognition', action='store_true', help='train only recognition') 22 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 23 | parser.add_argument('--noload_D', action='store_true', help='whether to load D when continue training') 24 | parser.add_argument('--pose_noise', action='store_true', help='whether to use pose noise training') 25 | parser.add_argument('--load_separately', action='store_true', help='whether to continue train by loading separate models') 26 | parser.add_argument('--niter', type=int, default=10, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') 27 | parser.add_argument('--niter_decay', type=int, default=1000, help='# of iter to linearly decay learning rate to zero') 28 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.') 29 | 30 | parser.add_argument('--G_pretrain_path', type=str, default='./checkpoints/100_net_G.pth', help='G pretrain path') 31 | parser.add_argument('--D_pretrain_path', type=str, default='', help='D pretrain path') 32 | parser.add_argument('--E_pretrain_path', type=str, default='', help='E pretrain path') 33 | parser.add_argument('--V_pretrain_path', type=str, default='', help='V pretrain path') 34 | parser.add_argument('--A_pretrain_path', type=str, default='', help='E pretrain path') 35 | parser.add_argument('--A_sync_pretrain_path', type=str, default='', help='E pretrain path') 36 | parser.add_argument('--netE_pretrain_path', type=str, default='', help='E pretrain path') 37 | 38 | parser.add_argument('--fix_netV', action='store_true', help='if specified, fix net V') 39 | parser.add_argument('--fix_netE', action='store_true', help='if specified, fix net E') 40 | parser.add_argument('--fix_netE_mouth', action='store_true', help='if specified, fix net E mapper, fc and mapper') 41 | parser.add_argument('--fix_netE_mouth_embed', action='store_true', help='if specified, fix net E mapper, fc and mapper') 42 | parser.add_argument('--fix_netE_headpose', action='store_true', help='if specified, fix net E headpose') 43 | parser.add_argument('--fix_netA_sync', action='store_true', help='if specified fix net A_sync') 44 | parser.add_argument('--fix_netG', action='store_true', help='if specified, fix net G') 45 | parser.add_argument('--fix_netD', action='store_true', help='if specified, fix net D') 46 | parser.add_argument('--no_cross_modal', action='store_true', help='if specified, do *not* use cls loss') 47 | parser.add_argument('--softmax_contrastive', action='store_true', help='if specified, use contrastive loss for img and audio embed') 48 | # for discriminators 49 | 50 | parser.add_argument('--baseline_sync', action='store_true', help='train baseline sync') 51 | parser.add_argument('--style_feature_loss', action='store_true', help='to use style feature loss') 52 | # parser.add_argument('--vggface_checkpoint', type=str, default='', help='pth to vggface ckpt') 53 | parser.add_argument('--pretrain', action='store_true', help='Use outsider pretrain') 54 | parser.add_argument('--disentangle', action='store_true', help='whether to use disentangle loss') 55 | self.isTrain = True 56 | return parser 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | torchvision 3 | dominate>=2.3.1 4 | dill 5 | scikit-image 6 | numpy>=1.15.4 7 | scipy>=1.1.0 8 | matplotlib 9 | opencv-python>=3.4.3.18 10 | tensorboard==1.14.0 11 | tqdm 12 | librosa -------------------------------------------------------------------------------- /scripts/align_68.py: -------------------------------------------------------------------------------- 1 | import face_alignment 2 | import os 3 | import cv2 4 | import skimage.transform as trans 5 | import argparse 6 | import torch 7 | import numpy as np 8 | 9 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 10 | 11 | 12 | def get_affine(src): 13 | dst = np.array([[87, 59], 14 | [137, 59], 15 | [112, 120]], dtype=np.float32) 16 | tform = trans.SimilarityTransform() 17 | tform.estimate(src, dst) 18 | M = tform.params[0:2, :] 19 | return M 20 | 21 | 22 | def affine_align_img(img, M, crop_size=224): 23 | warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 24 | return warped 25 | 26 | 27 | def affine_align_3landmarks(landmarks, M): 28 | new_landmarks = np.concatenate([landmarks, np.ones((3, 1))], 1) 29 | affined_landmarks = np.matmul(new_landmarks, M.transpose()) 30 | return affined_landmarks 31 | 32 | 33 | def get_eyes_mouths(landmark): 34 | three_points = np.zeros((3, 2)) 35 | three_points[0] = landmark[36:42].mean(0) 36 | three_points[1] = landmark[42:48].mean(0) 37 | three_points[2] = landmark[60:68].mean(0) 38 | return three_points 39 | 40 | 41 | def get_mouth_bias(three_points): 42 | bias = np.array([112, 120]) - three_points[2] 43 | return bias 44 | 45 | 46 | def align_folder(folder_path, folder_save_path): 47 | 48 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device) 49 | preds = fa.get_landmarks_from_directory(folder_path) 50 | 51 | sumpoints = 0 52 | three_points_list = [] 53 | 54 | for img in preds.keys(): 55 | pred_points = np.array(preds[img]) 56 | if pred_points is None or len(pred_points.shape) != 3: 57 | print('preprocessing failed') 58 | return False 59 | else: 60 | num_faces, size, _ = pred_points.shape 61 | if num_faces == 1 and size == 68: 62 | 63 | three_points = get_eyes_mouths(pred_points[0]) 64 | sumpoints += three_points 65 | three_points_list.append(three_points) 66 | else: 67 | 68 | print('preprocessing failed') 69 | return False 70 | avg_points = sumpoints / len(preds) 71 | M = get_affine(avg_points) 72 | p_bias = None 73 | for i, img_pth in enumerate(preds.keys()): 74 | three_points = three_points_list[i] 75 | affined_3landmarks = affine_align_3landmarks(three_points, M) 76 | bias = get_mouth_bias(affined_3landmarks) 77 | if p_bias is None: 78 | bias = bias 79 | else: 80 | bias = p_bias * 0.2 + bias * 0.8 81 | p_bias = bias 82 | M_i = M.copy() 83 | M_i[:, 2] = M[:, 2] + bias 84 | img = cv2.imread(img_pth) 85 | wrapped = affine_align_img(img, M_i) 86 | img_save_path = os.path.join(folder_save_path, img_pth.split('/')[-1]) 87 | cv2.imwrite(img_save_path, wrapped) 88 | print('cropped files saved at {}'.format(folder_save_path)) 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument('--folder_path', help='the folder which needs processing') 94 | args = parser.parse_args() 95 | 96 | if os.path.isdir(args.folder_path): 97 | home_path = '/'.join(args.folder_path.split('/')[:-1]) 98 | save_img_path = os.path.join(home_path, args.folder_path.split('/')[-1] + '_cropped') 99 | os.makedirs(save_img_path, exist_ok=True) 100 | 101 | align_folder(args.folder_path, save_img_path) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /scripts/prepare_testing_files.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 4 | import argparse 5 | import glob 6 | import csv 7 | import numpy as np 8 | from config.AudioConfig import AudioConfig 9 | 10 | 11 | def mkdir(path): 12 | if not os.path.exists(path): 13 | os.makedirs(path) 14 | 15 | 16 | def proc_frames(src_path, dst_path): 17 | cmd = 'ffmpeg -i \"{}\" -start_number 0 -qscale:v 2 \"{}\"/%06d.jpg -loglevel error -y'.format(src_path, dst_path) 18 | os.system(cmd) 19 | frames = glob.glob(os.path.join(dst_path, '*.jpg')) 20 | return len(frames) 21 | 22 | 23 | def proc_audio(src_mouth_path, dst_audio_path): 24 | audio_command = 'ffmpeg -i \"{}\" -loglevel error -y -f wav -acodec pcm_s16le ' \ 25 | '-ar 16000 \"{}\"'.format(src_mouth_path, dst_audio_path) 26 | os.system(audio_command) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | # parser.add_argument('--dst_dir_path', default='/mnt/lustre/DATAshare3/VoxCeleb2', 32 | # help="dst file position") 33 | parser.add_argument('--dir_path', default='./misc', 34 | help="dst file position") 35 | parser.add_argument('--src_pose_path', default='./misc/Pose_Source/00473.mp4', 36 | help="pose source file position, this could be an mp4 or a folder") 37 | parser.add_argument('--src_audio_path', default='./misc/Audio_Source/00015.mp4', 38 | help="audio source file position, it could be an mp3 file or an mp4 video with audio") 39 | parser.add_argument('--src_mouth_frame_path', default=None, 40 | help="mouth frame file position, the video frames synced with audios") 41 | parser.add_argument('--src_input_path', default='./misc/Input/00098.mp4', 42 | help="input file position, it could be a folder with frames, a jpg or an mp4") 43 | parser.add_argument('--csv_path', default='./misc/demo2.csv', 44 | help="path to output index files") 45 | parser.add_argument('--convert_spectrogram', action='store_true', help='whether to convert audio to spectrogram') 46 | 47 | args = parser.parse_args() 48 | dir_path = args.dir_path 49 | mkdir(dir_path) 50 | 51 | # ===================== process input ======================================================= 52 | input_save_path = os.path.join(dir_path, 'Input') 53 | mkdir(input_save_path) 54 | input_name = args.src_input_path.split('/')[-1].split('.')[0] 55 | num_inputs = 1 56 | dst_input_path = os.path.join(input_save_path, input_name) 57 | mkdir(dst_input_path) 58 | if args.src_input_path.split('/')[-1].split('.')[-1] == 'mp4': 59 | num_inputs = proc_frames(args.src_input_path, dst_input_path) 60 | elif os.path.isdir(args.src_input_path): 61 | dst_input_path = args.src_input_path 62 | else: 63 | os.system('cp {} {}'.format(args.src_input_path, os.path.join(dst_input_path, args.src_input_path.split('/')[-1]))) 64 | 65 | 66 | # ===================== process audio ======================================================= 67 | audio_source_save_path = os.path.join(dir_path, 'Audio_Source') 68 | mkdir(audio_source_save_path) 69 | audio_name = args.src_audio_path.split('/')[-1].split('.')[0] 70 | spec_dir = 'None' 71 | dst_audio_path = os.path.join(audio_source_save_path, audio_name + '.mp3') 72 | 73 | if args.src_audio_path.split('/')[-1].split('.')[-1] == 'mp3': 74 | os.system('cp {} {}'.format(args.src_audio_path, dst_audio_path)) 75 | if args.src_mouth_frame_path and os.path.isdir(args.src_mouth_frame_path): 76 | dst_mouth_frame_path = args.src_mouth_frame_path 77 | num_mouth_frames = len(glob.glob(os.path.join(args.src_mouth_frame_path, '*.jpg')) + glob.glob(os.path.join(args.src_mouth_frame_path, '*.png'))) 78 | else: 79 | dst_mouth_frame_path = 'None' 80 | num_mouth_frames = 0 81 | else: 82 | mouth_source_save_path = os.path.join(dir_path, 'Mouth_Source') 83 | mkdir(mouth_source_save_path) 84 | dst_mouth_frame_path = os.path.join(mouth_source_save_path, audio_name) 85 | mkdir(dst_mouth_frame_path) 86 | proc_audio(args.src_audio_path, dst_audio_path) 87 | num_mouth_frames = proc_frames(args.src_audio_path, dst_mouth_frame_path) 88 | 89 | if args.convert_spectrogram: 90 | audio = AudioConfig(fft_size=1280, hop_size=160) 91 | wav = audio.read_audio(dst_audio_path) 92 | spectrogram = audio.audio_to_spectrogram(wav) 93 | spec_dir = os.path.join(audio_source_save_path, audio_name + '.npy') 94 | np.save(spec_dir, 95 | spectrogram.astype(np.float32), allow_pickle=False) 96 | 97 | # ===================== process pose ======================================================= 98 | if os.path.isdir(args.src_pose_path): 99 | num_pose_frames = len(glob.glob(os.path.join(args.src_pose_path, '*.jpg')) + glob.glob(os.path.join(args.src_pose_path, '*.png'))) 100 | dst_pose_frame_path = args.src_pose_path 101 | else: 102 | pose_source_save_path = os.path.join(dir_path, 'Pose_Source') 103 | mkdir(pose_source_save_path) 104 | pose_name = args.src_pose_path.split('/')[-1].split('.')[0] 105 | dst_pose_frame_path = os.path.join(pose_source_save_path, pose_name) 106 | mkdir(dst_pose_frame_path) 107 | num_pose_frames = proc_frames(args.src_pose_path, dst_pose_frame_path) 108 | 109 | # ===================== form csv ======================================================= 110 | 111 | with open(args.csv_path, 'w', newline='') as csvfile: 112 | writer = csv.writer(csvfile, delimiter=' ', quoting=csv.QUOTE_MINIMAL) 113 | writer.writerows([[dst_input_path, str(num_inputs), dst_pose_frame_path, str(num_pose_frames), 114 | dst_audio_path, dst_mouth_frame_path, str(num_mouth_frames), spec_dir]]) 115 | print('meta-info saved at ' + args.csv_path) 116 | 117 | csvfile.close() -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import dominate 3 | from dominate.tags import * 4 | import os 5 | 6 | 7 | class HTML: 8 | def __init__(self, web_dir, title, refresh=0): 9 | if web_dir.endswith('.html'): 10 | web_dir, html_name = os.path.split(web_dir) 11 | else: 12 | web_dir, html_name = web_dir, 'index.html' 13 | self.title = title 14 | self.web_dir = web_dir 15 | self.html_name = html_name 16 | self.img_dir = os.path.join(self.web_dir, 'images') 17 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 18 | os.makedirs(self.web_dir) 19 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 20 | os.makedirs(self.img_dir) 21 | 22 | self.doc = dominate.document(title=title) 23 | with self.doc: 24 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 25 | if refresh > 0: 26 | with self.doc.head: 27 | meta(http_equiv="refresh", content=str(refresh)) 28 | 29 | def get_image_dir(self): 30 | return self.img_dir 31 | 32 | def add_header(self, str): 33 | with self.doc: 34 | h3(str) 35 | 36 | def add_table(self, border=1): 37 | self.t = table(border=border, style="table-layout: fixed;") 38 | self.doc.add(self.t) 39 | 40 | def add_images(self, ims, txts, links, width=512): 41 | self.add_table() 42 | with self.t: 43 | with tr(): 44 | for im, txt, link in zip(ims, txts, links): 45 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 46 | with p(): 47 | with a(href=os.path.join('images', link)): 48 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 49 | br() 50 | p(txt.encode('utf-8')) 51 | 52 | def save(self): 53 | html_file = os.path.join(self.web_dir, self.html_name) 54 | f = open(html_file, 'wt') 55 | f.write(self.doc.render()) 56 | f.close() 57 | 58 | 59 | if __name__ == '__main__': 60 | html = HTML('web/', 'test_html') 61 | html.add_header('hello world') 62 | 63 | ims = [] 64 | txts = [] 65 | links = [] 66 | for n in range(4): 67 | ims.append('image_%d.jpg' % n) 68 | txts.append('text_%d' % n) 69 | links.append('image_%d.jpg' % n) 70 | html.add_images(ims, txts, links) 71 | html.save() 72 | -------------------------------------------------------------------------------- /util/iter_counter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | 5 | 6 | # Helper class that keeps track of training iterations 7 | class IterationCounter(): 8 | def __init__(self, opt, dataset_size): 9 | self.opt = opt 10 | self.dataset_size = dataset_size 11 | 12 | self.first_epoch = 1 13 | self.total_epochs = opt.niter + opt.niter_decay if opt.isTrain else 1 14 | self.epoch_iter = 0 # iter number within each epoch 15 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') 16 | if opt.isTrain and opt.continue_train: 17 | try: 18 | self.first_epoch, self.epoch_iter = np.loadtxt( 19 | self.iter_record_path, delimiter=',', dtype=int) 20 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) 21 | except: 22 | print('Could not load iteration record at %s. Starting from beginning.' % 23 | self.iter_record_path) 24 | 25 | self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter 26 | 27 | # return the iterator of epochs for the training 28 | def training_epochs(self): 29 | return range(self.first_epoch, self.total_epochs + 1) 30 | 31 | def record_epoch_start(self, epoch): 32 | self.epoch_start_time = time.time() 33 | self.epoch_iter = 0 34 | self.last_iter_time = time.time() 35 | self.current_epoch = epoch 36 | 37 | def record_one_iteration(self): 38 | current_time = time.time() 39 | 40 | # the last remaining batch is dropped (see data/__init__.py), 41 | # so we can assume batch size is always opt.batchSize 42 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize 43 | self.last_iter_time = current_time 44 | self.total_steps_so_far += self.opt.batchSize 45 | self.epoch_iter += self.opt.batchSize 46 | 47 | def record_epoch_end(self): 48 | current_time = time.time() 49 | self.time_per_epoch = current_time - self.epoch_start_time 50 | print('End of epoch %d / %d \t Time Taken: %d sec' % 51 | (self.current_epoch, self.total_epochs, self.time_per_epoch)) 52 | if self.current_epoch % self.opt.save_epoch_freq == 0: 53 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), 54 | delimiter=',', fmt='%d') 55 | print('Saved current iteration count at %s.' % self.iter_record_path) 56 | 57 | def record_current_iter(self): 58 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), 59 | delimiter=',', fmt='%d') 60 | print('Saved current iteration count at %s.' % self.iter_record_path) 61 | 62 | def needs_saving(self): 63 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize 64 | 65 | def needs_printing(self): 66 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize 67 | 68 | def needs_displaying(self): 69 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize 70 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import re 2 | import importlib 3 | import torch 4 | from argparse import Namespace 5 | import numpy as np 6 | from PIL import Image 7 | import os 8 | import argparse 9 | import dill as pickle 10 | import skimage.transform as trans 11 | import cv2 12 | 13 | 14 | def save_obj(obj, name): 15 | with open(name, 'wb') as f: 16 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 17 | 18 | 19 | def load_obj(name): 20 | with open(name, 'rb') as f: 21 | return pickle.load(f) 22 | 23 | # returns a configuration for creating a generator 24 | # |default_opt| should be the opt of the current experiment 25 | # |**kwargs|: if any configuration should be overriden, it can be specified here 26 | 27 | 28 | def copyconf(default_opt, **kwargs): 29 | conf = argparse.Namespace(**vars(default_opt)) 30 | for key in kwargs: 31 | print(key, kwargs[key]) 32 | setattr(conf, key, kwargs[key]) 33 | return conf 34 | 35 | 36 | def tile_images(imgs, picturesPerRow=4): 37 | """ Code borrowed from 38 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 39 | """ 40 | 41 | # Padding 42 | if imgs.shape[0] % picturesPerRow == 0: 43 | rowPadding = 0 44 | else: 45 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow 46 | if rowPadding > 0: 47 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) 48 | 49 | # Tiling Loop (The conditionals are not necessary anymore) 50 | tiled = [] 51 | for i in range(0, imgs.shape[0], picturesPerRow): 52 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) 53 | 54 | tiled = np.concatenate(tiled, axis=0) 55 | return tiled 56 | 57 | 58 | # Converts a Tensor into a Numpy array 59 | # |imtype|: the desired type of the converted numpy array 60 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=True): 61 | if isinstance(image_tensor, list): 62 | image_numpy = [] 63 | for i in range(len(image_tensor)): 64 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 65 | return image_numpy 66 | 67 | if image_tensor.dim() == 4: 68 | # transform each image in the batch 69 | images_np = [] 70 | for b in range(image_tensor.size(0)): 71 | one_image = image_tensor[b] 72 | one_image_np = tensor2im(one_image) 73 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 74 | images_np = np.concatenate(images_np, axis=0) 75 | if tile: 76 | images_tiled = tile_images(images_np) 77 | return images_tiled 78 | else: 79 | if len(images_np.shape) == 4 and images_np.shape[0] == 1: 80 | images_np = images_np[0] 81 | return images_np 82 | 83 | if image_tensor.dim() == 2: 84 | image_tensor = image_tensor.unsqueeze(0) 85 | image_numpy = image_tensor.detach().cpu().float().numpy() 86 | if normalize: 87 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 88 | else: 89 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 90 | image_numpy = np.clip(image_numpy, 0, 255) 91 | if image_numpy.shape[2] == 1: 92 | image_numpy = image_numpy[:, :, 0] 93 | return image_numpy.astype(imtype) 94 | 95 | 96 | 97 | def save_image(image_numpy, image_path, create_dir=False): 98 | if create_dir: 99 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 100 | if len(image_numpy.shape) == 4: 101 | image_numpy = image_numpy[0] 102 | if len(image_numpy.shape) == 2: 103 | image_numpy = np.expand_dims(image_numpy, axis=2) 104 | if image_numpy.shape[2] == 1: 105 | image_numpy = np.repeat(image_numpy, 3, 2) 106 | image_pil = Image.fromarray(image_numpy) 107 | 108 | # save to png 109 | image_pil.save(image_path) 110 | # image_pil.save(image_path.replace('.jpg', '.png')) 111 | 112 | 113 | def save_torch_img(img, save_path): 114 | image_numpy = tensor2im(img,tile=False) 115 | save_image(image_numpy, save_path, create_dir=True) 116 | return image_numpy 117 | 118 | 119 | 120 | def mkdirs(paths): 121 | if isinstance(paths, list) and not isinstance(paths, str): 122 | for path in paths: 123 | mkdir(path) 124 | else: 125 | mkdir(paths) 126 | 127 | 128 | def mkdir(path): 129 | if not os.path.exists(path): 130 | os.makedirs(path) 131 | 132 | 133 | def atoi(text): 134 | return int(text) if text.isdigit() else text 135 | 136 | 137 | def natural_keys(text): 138 | ''' 139 | alist.sort(key=natural_keys) sorts in human order 140 | http://nedbatchelder.com/blog/200712/human_sorting.html 141 | (See Toothy's implementation in the comments) 142 | ''' 143 | return [atoi(c) for c in re.split('(\d+)', text)] 144 | 145 | 146 | def natural_sort(items): 147 | items.sort(key=natural_keys) 148 | 149 | 150 | def str2bool(v): 151 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 152 | return True 153 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 154 | return False 155 | else: 156 | raise argparse.ArgumentTypeError('Boolean value expected.') 157 | 158 | 159 | def find_class_in_module(target_cls_name, module): 160 | target_cls_name = target_cls_name.replace('_', '').lower() 161 | clslib = importlib.import_module(module) 162 | cls = None 163 | for name, clsobj in clslib.__dict__.items(): 164 | if name.lower() == target_cls_name: 165 | cls = clsobj 166 | 167 | if cls is None: 168 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) 169 | exit(0) 170 | 171 | return cls 172 | 173 | 174 | def save_network(net, label, epoch, opt): 175 | save_filename = '%s_net_%s.pth' % (epoch, label) 176 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) 177 | torch.save(net.cpu().state_dict(), save_path) 178 | if len(opt.gpu_ids) and torch.cuda.is_available(): 179 | net.cuda() 180 | 181 | 182 | def load_network(net, label, epoch, opt): 183 | save_filename = '%s_net_%s.pth' % (epoch, label) 184 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 185 | save_path = os.path.join(save_dir, save_filename) 186 | weights = torch.load(save_path) 187 | net.load_state_dict(weights) 188 | return net 189 | 190 | 191 | def copy_state_dict(state_dict, model, strip=None, replace=None): 192 | tgt_state = model.state_dict() 193 | copied_names = set() 194 | for name, param in state_dict.items(): 195 | if strip is not None and replace is None and name.startswith(strip): 196 | name = name[len(strip):] 197 | if strip is not None and replace is not None: 198 | name = name.replace(strip, replace) 199 | if name not in tgt_state: 200 | continue 201 | if isinstance(param, torch.nn.Parameter): 202 | param = param.data 203 | if param.size() != tgt_state[name].size(): 204 | print('mismatch:', name, param.size(), tgt_state[name].size()) 205 | continue 206 | tgt_state[name].copy_(param) 207 | copied_names.add(name) 208 | 209 | missing = set(tgt_state.keys()) - copied_names 210 | if len(missing) > 0: 211 | print("missing keys in state_dict:", missing) 212 | 213 | 214 | 215 | def freeze_model(net): 216 | for param in net.parameters(): 217 | param.requires_grad = False 218 | ############################################################################### 219 | # Code from 220 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 221 | # Modified so it complies with the Citscape label map colors 222 | ############################################################################### 223 | def uint82bin(n, count=8): 224 | """returns the binary of integer n, count refers to amount of bits""" 225 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) 226 | 227 | def build_landmark_dict(ldmk_path): 228 | with open(ldmk_path) as f: 229 | lines = f.readlines() 230 | ldmk_dict = {} 231 | paths = [] 232 | for line in lines: 233 | info = line.strip().split() 234 | key = info[-1] 235 | if "/" in key: 236 | key = key.split("/")[-1] 237 | # key = int(key.split(".")[0]) 238 | value = info[:-1] 239 | paths.append(key) 240 | value = [float(it) for it in value] 241 | if len(info) == 106 * 2 + 1: # landmark+name 242 | value = [float(it) for it in info[:106 * 2]] 243 | elif len(info) == 106 * 2 + 1 + 6: # affmat+landmark+name 244 | value = [float(it) for it in info[6:106 * 2 + 6]] 245 | elif len(info) == 20 * 2 + 2: # mouth landmark+name 246 | value = [float(it) for it in info[:-1]] 247 | ldmk_dict[key] = value 248 | return ldmk_dict, paths 249 | 250 | 251 | def get_affine(src, dst): 252 | tform = trans.SimilarityTransform() 253 | tform.estimate(src, dst) 254 | M = tform.params[0:2, :] 255 | return M 256 | 257 | def affine_align_img(img, M, crop_size=224): 258 | warped = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 259 | return warped 260 | 261 | def calc_loop_idx(idx, loop_num): 262 | flag = -1 * ((idx // loop_num % 2) * 2 - 1) 263 | new_idx = -flag * (flag - 1) // 2 + flag * (idx % loop_num) 264 | return (new_idx + loop_num) % loop_num 265 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ntpath 3 | import time 4 | from . import util 5 | from . import html 6 | import scipy.misc 7 | import torch 8 | import torchvision.utils as vutils 9 | from torch.utils.tensorboard import SummaryWriter 10 | try: 11 | from StringIO import StringIO # Python 2.7 12 | except ImportError: 13 | from io import BytesIO # Python 3.x 14 | 15 | class Visualizer(): 16 | def __init__(self, opt): 17 | self.opt = opt 18 | self.tf_log = opt.isTrain and opt.tf_log 19 | self.tensorboard = opt.isTrain and opt.tensorboard 20 | self.use_html = opt.isTrain and not opt.no_html 21 | self.win_size = opt.display_winsize 22 | self.name = opt.name 23 | if self.tf_log: 24 | import tensorflow as tf 25 | self.tf = tf 26 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 27 | self.writer = tf.summary.FileWriter(self.log_dir) 28 | 29 | if self.tensorboard: 30 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 31 | self.writer = SummaryWriter(self.log_dir, comment=opt.name) 32 | 33 | if self.use_html: 34 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 35 | self.img_dir = os.path.join(self.web_dir, 'images') 36 | print('create web directory %s...' % self.web_dir) 37 | util.mkdirs([self.web_dir, self.img_dir]) 38 | if opt.isTrain: 39 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 40 | with open(self.log_name, "a") as log_file: 41 | now = time.strftime("%c") 42 | log_file.write('================ Training Loss (%s) ================\n' % now) 43 | 44 | # |visuals|: dictionary of images to display or save 45 | def display_current_results(self, visuals, epoch, step): 46 | 47 | ## convert tensors to numpy arrays 48 | 49 | 50 | if self.tf_log: # show images in tensorboard output 51 | img_summaries = [] 52 | visuals = self.convert_visuals_to_numpy(visuals) 53 | for label, image_numpy in visuals.items(): 54 | # Write the image to a string 55 | try: 56 | s = StringIO() 57 | except: 58 | s = BytesIO() 59 | if len(image_numpy.shape) >= 4: 60 | image_numpy = image_numpy[0] 61 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 62 | # Create an Image object 63 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 64 | # Create a Summary value 65 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 66 | 67 | # Create and write Summary 68 | summary = self.tf.Summary(value=img_summaries) 69 | self.writer.add_summary(summary, step) 70 | 71 | if self.tensorboard: # show images in tensorboard output 72 | img_summaries = [] 73 | for label, image_numpy in visuals.items(): 74 | # Write the image to a string 75 | try: 76 | s = StringIO() 77 | except: 78 | s = BytesIO() 79 | # if len(image_numpy.shape) >= 4: 80 | # image_numpy = image_numpy[0] 81 | # scipy.misc.toimage(image_numpy).save(s, format="jpeg") 82 | # Create an Image object 83 | # self.writer.add_image(tag=label, img_tensor=image_numpy, global_step=step, dataformats='HWC') 84 | # Create a Summary value 85 | batch_size = image_numpy.size(0) 86 | x = vutils.make_grid(image_numpy[:min(batch_size, 16)], normalize=True, scale_each=True) 87 | self.writer.add_image(label, x, step) 88 | 89 | 90 | if self.use_html: # save images to a html file 91 | for label, image_numpy in visuals.items(): 92 | if isinstance(image_numpy, list): 93 | for i in range(len(image_numpy)): 94 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) 95 | util.save_image(image_numpy[i], img_path) 96 | else: 97 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) 98 | if len(image_numpy.shape) >= 4: 99 | image_numpy = image_numpy[0] 100 | util.save_image(image_numpy, img_path) 101 | 102 | # update website 103 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 104 | for n in range(epoch, 0, -1): 105 | webpage.add_header('epoch [%d]' % n) 106 | ims = [] 107 | txts = [] 108 | links = [] 109 | 110 | for label, image_numpy in visuals.items(): 111 | if isinstance(image_numpy, list): 112 | for i in range(len(image_numpy)): 113 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) 114 | ims.append(img_path) 115 | txts.append(label+str(i)) 116 | links.append(img_path) 117 | else: 118 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) 119 | ims.append(img_path) 120 | txts.append(label) 121 | links.append(img_path) 122 | if len(ims) < 10: 123 | webpage.add_images(ims, txts, links, width=self.win_size) 124 | else: 125 | num = int(round(len(ims)/2.0)) 126 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 127 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 128 | webpage.save() 129 | 130 | # errors: dictionary of error labels and values 131 | def plot_current_errors(self, errors, step): 132 | if self.tf_log: 133 | for tag, value in errors.items(): 134 | value = value.mean().float() 135 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 136 | self.writer.add_summary(summary, step) 137 | 138 | if self.tensorboard: 139 | for tag, value in errors.items(): 140 | value = value.mean().float() 141 | self.writer.add_scalar(tag=tag, scalar_value=value, global_step=step) 142 | 143 | # errors: same format as |errors| of plotCurrentErrors 144 | def print_current_errors(self, opt, epoch, i, errors, t): 145 | message = opt.name + ' (epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 146 | for k, v in errors.items(): 147 | #print(v) 148 | #if v != 0: 149 | v = v.mean().float() 150 | message += '%s: %.3f ' % (k, v) 151 | 152 | print(message) 153 | with open(self.log_name, "a") as log_file: 154 | log_file.write('%s\n' % message) 155 | 156 | def convert_visuals_to_numpy(self, visuals): 157 | for key, t in visuals.items(): 158 | tile = self.opt.batchSize > 8 159 | if 'input_label' == key: 160 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) 161 | else: 162 | t = util.tensor2im(t, tile=tile) 163 | visuals[key] = t 164 | return visuals 165 | 166 | # save image to the disk 167 | def save_images(self, webpage, visuals, image_path): 168 | visuals = self.convert_visuals_to_numpy(visuals) 169 | 170 | image_dir = webpage.get_image_dir() 171 | short_path = ntpath.basename(image_path[0]) 172 | name = os.path.splitext(short_path)[0] 173 | 174 | webpage.add_header(name) 175 | ims = [] 176 | txts = [] 177 | links = [] 178 | 179 | for label, image_numpy in visuals.items(): 180 | image_name = os.path.join(label, '%s.png' % (name)) 181 | save_path = os.path.join(image_dir, image_name) 182 | util.save_image(image_numpy, save_path, create_dir=True) 183 | 184 | ims.append(image_name) 185 | txts.append(label) 186 | links.append(image_name) 187 | webpage.add_images(ims, txts, links, width=self.win_size) 188 | --------------------------------------------------------------------------------