├── .gitignore ├── .idea └── .gitignore ├── LICENSE ├── README.md ├── bin ├── f_sync_score.py ├── gfpgan_trans.py ├── inference.py ├── inference.sh ├── multi_process.py ├── parse_options.sh ├── processData.py ├── process_data.sh ├── score.sh ├── score_video.py ├── train_syncnet.sh └── train_wav2lip.sh ├── configs ├── train_config.yaml └── train_config_96.yaml ├── data ├── original_data │ └── README.md ├── preProcessed_data │ └── README.md ├── pretrain_model │ └── README ├── syncnet_checkpoint │ └── README.md └── test_data │ ├── input │ ├── .gitignore │ └── test.jpg │ └── pr_data │ └── README.md ├── inference_util ├── FaceParseModel.py ├── InferenceUtil.py ├── README.md ├── __init__.py ├── resnet.py └── swap.py ├── models ├── BaseConv2D.py ├── BaseTranspose.py ├── Discriminator.py ├── FaceCreator.py ├── NoNormConv.py ├── README.md ├── SyncNetModel.py ├── SyncNetModel_P.py └── __init__.py ├── process_util ├── DataProcessor.py ├── FaceDetector.py ├── ParamsUtil.py ├── PreProcessor.py ├── SyncnetPScore.py ├── SyncnetScore.py ├── VideoExtractor.py └── __init__.py ├── requirement.txt ├── testcase ├── dataProcessTest.py ├── discriminatorTest.py ├── faceEncodeTest.py ├── faceTest.py ├── melspecTest.py ├── paramsUtilsTest.py ├── processorTest.py ├── syncnetTrainTest.py ├── syncnetmodelTest.py ├── testFaceDataset.py ├── testSyncDataset.py ├── videoExtractorTest.py └── w2ltrainTest.py ├── trains ├── README.md ├── syncnet_train.py └── wl_train.py └── wldatasets ├── FaceDataset.py ├── SyncNetDataset.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | /data/test_data/input/ 163 | /data/test_data/pr_data/cctvm0000001/ 164 | /data/test_data/pr_data/eval.txt 165 | /data/test_data/pr_data/train.txt 166 | /data/test_data/output/ 167 | /data/pretrain_model/GFPGANv1.3.pth 168 | /data/pretrain_model/RealESRGAN_x2plus.pth 169 | /data/pretrain_model/RestoreFormer.pth 170 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wav2lipup 2 | 3 | optimized wav2lip 4 | sentences 5 | start end text -------------------------------------------------------------------------------- /bin/f_sync_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | 5 | def process_file_list(data_root): 6 | filelist = [] 7 | for file in Path(data_root).glob('*/*.mp4'): 8 | if file.is_file(): 9 | filelist.append(str(file)) 10 | with open(data_root + '/' + 'process.txt', 'w') as f: 11 | f.write('\n'.join(filelist)) 12 | return filelist 13 | 14 | 15 | def continues_file_list(data_root): 16 | full_txt = data_root + '/' + 'process.txt' 17 | full_list = get_list(full_txt) 18 | processed_txt = data_root + '/' + 'processed.txt' 19 | processed_list = get_list(processed_txt) 20 | 21 | p_list = clear_pv(full_list, processed_list) 22 | 23 | return p_list 24 | 25 | 26 | def get_list(inputText): 27 | list = [] 28 | with open(inputText, 'r') as f: 29 | for line in f: 30 | line = line.strip() 31 | list.append(line) 32 | return list 33 | 34 | 35 | def clear_pv(all_list, exclude_list): 36 | for item in exclude_list: 37 | if item in all_list: 38 | all_list.remove(item) 39 | 40 | return all_list 41 | 42 | 43 | def main(): 44 | args = parse_args() 45 | data_root = args.data_root 46 | break_p = args.break_point 47 | 48 | # 处理的视频文件做成list文本 49 | if break_p == 0: 50 | process_list = process_file_list(data_root) 51 | else: 52 | process_list = continues_file_list(data_root) 53 | 54 | 55 | def parse_args(): 56 | # parse args and config 57 | parser = argparse.ArgumentParser( 58 | description="score for the video") 59 | parser.add_argument("--video_root", help='Root folder of the dataset', required=True, type=str) 60 | parser.add_argument('--init_model', help='Load the pre-trained ', required=True, 61 | default='../data/pre-models/syncnet_v2.model') 62 | parser.add_argument('--num_worker', help='multiprocessor number', default=6, type=int) 63 | parser.add_argument('--batch_size', help='produce img batch', default=20, type=int) 64 | parser.add_argument('--beark_point', help='score continus from beak point', default=0, type=int) 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /bin/gfpgan_trans.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from basicsr.utils import imwrite 9 | from gfpgan.utils import GFPGANer 10 | 11 | 12 | def main(): 13 | """Inference demo for GFPGAN (for users). 14 | """ 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | '-i', 18 | '--input', 19 | type=str, 20 | default='inputs/whole_imgs', 21 | help='Input image or folder. Default: inputs/whole_imgs') 22 | parser.add_argument('-o', '--output', type=str, default='results', help='Output folder. Default: results') 23 | # we use version to select models, which is more user-friendly 24 | parser.add_argument( 25 | '-v', '--version', type=str, default='1.3', help='GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3') 26 | parser.add_argument( 27 | '-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2') 28 | 29 | parser.add_argument( 30 | '--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan') 31 | parser.add_argument( 32 | '--bg_tile', 33 | type=int, 34 | default=400, 35 | help='Tile size for background sampler, 0 for no tile during testing. Default: 400') 36 | parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') 37 | parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face') 38 | parser.add_argument('--aligned', action='store_true', help='Input are aligned faces') 39 | parser.add_argument( 40 | '--ext', 41 | type=str, 42 | default='auto', 43 | help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto') 44 | parser.add_argument('-w', '--weight', type=float, default=0.5, help='Adjustable weights.') 45 | args = parser.parse_args() 46 | 47 | args = parser.parse_args() 48 | 49 | # ------------------------ input & output ------------------------ 50 | if args.input.endswith('/'): 51 | args.input = args.input[:-1] 52 | if os.path.isfile(args.input): 53 | img_list = [args.input] 54 | else: 55 | img_list = sorted(glob.glob(os.path.join(args.input, '*'))) 56 | 57 | os.makedirs(args.output, exist_ok=True) 58 | 59 | # ------------------------ set up background upsampler ------------------------ 60 | if args.bg_upsampler == 'realesrgan': 61 | if not torch.cuda.is_available(): # CPU 62 | import warnings 63 | warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' 64 | 'If you really want to use it, please modify the corresponding codes.') 65 | bg_upsampler = None 66 | else: 67 | from basicsr.archs.rrdbnet_arch import RRDBNet 68 | from realesrgan import RealESRGANer 69 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) 70 | bg_upsampler = RealESRGANer( 71 | scale=2, 72 | model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 73 | model=model, 74 | tile=args.bg_tile, 75 | tile_pad=10, 76 | pre_pad=0, 77 | half=True) # need to set False in CPU mode 78 | else: 79 | bg_upsampler = None 80 | 81 | # ------------------------ set up GFPGAN restorer ------------------------ 82 | if args.version == '1': 83 | arch = 'original' 84 | channel_multiplier = 1 85 | model_name = 'GFPGANv1' 86 | url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth' 87 | elif args.version == '1.2': 88 | arch = 'clean' 89 | channel_multiplier = 2 90 | model_name = 'GFPGANCleanv1-NoCE-C2' 91 | url = 'https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth' 92 | elif args.version == '1.3': 93 | arch = 'clean' 94 | channel_multiplier = 2 95 | model_name = 'GFPGANv1.3' 96 | url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' 97 | elif args.version == '1.4': 98 | arch = 'clean' 99 | channel_multiplier = 2 100 | model_name = 'GFPGANv1.4' 101 | url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' 102 | elif args.version == 'RestoreFormer': 103 | arch = 'RestoreFormer' 104 | channel_multiplier = 2 105 | model_name = 'RestoreFormer' 106 | url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' 107 | else: 108 | raise ValueError(f'Wrong model version {args.version}.') 109 | 110 | # determine model paths 111 | model_path = os.path.join('experiments/pretrained_models', model_name + '.pth') 112 | if not os.path.isfile(model_path): 113 | model_path = os.path.join('gfpgan/weights', model_name + '.pth') 114 | if not os.path.isfile(model_path): 115 | # download pre-trained models from url 116 | model_path = url 117 | 118 | restorer = GFPGANer( 119 | model_path=model_path, 120 | upscale=args.upscale, 121 | arch=arch, 122 | channel_multiplier=channel_multiplier, 123 | bg_upsampler=bg_upsampler) 124 | 125 | # ------------------------ restore ------------------------ 126 | for img_path in img_list: 127 | # read image 128 | img_name = os.path.basename(img_path) 129 | print(f'Processing {img_name} ...') 130 | basename, ext = os.path.splitext(img_name) 131 | input_img = cv2.imread(img_path, cv2.IMREAD_COLOR) 132 | 133 | # restore faces and background if necessary 134 | cropped_faces, restored_faces, restored_img = restorer.enhance( 135 | input_img, 136 | has_aligned=args.aligned, 137 | only_center_face=args.only_center_face, 138 | paste_back=True, 139 | weight=args.weight) 140 | 141 | # save faces 142 | for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)): 143 | # save cropped face 144 | save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png') 145 | imwrite(cropped_face, save_crop_path) 146 | # save restored face 147 | if args.suffix is not None: 148 | save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png' 149 | else: 150 | save_face_name = f'{basename}_{idx:02d}.png' 151 | save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name) 152 | imwrite(restored_face, save_restore_path) 153 | # save comparison image 154 | cmp_img = np.concatenate((cropped_face, restored_face), axis=1) 155 | imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png')) 156 | 157 | # save restored img 158 | if restored_img is not None: 159 | if args.ext == 'auto': 160 | extension = ext[1:] 161 | else: 162 | extension = args.ext 163 | 164 | if args.suffix is not None: 165 | save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}') 166 | else: 167 | save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}') 168 | imwrite(restored_img, save_restore_path) 169 | 170 | print(f'Results are in the [{args.output}] folder.') 171 | 172 | 173 | if __name__ == '__main__': 174 | main() 175 | -------------------------------------------------------------------------------- /bin/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from basicsr import imwrite 10 | from basicsr.archs.rrdbnet_arch import RRDBNet 11 | from gfpgan.utils import GFPGANer 12 | from realesrgan import RealESRGANer 13 | from torch import nn 14 | from tqdm import tqdm 15 | 16 | from inference_util import init_parser, swap_regions 17 | from inference_util.InferenceUtil import InferenceUtil 18 | from models.FaceCreator import FaceCreator 19 | from process_util.ParamsUtil import ParamsUtil 20 | 21 | hp = ParamsUtil() 22 | mel_step_size = 16 23 | pretrain_model_url = { 24 | 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', 25 | } 26 | 27 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 28 | 29 | 30 | class MyDataParallel(nn.DataParallel): 31 | def __getattr__(self, name): 32 | try: 33 | return super().__getattr__(name) 34 | except AttributeError: 35 | return getattr(self.module, name) 36 | 37 | 38 | def parse_args(): 39 | # parse args and config 40 | parser = argparse.ArgumentParser(description="inference video") 41 | parser.add_argument('--checkpoint_path', type=str, help='Name of saved checkpoint to load weights from', 42 | required=True) 43 | parser.add_argument('--gfpgan_checkpoint', type=str, help='Name of saved checkpoint to load weights from', 44 | required=True) 45 | parser.add_argument('--esrgan_checkpoint', type=str, help='Name of saved checkpoint to load weights from', 46 | required=True) 47 | parser.add_argument('--segmentation_path', type=str, 48 | help='Name of saved checkpoint of segmentation network', required=True) 49 | parser.add_argument('--face', type=str, help='Filepath of video/image that contains faces to use', 50 | required=True) 51 | parser.add_argument('--audio', type=str, 52 | help='Filepath of video/audio file to use as raw audio source', required=True) 53 | parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', 54 | default='results/result_voice.mp4') 55 | parser.add_argument('--static', type=bool, 56 | help='If True, then use only first video frame for inference', default=False) 57 | parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', 58 | default=25., required=False) 59 | parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], 60 | help='Padding (top, bottom, left, right). Please adjust to include chin at least') 61 | 62 | parser.add_argument('--face_det_batch_size', type=int, 63 | help='Batch size for face detection', default=64) 64 | parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=64) 65 | parser.add_argument('--resize_factor', default=1, type=int, 66 | help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') 67 | parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], 68 | help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' 69 | 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') 70 | 71 | parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], 72 | help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' 73 | 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') 74 | parser.add_argument('--rotate', default=False, action='store_true', 75 | help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' 76 | 'Use if you get a flipped result, despite feeding a normal looking video') 77 | parser.add_argument('--nosmooth', default=False, action='store_true', 78 | help='Prevent smoothing face detections over a short temporal window') 79 | parser.add_argument('--no_segmentation', default=False, action='store_true', 80 | help='Prevent using face segmentation') 81 | parser.add_argument( 82 | '-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2') 83 | parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face') 84 | parser.add_argument('--aligned', action='store_true', help='Input are aligned faces') 85 | parser.add_argument( 86 | '--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan') 87 | parser.add_argument( 88 | '--bg_tile', 89 | type=int, 90 | default=400, 91 | help='Tile size for background sampler, 0 for no tile during testing. Default: 400') 92 | parser.add_argument('-w', '--weight', type=float, default=0.5, help='Adjustable weights.') 93 | 94 | return parser.parse_args() 95 | 96 | 97 | def _load(checkpoint_path): 98 | if device == 'cuda': 99 | checkpoint = torch.load(checkpoint_path) 100 | else: 101 | checkpoint = torch.load(checkpoint_path, 102 | map_location=lambda storage, loc: storage) 103 | return checkpoint 104 | 105 | 106 | def load_model(path): 107 | model = FaceCreator() 108 | print("Load checkpoint from: {}".format(path)) 109 | checkpoint = _load(path) 110 | s = checkpoint["state_dict"] 111 | new_s = {} 112 | for k, v in s.items(): 113 | new_s[k.replace('module.', '')] = v 114 | model.load_state_dict(new_s) 115 | 116 | model = model.to(device) 117 | return model.eval() 118 | 119 | 120 | def set_realesrgan(args): 121 | use_half = False 122 | if torch.cuda.is_available(): # set False in CPU/MPS mode 123 | no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16 124 | if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]: 125 | use_half = True 126 | model_net = RRDBNet( 127 | num_in_ch=3, 128 | num_out_ch=3, 129 | num_feat=64, 130 | num_block=23, 131 | num_grow_ch=32, 132 | scale=2, 133 | ) 134 | 135 | up_sampler = RealESRGANer( 136 | scale=2, 137 | model_path=args.esrgan_checkpoint, 138 | model=model_net, 139 | tile=args.bg_tile, 140 | tile_pad=10, 141 | pre_pad=0, 142 | half=use_half 143 | ) 144 | 145 | if not torch.cuda.is_available(): # CPU 146 | import warnings 147 | warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.' 148 | 'The unoptimized RealESRGAN is slow on CPU. ' 149 | 'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.', 150 | category=RuntimeWarning) 151 | return up_sampler 152 | 153 | 154 | def main(): 155 | args = parse_args() 156 | face_src = args.face 157 | fps = args.fps 158 | infer_util = InferenceUtil(args.fps, args.checkpoint_path) 159 | full_frames = [] 160 | if not Path(face_src).exists(): 161 | print('--face argument No such file or directory: {}'.format(face_src)) 162 | extr_v = '' 163 | elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg'] and Path(args.face).is_file(): 164 | full_frames = [cv2.imread(args.face)] 165 | fps = args.fps 166 | else: 167 | ex_dir = Path(args.face).parent.as_posix() 168 | tmp_name = Path(args.face).stem +"_25.mp4" 169 | ffmpeg_cmd = "ffmpeg -loglevel error -y -i {0} -vf scale=720:1280 yy-async 1 -r {1} {2}/{3}".format(args.face, args.fps, ex_dir,tmp_name) 170 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 171 | args.face = ex_dir+'/'+tmp_name 172 | video_stream = cv2.VideoCapture(args.face) 173 | fps = video_stream.get(cv2.CAP_PROP_FPS) 174 | 175 | print('-------Reading video frames-------') 176 | while True: 177 | success, frame = video_stream.read() 178 | if not success: 179 | video_stream.release() 180 | break 181 | if args.resize_factor > 1: 182 | frame = cv2.resize(frame, (frame.shape[1] // args.resize_factor, frame.shape[0] // args.resize_factor)) 183 | 184 | if args.rotate: 185 | frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE) 186 | y1, y2, x1, x2 = args.crop 187 | if x2 == -1: x2 = frame.shape[1] 188 | if y2 == -1: y2 = frame.shape[0] 189 | 190 | frame = frame[y1:y2, x1:x2] 191 | full_frames.append(frame) 192 | 193 | # extr_v = infer_util.src_video_process(face_src) 194 | # frame_names = infer_util.load_img_names(extr_v) 195 | # args.img_path = extr_v 196 | # full_frames = infer_util.get_frames(extr_v,frame_names) 197 | print('Number of frames for inference: {}'.format(len(full_frames))) 198 | """faces = infer_util.faces_detect(extr_v, args.face_det_batch_size) 199 | if faces is None: 200 | raise ValueError('No faces in video!Face not detected! Ensure the video contains a face in all the frames.') 201 | print('Number of frames and faces for inference: {}'.format(len(faces)))""" 202 | 203 | if not args.audio.endswith('.wav'): 204 | args.audio = infer_util.extract_audio(args.audio) 205 | 206 | mel = infer_util.get_mel(args.audio) 207 | print('The mel shape is {}'.format(mel.shape)) 208 | 209 | if np.isnan(mel.reshape(-1)).sum() > 0: 210 | raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 211 | 212 | mel_chunks = [] 213 | mel_idx_multiplier = 80. / fps 214 | i = 0 215 | while 1: 216 | start_idx = int(i * mel_idx_multiplier) 217 | if start_idx + mel_step_size > len(mel[0]): 218 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) 219 | break 220 | mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) 221 | i += 1 222 | 223 | print("Length of mel chunks: {}".format(len(mel_chunks))) 224 | 225 | full_frames = full_frames[:len(mel_chunks)] 226 | batch_size = args.wav2lip_batch_size 227 | gen = infer_util.gen_data(full_frames.copy(), mel_chunks, args) 228 | print("Model loading {}".format(args.checkpoint_path)) 229 | model = load_model(args.checkpoint_path) 230 | print("Model loaded") 231 | print("Loading segementation network...") 232 | seg_net = init_parser(args.segmentation_path) 233 | print("Model loaded") 234 | 235 | img_path = Path(args.outfile).parent.as_posix() 236 | output_path = img_path + '/output' 237 | 238 | j = 0 239 | prog_bar = tqdm(enumerate(gen), total=int(np.ceil(float(len(mel_chunks)) / batch_size)), leave=False) 240 | for i, (img_batch, mel_batch, frames, coords) in prog_bar: 241 | if i == 0: 242 | 243 | arch = 'clean' 244 | channel_multiplier = 2 245 | 246 | frame_h, frame_w = frames[0].shape[:-1] 247 | out = cv2.VideoWriter(img_path + '/result.avi', cv2.VideoWriter_fourcc(*'XVID'), fps, (frame_w, frame_h)) 248 | if args.bg_upsampler == 'realesrgan': 249 | bg_upsampler = set_realesrgan(args) 250 | else: 251 | bg_upsampler = None 252 | 253 | gfpgan = GFPGANer(model_path=args.gfpgan_checkpoint, upscale=1, 254 | arch=arch, channel_multiplier=channel_multiplier, 255 | bg_upsampler=bg_upsampler, device=device) 256 | 257 | img_batch = torch.tensor(np.transpose(img_batch, (0, 3, 1, 2)), dtype=torch.float).to(device) 258 | mel_batch = torch.tensor(np.transpose(mel_batch, (0, 3, 1, 2)), dtype=torch.float).to(device) 259 | 260 | print("batch write message:", len(img_batch), len(frames), len(coords)) 261 | # cuda_ids = [int(d_id) for d_id in os.environ.get('CUDA_VISIBLE_DEVICES').split(',')] 262 | with torch.no_grad(): 263 | # pred = MyDataParallel(model(mel_batch, img_batch), device_ids=cuda_ids) 264 | pred = model(mel_batch, img_batch) 265 | 266 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255 267 | 268 | for p, f, c in zip(pred, frames, coords): 269 | j += 1 270 | y1, y2, x1, x2 = c 271 | p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) 272 | 273 | if not args.no_segmentation: 274 | p = swap_regions(f[y1:y2, x1:x2], p, seg_net) 275 | head_h, head_w, _ = p.shape 276 | f[y1:y2, x1:x2] = p 277 | else: 278 | head_h, head_w, _ = p.shape 279 | width_cut = int(head_w * 0.2) 280 | f[y1:y2, x1 + width_cut:x2 - width_cut] = p[:, width_cut:head_w - width_cut] 281 | 282 | cf, rf, ri = gfpgan.enhance(f, 283 | has_aligned=args.aligned, 284 | only_center_face=args.only_center_face, 285 | paste_back=True, 286 | weight=args.weight) 287 | 288 | hd_f = np.clip(ri, 0, 255).astype(np.uint8) 289 | 290 | imwrite(f, output_path + '/{}_src.jpg'.format(j)) 291 | #imwrite(hd_f, output_path + '/{}.jpg'.format(j)) 292 | out.write(hd_f) 293 | 294 | out.release() 295 | 296 | # process_hd_video(args.img_path + '/result.avi', args) 297 | command = 'ffmpeg -f image2 -i {}/{}_src.jpg -tag:v DIVX {}'.format(output_path, '%d', img_path + '/result_src.avi') 298 | output = subprocess.call(command, shell=True, stdout=None) 299 | outfile_2 = Path(args.outfile).parent.as_posix()+'/output_src.mp4' 300 | command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 301 | img_path + '/result_src.avi', 302 | outfile_2) 303 | output = subprocess.call(command, shell=True, stdout=None) 304 | 305 | command = 'ffmpeg -loglevel error -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 306 | img_path + '/result.avi', 307 | args.outfile) 308 | output = subprocess.call(command, shell=True, stdout=None) 309 | print('Finish processed {}'.format(args.outfile)) 310 | 311 | 312 | if __name__ == '__main__': 313 | main() 314 | -------------------------------------------------------------------------------- /bin/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=0 4 | 5 | CUDA_VISIBLE_DEVICES=${gpus} python inference.py \ 6 | --esrgan_checkpoint=../data/pretrain_model/RealESRGAN_x2plus.pth \ 7 | --gfpgan_checkpoint=../data/pretrain_model/GFPGANv1.3.pth \ 8 | --segmentation_path=../data/pretrain_model/segments.pth \ 9 | --checkpoint_path=../data/checkpoint/checkpoint_step000339000.pth \ 10 | --face=../data/temp/self-1.mp4 \ 11 | --bg_tile=0 \ 12 | --audio=../data/temp/audio-1.wav \ 13 | --outfile=../data/temp/output.mp4 -------------------------------------------------------------------------------- /bin/multi_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | from tqdm import tqdm 7 | 8 | from process_util.DataProcessor import DataProcessor 9 | 10 | dataProcessor = DataProcessor() 11 | 12 | 13 | def get_processed_files(inputdir, outputdir): 14 | files = [] 15 | for f in Path.glob(Path(inputdir), '**/*.mp4'): 16 | if f.is_file(): 17 | files.append(f.as_posix()) 18 | files.sort() 19 | total_files = len(files) 20 | print('total files to processed:{}'.format(total_files)) 21 | 22 | dones = get_processed_data(outputdir) 23 | done_files = [] 24 | if dones is not None and len(dones) > 0: 25 | print('break point continue!') 26 | for done_file in dones: 27 | d_root = Path(done_file).parts[-2] 28 | d_name = Path(done_file).parts[-1] 29 | d_s = d_name.split('_') 30 | d_d = d_s[-3] 31 | d_f = d_s[-2] + '_' + d_s[-1] + ".mp4" 32 | d_full = inputdir + '/' + d_root + '/' + d_d + '/' + d_f 33 | done_files.append(d_full) 34 | done_bar = tqdm(enumerate(done_files), total=len(done_files), leave=False) 35 | for item in done_files: 36 | files.remove(item) 37 | done_bar.set_description('produce break point!{}'.format(item)) 38 | return files 39 | 40 | 41 | def get_processed_data(processed_data_root): 42 | done_dir = [] 43 | for done in Path.glob(Path(processed_data_root), '*/*'): 44 | if done.is_dir(): 45 | done_dir.append(done) 46 | return done_dir 47 | 48 | 49 | def process_data(inputdir, outputdir): 50 | files = get_processed_files(inputdir, outputdir) 51 | 52 | proc_f = partial(dataProcessor.processVideoFile, processed_data_root=outputdir) 53 | 54 | num_p = 4 55 | pool = multiprocessing.Pool(num_p) 56 | pool.map(proc_f, files) 57 | pool.close() 58 | 59 | 60 | def parse_args(): 61 | # parse args and config 62 | parser = argparse.ArgumentParser( 63 | description="process the datasets for wav2lip") 64 | parser.add_argument("--data_root", help='Root folder of the preprocessed dataset', required=True, type=str) 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | 70 | def main(): 71 | args = parse_args() 72 | data_root = args.data_root 73 | original_dir = data_root + '/original_data' 74 | preProcess_dir = data_root + '/preProcessed_data' 75 | process_dir = data_root + '/processed_data' 76 | 77 | process_data(preProcess_dir, process_dir) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /bin/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### No we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. -------------------------------------------------------------------------------- /bin/processData.py: -------------------------------------------------------------------------------- 1 | """ 2 | 这个主要处理原始视频到相应的目录下 3 | 根目录的主要结构 4 | ./root 5 | ├── original_data #这个是原始数据存放的目录 6 | ├── preProcessed_data #切割成5s的存放目录 7 | ├── processed_data #处理成训练所需的文件存放的目录 8 | 9 | 10 | """ 11 | import argparse 12 | import logging 13 | from functools import partial 14 | from pathlib import Path 15 | 16 | from sklearn.model_selection import train_test_split 17 | from torch import multiprocessing 18 | from tqdm import tqdm 19 | 20 | from process_util.DataProcessor import DataProcessor 21 | from process_util.PreProcessor import PreProcessor 22 | 23 | 24 | def orignal_process(inputdir): 25 | dirs = [] 26 | root = Path(inputdir) 27 | # 把目录下的子目录都拿出来 28 | for dir in Path.rglob(root, '*'): 29 | if dir.is_dir(): 30 | dirs.append(str(dir)) 31 | # 重命名所有文件,把文件转换成6位数字的文件名,格式为000000.MP4 32 | temp_dir = inputdir + '/temp' 33 | Path(temp_dir).mkdir(parents=True, exist_ok=True) 34 | for i, dir in tqdm(enumerate(dirs), desc='process the original datasets:', total=len(dirs), unit='video'): 35 | j = 0 36 | files = [] 37 | for f in Path.glob(Path(dir), '**/*.mp4'): 38 | j = j + 1 39 | newfilename = temp_dir + '/{0:04}.mp4'.format(j) 40 | files.append(newfilename) 41 | Path.rename(Path(str(f)), Path(newfilename)) 42 | for nf in files: 43 | fname = dir + '/' + Path(nf).name 44 | Path.rename(Path(nf), Path(fname)) 45 | 46 | Path(temp_dir).rmdir() 47 | 48 | 49 | def preProcess(inputdir, outputdir, preprocess_type): 50 | processer = PreProcessor() 51 | videos = get_video_fils(inputdir, 'mp4') 52 | probar = tqdm(enumerate(videos), total=len(videos), leave=False) 53 | for i, video in probar: 54 | if preprocess_type == 'Time': 55 | processer.videosPreProcessByTime(video, 56 | s_time=5, 57 | input_dir=inputdir, 58 | output_dir=outputdir, 59 | ) 60 | else: 61 | processer.videosPreProcessByASR(video, 62 | input_dir=inputdir, 63 | output_dir=outputdir) 64 | probar.set_description('Processed video: {}'.format(video)) 65 | 66 | 67 | def get_processed_files(inputdir, outputdir): 68 | files = [] 69 | for f in Path.glob(Path(inputdir), '**/*.mp4'): 70 | if f.is_file(): 71 | files.append(f.as_posix()) 72 | files.sort() 73 | total_files = len(files) 74 | print('total files to processed:{}'.format(total_files)) 75 | 76 | dones = get_processed_data(outputdir) 77 | done_files = [] 78 | if dones is not None and len(dones) > 0: 79 | print('break point continue!') 80 | for done_file in dones: 81 | d_root = Path(done_file).parts[-2] 82 | d_name = Path(done_file).parts[-1] 83 | d_s = d_name.split('_') 84 | d_d = d_s[-3] 85 | d_f = d_s[-2] + '_' + d_s[-1] + ".mp4" 86 | d_full = inputdir + '/' + d_root + '/' + d_d + '/' + d_f 87 | done_files.append(d_full) 88 | done_bar = tqdm(enumerate(done_files), total=len(done_files), leave=False) 89 | for item in done_files: 90 | files.remove(item) 91 | done_bar.set_description('produce break point!{}'.format(item)) 92 | return files 93 | 94 | 95 | def process_data(inputdir, outputdir): 96 | dataProcessor = DataProcessor() 97 | results = [] 98 | files = get_processed_files(inputdir, outputdir) 99 | proc_f = partial(dataProcessor.processVideoFile, processed_data_root=outputdir) 100 | 101 | num_p = int(multiprocessing.cpu_count() - 2) 102 | ctx = multiprocessing.get_context('spawn') 103 | pool = ctx.Pool(num_p) 104 | prog_bar = tqdm(pool.imap(proc_f, files), total=len(files)) 105 | for result in prog_bar: 106 | results.append(result) 107 | pool.close() 108 | return results 109 | 110 | 111 | def get_processed_data(processed_data_root): 112 | done_dir = [] 113 | for done in Path.glob(Path(processed_data_root), '*/*'): 114 | if done.is_dir(): 115 | done_dir.append(done) 116 | return done_dir 117 | 118 | 119 | def train_file_write(inputdir): 120 | train_txt = inputdir + '/train.txt' 121 | eval_txt = inputdir + '/eval.txt' 122 | Path(train_txt).write_text('') 123 | Path(eval_txt).write_text('') 124 | result_list = [] 125 | for line in Path.glob(Path(inputdir), '*/*'): 126 | if line.is_dir(): 127 | dirs = line.parts 128 | input_line = str(dirs[-2] + '/' + dirs[-1]) 129 | result_list.append(input_line) 130 | if len(result_list) < 14: 131 | test_result = eval_result = train_result = result_list 132 | else: 133 | train_result, test_result = train_test_split(result_list, test_size=0.15, random_state=42, shuffle=True) 134 | test_result, eval_result = train_test_split(test_result, test_size=0.5, random_state=42) 135 | 136 | for file_name, data_set in zip(("train.txt", "test.txt", "eval.txt"), (train_result, test_result, eval_result)): 137 | with open(inputdir + '/' + file_name, 'w', encoding='utf-8') as fi: 138 | fi.write("\n".join(data_set)) 139 | 140 | 141 | def clear_data(inputdir): 142 | train_txt = inputdir + '/train.txt' 143 | test_txt = inputdir + '/test.txt' 144 | eval_txt = inputdir + '/eval.txt' 145 | train_list = get_list(train_txt) 146 | test_list = get_list(test_txt) 147 | eval_list = get_list(eval_txt) 148 | bad_list = [] 149 | for line in tqdm(Path.glob(Path(inputdir), '*/*')): 150 | if line.is_dir(): 151 | imgs = [] 152 | for img in line.glob('**/*.jpg'): 153 | if img.is_file(): 154 | imgs.append(int(img.stem)) 155 | if imgs is None or len(imgs) < 25 or len(imgs) < max(imgs): 156 | print('delete empty or bad video!{}'.format(line)) 157 | dirs = line.parts 158 | bad_line = str(dirs[-2] + '/' + dirs[-1]) 159 | bad_list.append(bad_line) 160 | 161 | train_list = clear_badv(train_list, bad_list) 162 | test_list = clear_badv(test_list, bad_list) 163 | eval_list = clear_badv(eval_list, bad_list) 164 | 165 | with open(inputdir + '/bad_v.txt', 'w', encoding='utf-8') as fw: 166 | fw.write("\n".join(bad_list)) 167 | with open(inputdir + '/train.txt', 'w', encoding='utf-8') as fw: 168 | fw.write("\n".join(train_list)) 169 | with open(inputdir + '/test.txt', 'w', encoding='utf-8') as fw: 170 | fw.write("\n".join(test_list)) 171 | with open(inputdir + '/eval.txt', 'w', encoding='utf-8') as fw: 172 | fw.write("\n".join(eval_list)) 173 | 174 | 175 | def sync_data(inputdir): 176 | train_txt = inputdir + '/train.txt' 177 | test_txt = inputdir + '/test.txt' 178 | eval_txt = inputdir + '/eval.txt' 179 | exclude_txt = inputdir + '/bad_off.txt' 180 | train_list = get_list(train_txt) 181 | test_list = get_list(test_txt) 182 | eval_list = get_list(eval_txt) 183 | exclude_list = get_list(exclude_txt) 184 | train_list = clear_badv(train_list, exclude_list) 185 | test_list = clear_badv(test_list, exclude_list) 186 | eval_list = clear_badv(eval_list, exclude_list) 187 | 188 | with open(inputdir + '/train.txt', 'w', encoding='utf-8') as fw: 189 | fw.write("\n".join(train_list)) 190 | 191 | with open(inputdir + '/test.txt', 'w', encoding='utf-8') as fw: 192 | fw.write("\n".join(test_list)) 193 | 194 | with open(inputdir + '/eval.txt', 'w', encoding='utf-8') as fw: 195 | fw.write("\n".join(eval_list)) 196 | 197 | 198 | def get_list(inputText): 199 | list = [] 200 | with open(inputText, 'r') as f: 201 | for line in f: 202 | line = line.strip() 203 | list.append(line) 204 | return list 205 | 206 | 207 | def clear_badv(all_list, exclude_list): 208 | for item in exclude_list: 209 | if item in all_list: 210 | all_list.remove(item) 211 | 212 | return all_list 213 | 214 | 215 | def get_video_fils(input_dir, type): 216 | inputPath = input_dir 217 | fileType = type 218 | files = [] 219 | for file in Path.glob(Path(inputPath), '**/*.{}'.format(fileType)): 220 | if file.is_file(): 221 | files.append(file) 222 | files.sort() 223 | return files 224 | 225 | 226 | def main(): 227 | logging.basicConfig(level=logging.ERROR) 228 | args = parse_args() 229 | data_root = args.data_root 230 | p_step = args.process_step 231 | preprocess_type = args.preprocess_type 232 | original_dir = data_root + '/original_data' 233 | preProcess_dir = data_root + '/preProcessed_data' 234 | process_dir = data_root + '/processed_data' 235 | 236 | if p_step == 0: 237 | print("produce the step {}".format(p_step)) 238 | orignal_process(original_dir) 239 | elif p_step == 1: 240 | print("produce the step {}".format(p_step)) 241 | preProcess(original_dir, preProcess_dir, preprocess_type) 242 | elif p_step == 2: 243 | print("produce the step {}".format(p_step)) 244 | process_data(preProcess_dir, process_dir) 245 | elif p_step == 3: 246 | print("produce the step {}".format(p_step)) 247 | train_file_write(process_dir) 248 | clear_data(process_dir) 249 | elif p_step == 4: 250 | print("produce the step {}".format(p_step)) 251 | sync_data(process_dir) 252 | else: 253 | print('wrong step number, finished!') 254 | 255 | 256 | def parse_args(): 257 | # parse args and config 258 | parser = argparse.ArgumentParser( 259 | description="process the datasets for wav2lip") 260 | parser.add_argument("--data_root", help='Root folder of the preprocessed dataset', required=True, type=str) 261 | parser.add_argument("--preprocess_type", help='ASR or time split', default='ASR', type=str) 262 | parser.add_argument("--process_step", help='process data\'s step 1 orig,2.pre 3.pro 4.write file', default=0, 263 | type=int) 264 | args = parser.parse_args() 265 | 266 | return args 267 | 268 | 269 | if __name__ == '__main__': 270 | main() 271 | -------------------------------------------------------------------------------- /bin/process_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TEMP=/tmp 4 | stage=0 5 | stop_stage=100 6 | gpus=0 7 | export IMAGEIO_FFMPEG_EXE=/usr/local/bin/ffmpeg 8 | export IMAGEIO_USE_GPU=True 9 | 10 | source ./parse_options.sh || exit 1 11 | 12 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 13 | CUDA_VISIBLE_DEVICES=${gpus} python processData.py --data_root=../data \ 14 | --process_step=0 || exit -1 15 | fi 16 | 17 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 18 | CUDA_VISIBLE_DEVICES=${gpus} python processData.py --data_root=../data \ 19 | --preprocess_type=ASR \ 20 | --process_step=1 || exit -1 21 | fi 22 | 23 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 24 | CUDA_VISIBLE_DEVICES=${gpus} python processData.py --data_root=../data \ 25 | --process_step=2 || exit -1 26 | fi 27 | 28 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 29 | CUDA_VISIBLE_DEVICES=${gpus} python processData.py --data_root=../data \ 30 | --process_step=3 || exit -1 31 | fi 32 | 33 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 34 | CUDA_VISIBLE_DEVICES=${gpus} python processData.py --data_root=../data \ 35 | --process_step=4 || exit -1 36 | fi -------------------------------------------------------------------------------- /bin/score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TEMP=/tmp 4 | stage=0 5 | stop_stage=100 6 | export IMAGEIO_FFMPEG_EXE=/usr/local/bin/ffmpeg 7 | export IMAGEIO_USE_GPU=True 8 | 9 | source ./parse_options.sh || exit 1 10 | 11 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 12 | python score_video.py --data_root=../data/processed_data \ 13 | --checkpoint_path=../data/syncnet_checkpoint/sync_checkpoint_step000370000.pth \ 14 | --num_worker=2 15 | --batch_size=20 16 | fi -------------------------------------------------------------------------------- /bin/score_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | 这个主要处理判断视频是否同步,用训练的syncnet打分,对于偏离比较大的视频抛弃 3 | 4 | 5 | """ 6 | import argparse 7 | from functools import partial 8 | from pathlib import Path 9 | 10 | import torch 11 | from torch import multiprocessing 12 | from tqdm import tqdm 13 | 14 | from process_util.SyncnetScore import SyncnetScore 15 | 16 | 17 | def main(): 18 | args = parse_args() 19 | data_root = args.data_root 20 | checkpoint = args.checkpoint_path 21 | num_worker = args.num_worker 22 | batch_size = args.batch_size 23 | 24 | train_txt = data_root + '/train.txt' 25 | test_txt = data_root + '/test.txt' 26 | eval_txt = data_root + '/eval.txt' 27 | dir_list = get_dirList(train_txt) 28 | test_list = get_dirList(test_txt) 29 | eval_list = get_dirList(eval_txt) 30 | dir_list += eval_list 31 | dir_list += test_list 32 | Path(data_root + '/score.txt').write_text('') 33 | Path(data_root + '/bad_off.txt').write_text('') 34 | score_tools = SyncnetScore() 35 | proc_f = partial(score_tools.score_video, checkpoint=checkpoint, data_root=data_root, batch_size=batch_size) 36 | multiprocessing.set_start_method('spawn', force=True) 37 | pool = multiprocessing.Pool(processes=num_worker) 38 | # ctx = multiprocessing.get_context('spawn') 39 | # pool = ctx.Pool(num_worker) 40 | prog_bar = tqdm(pool.imap(proc_f, dir_list), total=len(dir_list)) 41 | results = [] 42 | bad_offset_f = [] 43 | for v_file, offset, conf in prog_bar: 44 | results.append('video:{} offset:{} conf:{}'.format(v_file, offset, conf)) 45 | with open(data_root + '/score.txt', 'a') as fw: 46 | fw.write("{},{},{}\n".format(v_file, offset, conf)) 47 | if offset < -1 or offset > 1: 48 | bad_offset_f.append(v_file) 49 | with open(data_root + '/bad_off.txt', 'a', encoding='utf-8') as fw: 50 | fw.write("{}\n".format(v_file)) 51 | prog_bar.set_description('score file:{} offset:{}'.format(v_file, offset)) 52 | torch.cuda.empty_cache() 53 | pool.close() 54 | pool.join() 55 | 56 | 57 | def get_dirList(path): 58 | dir_list = [] 59 | with open(path, 'r') as f: 60 | for line in f: 61 | line = line.strip() 62 | dir_list.append(line) 63 | return dir_list 64 | 65 | 66 | def parse_args(): 67 | # parse args and config 68 | parser = argparse.ArgumentParser( 69 | description="score for the video") 70 | parser.add_argument("--data_root", help='Root folder of the preprocessed dataset', required=True, type=str) 71 | parser.add_argument('--checkpoint_path', help='Load he pre-trained ', required=True, ) 72 | parser.add_argument('--num_worker', help='multiprocessor number', default=6, type=int) 73 | parser.add_argument('--batch_size', help='produce img batch', default=20, type=int) 74 | args = parser.parse_args() 75 | 76 | return args 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /bin/train_syncnet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=0,1 4 | 5 | CUDA_VISIBLE_DEVICES=${gpus} python ../trains/syncnet_train.py \ 6 | --data_root=../data/processed_data \ 7 | --checkpoint_dir=../data/syncnet_checkpoint \ 8 | --config_file=../configs/train_config.yaml \ 9 | --train_type=train -------------------------------------------------------------------------------- /bin/train_wav2lip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpus=0 4 | 5 | CUDA_VISIBLE_DEVICES=${gpus} python ../trains/wl_train.py \ 6 | --data_root=../data/processed_data \ 7 | --checkpoint_dir=../data/checkpoint \ 8 | --syncnet_checkpoint_path=../data/syncnet_checkpoint/sync_checkpoint_step000517000.pth 9 | -------------------------------------------------------------------------------- /configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | train_config: 2 | ########################################audio train config################################## 3 | #Number of mel-spectrogram channels and local conditioning dimensionality 4 | # Whether to rescale audio prior to preprocessing 5 | # Rescaling value 6 | num_mels: 80 7 | rescale: True 8 | resacling_max: 0.9 9 | # Extra window size is filled with 0 paddings to match this parameter 10 | # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 11 | # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 12 | # 16000Hz (corresponding to librispeech) (sox --i ) 13 | n_fft: 800 14 | hop_size: 200 15 | win_size: 800 16 | sample_rate: 16000 17 | # Can replace hop_size parameter. (Recommended: 12.5) 18 | frame_shift_ms: None 19 | # Mel and Linear spectrograms normalization/scaling and clipping 20 | signal_normalization: True 21 | allow_clipping_in_normalization: True 22 | 23 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 24 | # faster and cleaner convergence 25 | 26 | symmetric_mels: True 27 | 28 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 29 | # be too big to avoid gradient explosion, 30 | # not too small for fast convergence) 31 | max_abs_value: 4. 32 | 33 | preemphasize: True # whether to apply filter 34 | preemphasis: 0.97 35 | 36 | min_level_db: -100 37 | ref_level_db: 20 38 | fmin: 55 39 | fmax: 7600 40 | syncnet_mel_step_size: 16 41 | syncnet_T: 5 42 | use_lws: False 43 | #########################################img training parameters################################### 44 | #image training parameters,this is the face traning 45 | img_size: 288 46 | fps: 25 47 | 48 | batch_size: 2 49 | init_learning_rate: 1e-4 50 | epochs: 1000000000 51 | num_works: 8 52 | checkpoint_interval: 3000 53 | eval_interval: 3000 54 | save_optimizer_state: True 55 | syncnet_wt: 0.0 56 | m_min: 100 57 | m_med: 500 58 | m_max: 1000 59 | 60 | #####################################sycnet parameters########################################## 61 | 62 | syncnet_batch_size: 128 63 | syncnet_learning_rate: 1e-5 64 | syncnet_checkpoint_interval: 1000 65 | syncnet_eval_interval: 1000 66 | syncnet_min: 100 67 | syncnet_med: 500 68 | syncnet_max: 1000 69 | #####################################Disc parameters########################################## 70 | disc_wt: 0.07 71 | disc_initial_learning_rate: 1e-4 72 | 73 | 74 | -------------------------------------------------------------------------------- /configs/train_config_96.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/configs/train_config_96.yaml -------------------------------------------------------------------------------- /data/original_data/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/data/original_data/README.md -------------------------------------------------------------------------------- /data/preProcessed_data/README.md: -------------------------------------------------------------------------------- 1 | 初步处理的视频文件放在这里,切割成5s左右的文件 -------------------------------------------------------------------------------- /data/pretrain_model/README: -------------------------------------------------------------------------------- 1 | 用于一些预训练模型存放 2 | 主要的模型有 3 | realesrgan的预训练模型 4 | GFPGAN v1.3的预训练模型 5 | segmentation的BiSeNet预训练模型 -------------------------------------------------------------------------------- /data/syncnet_checkpoint/README.md: -------------------------------------------------------------------------------- 1 | syncnet的checkpoint存放目录 -------------------------------------------------------------------------------- /data/test_data/input/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /data/test_data/input/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/data/test_data/input/test.jpg -------------------------------------------------------------------------------- /data/test_data/pr_data/README.md: -------------------------------------------------------------------------------- 1 | 测试提取人脸图片的内容 -------------------------------------------------------------------------------- /inference_util/FaceParseModel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from .resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18() 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, n_classes, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath() 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | return feat_out, feat_out16, feat_out32 255 | 256 | def init_weight(self): 257 | for ly in self.children(): 258 | if isinstance(ly, nn.Conv2d): 259 | nn.init.kaiming_normal_(ly.weight, a=1) 260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 261 | 262 | def get_params(self): 263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 264 | for name, child in self.named_children(): 265 | child_wd_params, child_nowd_params = child.get_params() 266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 267 | lr_mul_wd_params += child_wd_params 268 | lr_mul_nowd_params += child_nowd_params 269 | else: 270 | wd_params += child_wd_params 271 | nowd_params += child_nowd_params 272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 273 | 274 | 275 | if __name__ == "__main__": 276 | net = BiSeNet(19) 277 | net.cuda() 278 | net.eval() 279 | in_ten = torch.randn(16, 3, 640, 480).cuda() 280 | out, out16, out32 = net(in_ten) 281 | print(out.shape) 282 | 283 | net.get_params() -------------------------------------------------------------------------------- /inference_util/InferenceUtil.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import shutil 4 | import subprocess 5 | import time 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | import torchaudio 12 | import torchaudio.functional as F 13 | from tqdm import tqdm 14 | 15 | from process_util.FaceDetector import FaceDetector 16 | from process_util.ParamsUtil import ParamsUtil 17 | 18 | hp = ParamsUtil() 19 | 20 | 21 | class InferenceUtil(): 22 | def __init__(self, fps, model_path): 23 | self.tmp_dir = os.environ.get('TEMP') 24 | self.fps = fps 25 | self.model_path = model_path 26 | 27 | def get_smoothened_boxes(self,boxes, T): 28 | for i in range(len(boxes)): 29 | if i + T > len(boxes): 30 | window = boxes[len(boxes) - T:] 31 | else: 32 | window = boxes[i: i + T] 33 | boxes[i] = np.mean(window, axis=0) 34 | return boxes 35 | def src_video_process(self, video_f): 36 | v_f = Path(video_f) 37 | v_name = v_f.stem 38 | tmp_video_name = v_name + '.avi' 39 | output_dir = self.tmp_dir + '/' + v_name 40 | extract_dir = self.tmp_dir + '/' + v_name + '/frames' 41 | Path(extract_dir).mkdir(exist_ok=True, parents=True) 42 | print('start extract source video file {}'.format(v_f.as_posix())) 43 | s = time.time() 44 | # change source video to 25fps 45 | ffmpeg_cmd = "ffmpeg -loglevel error -y -i {0} -qscale:v 2 -async 1 -r {1} {2}/{3}".format(v_f, self.fps, 46 | output_dir, 47 | tmp_video_name) 48 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 49 | # extract video to jpg 50 | ffmpeg_cmd = 'ffmpeg -loglevel error -y -i {0} -qscale:v 2 -threads 6 -f image2 {1}/{2}.jpg'.format( 51 | output_dir + '/' + tmp_video_name, extract_dir, '%d') 52 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 53 | st = time.time() 54 | print('process the video file {} cost {:.2f}s'.format(video_f, st - s)) 55 | return extract_dir 56 | 57 | def faces_detect(self, f_frames, args): 58 | results = [] 59 | head_exist = [] 60 | 61 | face_detector = FaceDetector() 62 | batch_size = args.face_det_batch_size 63 | print('start detecting faces...') 64 | s = time.time() 65 | while 1: 66 | predictions = [] 67 | try: 68 | for i in tqdm(range(0, len(f_frames), batch_size),leave=False): 69 | predictions.extend(face_detector.faceBatchDetection(f_frames[i:i+batch_size])) 70 | except RuntimeError: 71 | if batch_size == 1: 72 | raise RuntimeError('Image too large to run face detection') 73 | batch_size //=2 74 | print('Recovering from OOM; New batchsize is {}'.format(batch_size)) 75 | continue 76 | break 77 | pady1,pady2,padx1,padx2 =args.pads 78 | 79 | #获取第一帧头像的大小 80 | f_head_rec = None 81 | f_head_img = None 82 | for rect,img in zip(predictions,f_frames): 83 | if rect is not None: 84 | f_head_rec = rect 85 | f_head_img = img 86 | break 87 | 88 | for rect,img in zip(predictions,f_frames): 89 | if rect is None: 90 | head_exist.append(False) 91 | if len(results) == 0: 92 | y1 = max(0,f_head_rec[1]-pady1) 93 | y2 = min(f_head_img.shape[0],f_head_rec[3]+pady2) 94 | x1 = max(0,f_head_rec[0]-padx1) 95 | x2 = min(f_head_img.shape[1],f_head_rec[2]+padx2) 96 | results.append([x1,y1,x2,y2]) 97 | else: 98 | results.append(results[-1]) 99 | else: 100 | head_exist.append(True) 101 | y1 =max(0,rect[1] - pady1) 102 | y2 = min(f_head_img.shape[0],rect[3]+pady2) 103 | x1 = max(0,rect[0] - padx1) 104 | x2 = min(f_head_img.shape[1],rect[2]+padx2) 105 | results.append([x1,y1,x2,y2]) 106 | boxes = np.array(results) 107 | if not args.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5) 108 | results = [[f_frames[y1:y2,x1:x2],(y1,y2,x1,x2)] for f_frames,(x1,y1,x2,y2) in zip(f_frames,boxes)] 109 | del face_detector 110 | return results,head_exist 111 | 112 | def get_frames(self, img_path, names): 113 | imgs = [] 114 | for img in names: 115 | img_dir = img_path + '/' + str(img) + '.jpg' 116 | img = cv2.imread(img_dir) 117 | imgs.append(img) 118 | return imgs 119 | 120 | def load_img_names(self, frame_path): 121 | img_names = [] 122 | for img in Path(frame_path).glob('**/*.jpg'): 123 | img = img.stem 124 | img_names.append(img) 125 | img_names.sort(key=int) 126 | return img_names 127 | 128 | def extract_audio(self, wavfile): 129 | print('Extracting raw audio...') 130 | tmp_audio = self.tmp_dir + '/temp.wav' 131 | ffmpeg_cmd = 'ffmpeg -loglevel error -y -i {} -strict -2 {}'.format(wavfile, tmp_audio) 132 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 133 | 134 | return tmp_audio 135 | 136 | def get_mel(self, wavefile): 137 | try: 138 | waveform, sf = torchaudio.load(wavefile) 139 | 140 | resample = torchaudio.transforms.Resample(sf, 16000) 141 | waveform = resample(waveform) 142 | waveform = F.preemphasis(waveform, hp.preemphasis) 143 | 144 | specgram = torchaudio.transforms.MelSpectrogram(sample_rate=hp.sample_rate, 145 | n_fft=hp.n_fft, 146 | power=1., 147 | hop_length=hp.hop_size, 148 | win_length=hp.win_size, 149 | f_min=hp.fmin, 150 | f_max=hp.fmax, 151 | n_mels=hp.num_mels, 152 | normalized=hp.signal_normalization) 153 | orig_mel = specgram(waveform) 154 | orig_mel = F.amplitude_to_DB(orig_mel, multiplier=10., amin=hp.min_level_db, 155 | db_multiplier=hp.ref_level_db, top_db=100) 156 | orig_mel = torch.mean(orig_mel, dim=0) 157 | orig_mel = orig_mel.numpy() 158 | except Exception as e: 159 | orig_mel = None 160 | 161 | return orig_mel 162 | 163 | def gen_data(self, frames, mels, args): 164 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 165 | 166 | #*******识别人脸位置坐标,未识别的对应为none ******************* 167 | if args.box[0] == -1: 168 | if not args.static: 169 | face_det_result,head_exist = self.faces_detect(frames,args) 170 | else: 171 | face_det_result,head_exist = self.faces_detect([frames[0]],args) 172 | else: 173 | print('Using the specified bounding box instead of face detaction....') 174 | y1,y2,x1,x2 = args.box 175 | face_det_result = [[(f[y1:y2, x1:x2],(y1,y2,x1,x2))] for f in frames ] 176 | head_exit = [True]*len(frames) 177 | if face_det_result is None: 178 | raise ValueError('No faces in video!Face not detected! Ensure the video contains a face in all the frames.') 179 | 180 | for i, m in tqdm(enumerate(mels),total=len(mels),leave=False): 181 | idx = 0 if args.static else i % len(frames) 182 | frame_to_save = frames[idx].copy() 183 | face, coords = face_det_result[idx].copy() 184 | face = cv2.resize(face, (hp.img_size, hp.img_size)) 185 | 186 | img_batch.append(face) 187 | mel_batch.append(m) 188 | frame_batch.append(frame_to_save) 189 | coords_batch.append(coords) 190 | 191 | if len(img_batch) >= args.wav2lip_batch_size: 192 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 193 | 194 | img_masked = img_batch.copy() 195 | img_masked[:, hp.img_size // 2:] = 0 196 | 197 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 198 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 199 | 200 | yield img_batch, mel_batch, frame_batch, coords_batch 201 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 202 | 203 | if len(img_batch) > 0: 204 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 205 | 206 | img_masked = img_batch.copy() 207 | img_masked[:, hp.img_size // 2:] = 0 208 | 209 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 210 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 211 | 212 | yield img_batch, mel_batch, frame_batch, coords_batch 213 | -------------------------------------------------------------------------------- /inference_util/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/inference_util/README.md -------------------------------------------------------------------------------- /inference_util/__init__.py: -------------------------------------------------------------------------------- 1 | from .swap import init_parser, swap_regions -------------------------------------------------------------------------------- /inference_util/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() -------------------------------------------------------------------------------- /inference_util/swap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import cv2 4 | import numpy as np 5 | 6 | from .FaceParseModel import BiSeNet 7 | 8 | 9 | def init_parser(pth_path): 10 | n_classes = 19 11 | net = BiSeNet(n_classes=n_classes) 12 | net.cuda() 13 | net.load_state_dict(torch.load(pth_path)) 14 | net.eval() 15 | return net 16 | 17 | 18 | def image_to_parsing(img, net): 19 | img = cv2.resize(img, (512, 512)) 20 | img = img[:,:,::-1] 21 | transform = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 24 | ]) 25 | img = transform(img.copy()) 26 | img = torch.unsqueeze(img, 0) 27 | 28 | with torch.no_grad(): 29 | img = img.cuda() 30 | out = net(img)[0] 31 | parsing = out.squeeze(0).cpu().numpy().argmax(0) 32 | return parsing 33 | 34 | 35 | def get_mask(parsing, classes): 36 | res = parsing == classes[0] 37 | for val in classes[1:]: 38 | res += parsing == val 39 | return res 40 | 41 | 42 | def swap_regions(source, target, net): 43 | parsing = image_to_parsing(source, net) 44 | face_classes = [1, 11, 12, 13] 45 | 46 | mask = get_mask(parsing, face_classes) 47 | mask = np.repeat(np.expand_dims(mask, axis=2), 3, 2) 48 | result = (1 - mask) * cv2.resize(source, (512, 512)) + mask * cv2.resize(target, (512, 512)) 49 | result = cv2.resize(result.astype("float32"), (source.shape[1], source.shape[0])) 50 | return result 51 | -------------------------------------------------------------------------------- /models/BaseConv2D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class BaseConv2D(nn.Module): 5 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act='relu', *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | self.conv_block = nn.Sequential( 8 | nn.Conv2d(cin, cout, kernel_size=kernel_size, stride=stride, padding=padding), 9 | nn.BatchNorm2d(cout) 10 | ) 11 | if act == 'relu': 12 | self.act = nn.ReLU() 13 | if act == 'prelu': 14 | self.act = nn.PReLU() 15 | if act == 'leaky': 16 | self.act = nn.LeakyReLU(0.01, inplace=True) 17 | if act == 'sigmoid': 18 | self.act = nn.Sigmoid() 19 | self.residual = residual 20 | 21 | def forward(self, x): 22 | out = self.conv_block(x) 23 | if self.residual: 24 | out += x 25 | return self.act(out) 26 | -------------------------------------------------------------------------------- /models/BaseTranspose.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class BaseTranspose(nn.Module): 4 | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): 5 | super().__init__(*args, **kwargs) 6 | self.conv_block = nn.Sequential( 7 | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), 8 | nn.BatchNorm2d(cout) 9 | ) 10 | self.act = nn.ReLU() 11 | 12 | def forward(self, x): 13 | out = self.conv_block(x) 14 | return self.act(out) -------------------------------------------------------------------------------- /models/Discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from models.BaseConv2D import BaseConv2D 6 | from models.NoNormConv import NoNormConv 7 | 8 | 9 | class Discriminator(nn.Module): 10 | def __init__(self): 11 | super(Discriminator, self).__init__() 12 | 13 | self.face_encoder_blocks = nn.ModuleList([ 14 | nn.Sequential(NoNormConv(3, 32, kernel_size=7, stride=1, padding=3)), # 144,288 15 | 16 | nn.Sequential(NoNormConv(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 144,144 17 | nn.MaxPool2d(kernel_size=5,stride=1,padding=2)), 18 | 19 | nn.Sequential(NoNormConv(64, 128, kernel_size=3, stride=2, padding=1), # 72,72 20 | nn.MaxPool2d(kernel_size=3,stride=1,padding=1)), 21 | 22 | nn.Sequential(NoNormConv(128, 256, kernel_size=3, stride=2, padding=1), # 36,36 23 | nn.MaxPool2d(kernel_size=3,stride=1,padding=1)), 24 | 25 | nn.Sequential(NoNormConv(256, 512, kernel_size=3, stride=2, padding=1), # 18,18 26 | nn.MaxPool2d(kernel_size=3,stride=1,padding=1)), 27 | 28 | nn.Sequential(NoNormConv(512, 512, kernel_size=3, stride=2, padding=1), # 9,9 29 | nn.MaxPool2d(kernel_size=5,stride=1,padding=2)), 30 | 31 | nn.Sequential(NoNormConv(512, 512, kernel_size=3, stride=2, padding=0), # 4,4 32 | nn.MaxPool2d(kernel_size=3,stride=1,padding=1)), 33 | 34 | nn.Sequential(NoNormConv(512, 512, kernel_size=3, stride=1, padding=0), # 3,3 35 | nn.MaxPool2d(kernel_size=1,stride=1,padding=0)), 36 | 37 | nn.Sequential(NoNormConv(512, 512, kernel_size=2, stride=1, padding=0), # 1, 1 38 | nn.MaxPool2d(kernel_size=1,stride=1,padding=0))]) 39 | 40 | #self.binary_pred = nn.Sequential(BaseConv2D(512, 1, kernel_size=1, stride=1, padding=0,act='sigmoid')) 41 | self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), 42 | nn.Sigmoid()) 43 | self.label_noise = .0 44 | 45 | def get_lower_half(self, face_sequences): 46 | return face_sequences[:, :, face_sequences.size(2) // 2:] 47 | 48 | def to_2d(self, face_sequences): 49 | 50 | input_dim_size = len(face_sequences.size()) 51 | if input_dim_size > 4: 52 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 53 | return face_sequences 54 | 55 | def perceptual_forward(self, false_face_sequences): 56 | false_face_sequences = self.to_2d(false_face_sequences) 57 | false_face_sequences = self.get_lower_half(false_face_sequences) 58 | 59 | false_feats = false_face_sequences 60 | for f in self.face_encoder_blocks: 61 | false_feats = f(false_feats) 62 | 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | y = torch.ones(len(false_feats), 1,dtype=torch.float).to(device) 65 | x = self.binary_pred(false_feats).view(false_feats.size(0),-1) 66 | try: 67 | false_pred_loss = F.binary_cross_entropy(x,y) 68 | except Exception as e: 69 | print('x value:{}'.format(x)) 70 | raise e 71 | 72 | return false_pred_loss 73 | 74 | def forward(self, face_sequences): 75 | face_sequences = self.to_2d(face_sequences) 76 | face_sequences = self.get_lower_half(face_sequences) 77 | 78 | x = face_sequences 79 | for f in self.face_encoder_blocks: 80 | x = f(x) 81 | x = self.binary_pred(x).view(x.size(0), -1) 82 | return x 83 | -------------------------------------------------------------------------------- /models/FaceCreator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.BaseTranspose import BaseTranspose 5 | from models.BaseConv2D import BaseConv2D 6 | 7 | 8 | class FaceCreator(nn.Module): 9 | 10 | def __init__(self): 11 | super(FaceCreator, self).__init__() 12 | 13 | self.face_encoder_block = nn.ModuleList([ 14 | nn.Sequential(BaseConv2D(6, 16, kernel_size=7, stride=1, padding=3, act='relu')), # 输入形状 [5,6,288 288] 15 | 16 | nn.Sequential(BaseConv2D(16, 32, kernel_size=5, stride=2, padding=2, act='relu'), # 144 144 17 | BaseConv2D(32, 32, kernel_size=5, stride=1, padding=2, residual=True, act='relu'), 18 | nn.MaxPool2d(kernel_size=5, stride=1, padding=2)), 19 | 20 | nn.Sequential(BaseConv2D(32, 64, kernel_size=3, stride=2, padding=1, act='relu'), # 72 72 21 | BaseConv2D(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 22 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1)), 23 | 24 | nn.Sequential(BaseConv2D(64, 128, kernel_size=3, stride=2, padding=1, act='relu'), # 转成 36 36 25 | BaseConv2D(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 26 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1)), 27 | 28 | nn.Sequential(BaseConv2D(128, 256, kernel_size=3, stride=2, padding=1, act='relu'), # 转成 18 18 29 | BaseConv2D(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 30 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1)), 31 | 32 | nn.Sequential(BaseConv2D(256, 512, kernel_size=3, stride=2, padding=1, act='relu'), 33 | BaseConv2D(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 34 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1)), # 9 9 35 | 36 | nn.Sequential(BaseConv2D(512, 512, kernel_size=3, stride=2, padding=0, act='relu'), 37 | nn.MaxPool2d(kernel_size=1, stride=1, padding=0)), # 4 4 38 | nn.Sequential(BaseConv2D(512, 512, kernel_size=3, stride=1, padding=0, act='relu'), 39 | nn.MaxPool2d(kernel_size=1, stride=1, padding=0)), # 2 2 40 | nn.Sequential(BaseConv2D(512, 512, kernel_size=2, stride=1, padding=0, act='relu'), 41 | nn.MaxPool2d(kernel_size=1, stride=1, padding=0)) # 1 1 42 | ]) 43 | 44 | self.audio_encoder = nn.Sequential( 45 | # [5,1,80,16] 46 | BaseConv2D(1, 32, kernel_size=3, stride=1, padding=1, act='relu'), 47 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 48 | 49 | BaseConv2D(32, 64, kernel_size=3, stride=(3, 1), padding=1, act='relu'), 50 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 51 | 52 | BaseConv2D(64, 128, kernel_size=3, stride=3, padding=1, act='relu'), 53 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 54 | 55 | BaseConv2D(128, 256, kernel_size=3, stride=(3, 2), padding=1, act='relu'), 56 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 57 | 58 | BaseConv2D(256, 512, kernel_size=3, stride=1, padding=0, act='relu'), 59 | nn.MaxPool2d(kernel_size=1, stride=1, padding=0), 60 | BaseConv2D(512, 512, kernel_size=1, stride=1, padding=0, act='relu'), 61 | nn.MaxPool2d(kernel_size=1, stride=1, padding=0) 62 | ) 63 | 64 | self.face_decoder_block = nn.ModuleList([ 65 | nn.Sequential(BaseConv2D(512, 512, kernel_size=1, stride=1, padding=0, act='relu')), # 1 1 66 | 67 | nn.Sequential(BaseTranspose(1024, 512, kernel_size=2, stride=1, padding=0), ), # 2 2 68 | nn.Sequential(BaseTranspose(1024, 512, kernel_size=3, stride=1, padding=0), ), # 4 4 69 | nn.Sequential(BaseTranspose(1024, 256, kernel_size=3, stride=2, padding=0), ), # 9 9 70 | 71 | nn.Sequential(BaseTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), 72 | BaseConv2D(384, 384, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 73 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), ), # 18 18 74 | 75 | nn.Sequential(BaseTranspose(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1), 76 | BaseConv2D(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 77 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), ), # 36 36 78 | 79 | nn.Sequential(BaseTranspose(384, 128, kernel_size=3, stride=2, padding=1, output_padding=1), 80 | BaseConv2D(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 81 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), ), # 72 72 82 | 83 | nn.Sequential(BaseTranspose(192, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 84 | BaseConv2D(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 85 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), ), # 144 144 86 | 87 | nn.Sequential(BaseTranspose(96, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 88 | BaseConv2D(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act='relu'), 89 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), ), # 288 288 90 | 91 | ]) 92 | 93 | self.output_block = nn.Sequential(BaseConv2D(80, 32, kernel_size=3, stride=1, padding=1, act='relu'), 94 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), 95 | nn.Sigmoid(), 96 | nn.MaxPool2d(kernel_size=1, stride=1, padding=0)) 97 | 98 | def forward(self, audio_sequences, face_sequences): 99 | 100 | B = audio_sequences.size(0) 101 | 102 | input_dim_size = len(face_sequences.size()) 103 | if input_dim_size > 4: 104 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) 105 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 106 | 107 | audio_embedding = self.audio_encoder(audio_sequences) 108 | 109 | feats = [] 110 | x = face_sequences 111 | for f in self.face_encoder_block: 112 | x = f(x) 113 | feats.append(x) 114 | 115 | x = audio_embedding 116 | for f in self.face_decoder_block: 117 | x = f(x) 118 | try: 119 | x = torch.cat((x, feats[-1]), dim=1) 120 | except Exception as e: 121 | print('exception got: {}'.format(e)) 122 | print('audio size: {}'.format(x.size())) 123 | print('face size {}'.format(feats[-1].size())) 124 | raise e 125 | feats.pop() 126 | 127 | x = self.output_block(x) 128 | 129 | if input_dim_size > 4: 130 | x = torch.split(x, B, dim=0) 131 | outputs = torch.stack(x, dim=2) 132 | else: 133 | outputs = x 134 | 135 | return outputs 136 | -------------------------------------------------------------------------------- /models/NoNormConv.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class NoNormConv(nn.Module): 5 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 6 | super().__init__(*args, **kwargs) 7 | self.conv_block = nn.Sequential( 8 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 9 | ) 10 | self.act = nn.LeakyReLU(0.01, inplace=True) 11 | 12 | def forward(self, x): 13 | out = self.conv_block(x) 14 | return self.act(out) 15 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/models/README.md -------------------------------------------------------------------------------- /models/SyncNetModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from models.BaseConv2D import BaseConv2D 5 | 6 | 7 | class SyncNetModel(nn.Module): 8 | 9 | def __init__(self): 10 | super(SyncNetModel, self).__init__() 11 | 12 | self.face_encoder = nn.Sequential( 13 | BaseConv2D(15, 32, kernel_size=(7, 7), stride=1, padding=3), 14 | BaseConv2D(32, 32, kernel_size=5, stride=1, padding=1), 15 | BaseConv2D(32, 32, kernel_size=3, stride=1, padding=1), 16 | 17 | BaseConv2D(32, 64, kernel_size=5, stride=(1, 2), padding=1), #140 142 18 | BaseConv2D(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 19 | BaseConv2D(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 20 | 21 | BaseConv2D(64, 128, kernel_size=3, stride=2, padding=1), #7071 22 | BaseConv2D(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 23 | BaseConv2D(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 24 | BaseConv2D(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 25 | 26 | BaseConv2D(128, 256, kernel_size=3, stride=2, padding=1), #35,36 27 | BaseConv2D(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 28 | BaseConv2D(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 29 | 30 | BaseConv2D(256, 512, kernel_size=3, stride=2, padding=1), #18 18 31 | BaseConv2D(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 32 | BaseConv2D(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 33 | 34 | BaseConv2D(512, 512, kernel_size=3, stride=2, padding=1),#9 35 | BaseConv2D(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 36 | 37 | BaseConv2D(512, 512, kernel_size=3, stride=2, padding=0),#4 38 | BaseConv2D(512, 512, kernel_size=3, stride=1, padding=0),#2 39 | BaseConv2D(512, 512, kernel_size=2, stride=1, padding=0),#1 40 | BaseConv2D(512, 512, kernel_size=1, stride=1, padding=0), 41 | 42 | ) 43 | 44 | self.audio_encoder = nn.Sequential( 45 | BaseConv2D(1, 32, kernel_size=3, stride=1, padding=1), 46 | BaseConv2D(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 47 | BaseConv2D(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 48 | 49 | BaseConv2D(32, 64, kernel_size=3, stride=(3, 1), padding=1), 50 | BaseConv2D(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 51 | BaseConv2D(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 52 | 53 | BaseConv2D(64, 128, kernel_size=3, stride=3, padding=1), 54 | BaseConv2D(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 55 | BaseConv2D(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 56 | 57 | BaseConv2D(128, 256, kernel_size=3, stride=(3, 2), padding=1), 58 | BaseConv2D(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 59 | BaseConv2D(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 60 | 61 | BaseConv2D(256, 512, kernel_size=3, stride=1, padding=0), 62 | BaseConv2D(512, 512, kernel_size=1, stride=1, padding=0) 63 | ) 64 | 65 | def forward(self,audio_sequences,face_sequences): 66 | face_embedding = self.face_encoder(face_sequences) 67 | audio_embedding = self.audio_encoder(audio_sequences) 68 | 69 | audio_embedding = audio_embedding.view(audio_embedding.size(0),-1) 70 | face_embedding = face_embedding.view(face_embedding.size(0),-1) 71 | 72 | audio_embedding = F.normalize(audio_embedding,p=2,dim=1) 73 | face_embedding = F.normalize(face_embedding,p=2,dim=1) 74 | 75 | return audio_embedding,face_embedding 76 | -------------------------------------------------------------------------------- /models/SyncNetModel_P.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | def save(model, filename): 8 | with open(filename, "wb") as f: 9 | torch.save(model, f); 10 | print("%s saved."%filename); 11 | 12 | def load(filename): 13 | net = torch.load(filename) 14 | return net; 15 | 16 | class S(nn.Module): 17 | def __init__(self, num_layers_in_fc_layers = 1024): 18 | super(S, self).__init__(); 19 | 20 | self.__nFeatures__ = 24; 21 | self.__nChs__ = 32; 22 | self.__midChs__ = 32; 23 | 24 | self.netcnnaud = nn.Sequential( 25 | nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)), 26 | nn.BatchNorm2d(64), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=(1,1), stride=(1,1)), 29 | 30 | nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)), 31 | nn.BatchNorm2d(192), 32 | nn.ReLU(inplace=True), 33 | nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)), 34 | 35 | nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)), 36 | nn.BatchNorm2d(384), 37 | nn.ReLU(inplace=True), 38 | 39 | nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)), 40 | nn.BatchNorm2d(256), 41 | nn.ReLU(inplace=True), 42 | 43 | nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)), 44 | nn.BatchNorm2d(256), 45 | nn.ReLU(inplace=True), 46 | nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), 47 | 48 | nn.Conv2d(256, 512, kernel_size=(5,4), padding=(0,0)), 49 | nn.BatchNorm2d(512), 50 | nn.ReLU(), 51 | ); 52 | 53 | self.netfcaud = nn.Sequential( 54 | nn.Linear(512, 512), 55 | nn.BatchNorm1d(512), 56 | nn.ReLU(), 57 | nn.Linear(512, num_layers_in_fc_layers), 58 | ); 59 | 60 | self.netfclip = nn.Sequential( 61 | nn.Linear(512, 512), 62 | nn.BatchNorm1d(512), 63 | nn.ReLU(), 64 | nn.Linear(512, num_layers_in_fc_layers), 65 | ); 66 | 67 | self.netcnnlip = nn.Sequential( 68 | nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=0), 69 | nn.BatchNorm3d(96), 70 | nn.ReLU(inplace=True), 71 | nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), 72 | 73 | nn.Conv3d(96, 256, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,1,1)), 74 | nn.BatchNorm3d(256), 75 | nn.ReLU(inplace=True), 76 | nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), 77 | 78 | nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), 79 | nn.BatchNorm3d(256), 80 | nn.ReLU(inplace=True), 81 | 82 | nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), 83 | nn.BatchNorm3d(256), 84 | nn.ReLU(inplace=True), 85 | 86 | nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), 87 | nn.BatchNorm3d(256), 88 | nn.ReLU(inplace=True), 89 | nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), 90 | 91 | nn.Conv3d(256, 512, kernel_size=(1,6,6), padding=0), 92 | nn.BatchNorm3d(512), 93 | nn.ReLU(inplace=True), 94 | ); 95 | 96 | def forward_aud(self, x): 97 | 98 | mid = self.netcnnaud(x); # N x ch x 24 x M 99 | mid = mid.view((mid.size()[0], -1)); # N x (ch x 24) 100 | out = self.netfcaud(mid); 101 | 102 | return out; 103 | 104 | def forward_lip(self, x): 105 | 106 | mid = self.netcnnlip(x); 107 | mid = mid.view((mid.size()[0], -1)); # N x (ch x 24) 108 | out = self.netfclip(mid); 109 | 110 | return out; 111 | 112 | def forward_lipfeat(self, x): 113 | 114 | mid = self.netcnnlip(x); 115 | out = mid.view((mid.size()[0], -1)); # N x (ch x 24) 116 | 117 | return out; -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/models/__init__.py -------------------------------------------------------------------------------- /process_util/DataProcessor.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import cv2 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torchaudio 8 | from moviepy.editor import * 9 | from tqdm import tqdm 10 | 11 | from process_util.FaceDetector import FaceDetector 12 | import logging 13 | 14 | 15 | 16 | class DataProcessor(): 17 | logging.basicConfig(level=logging.ERROR) 18 | face_detector = FaceDetector() 19 | ''' 20 | 视频文件处理,把视频文件分解成脸部图片。并分离出音频 21 | ''' 22 | 23 | def processVideoFile(self, vfile, **kwargs): 24 | vCap = cv2.VideoCapture() 25 | 26 | ok = vCap.open(vfile) 27 | 28 | frames = [] 29 | while ok: 30 | success, frame = vCap.read() 31 | if not success: 32 | vCap.release() 33 | break 34 | frames.append(frame) 35 | # 创建解开的目录 36 | split_dir = self.__get_split_path(vfile, kwargs['processed_data_root']) 37 | self.__extract_face_img(frames, split_dir) 38 | self.__extract_audio(vfile, split_dir) 39 | 40 | """ 41 | 提取音频文件到数据处理文件夹 42 | """ 43 | 44 | def __extract_audio(self, vfile, split_dir): 45 | audio_clip = AudioFileClip(str(vfile)) 46 | audiofile = split_dir + '/audio.wav' 47 | audio_clip.write_audiofile(audiofile, logger=None) 48 | 49 | audio_meta_f = split_dir + '/audio_meta.info' 50 | 51 | wavform, sr = torchaudio.load(audiofile) 52 | resample = torchaudio.transforms.Resample(sr, 16000) 53 | wavform = resample(wavform) 54 | torchaudio.save(audiofile, wavform, sample_rate=16000) 55 | 56 | audio_meta = torchaudio.info(audiofile) 57 | with open(audio_meta_f, 'w') as f: 58 | f.write(str(audio_meta)) 59 | 60 | """ 61 | 提取人脸图片放入数据处理文件 62 | """ 63 | 64 | def __extract_face_img(self, frames, split_dir): 65 | prog_bar = tqdm(enumerate(frames), total=len(frames), leave=False) 66 | # faces={} 67 | # face_file = split_dir+'/faces.pkl' 68 | for j, frame in prog_bar: 69 | j = j + 1 70 | face_result = self.face_detector.faceDetec(frame) 71 | scores = face_result['scores'] 72 | boxes = face_result['boxes'] 73 | if scores is None or len(scores) == 0: 74 | print('bad face video,drop it!') 75 | break 76 | else: 77 | idx = scores.index(max(scores)) 78 | box = boxes[idx] 79 | x1, y1, x2, y2 = box 80 | file_name = split_dir + '/{}.jpg'.format(j) 81 | face = frame[max(0,int(y1)):min(int(y2),frame.shape[0]), max(0,int(x1)):min(int(x2),frame.shape[1])] 82 | if np.size(face) == 0: 83 | continue 84 | cv2.imwrite(file_name, face) 85 | prog_bar.set_description('Extract Face Image:{}/{}.jpg'.format(split_dir, j)) 86 | return split_dir 87 | 88 | # 写入脸部文件“faces.bin",注意的是这个里面保存的是dict文件 89 | # with open(face_file,'wb') as f: 90 | # pickle.dump(faces,f) 91 | 92 | def __get_split_path(self, vfile, processed_data_root): 93 | vf = Path(vfile) 94 | fdir = vf.parts[-3] 95 | fbase = vf.parts[-2] + '_' + vf.stem 96 | fulldir = processed_data_root + '/' + str(fdir) + '/' + fbase 97 | Path(fulldir).mkdir(parents=True, exist_ok=True) 98 | 99 | return fulldir 100 | -------------------------------------------------------------------------------- /process_util/FaceDetector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from modelscope.pipelines import pipeline 3 | from modelscope.utils.constant import Tasks 4 | import logging 5 | 6 | 7 | class FaceDetector(): 8 | logging.basicConfig(level=logging.ERROR) 9 | # 使用modelscope的mogofacedetector模型进行人脸检测 10 | model_id = 'damo/cv_resnet50_face-detection_retinaface' 11 | #model_id = 'damo/cv_manual_face-detection_tinymog' 12 | 13 | # 初始化人脸热别模型,利用这个模型来识别人脸 14 | def __init__(self): 15 | # 初始化模型,人脸检测可以做 16 | self.mog_face_detection_func = pipeline(Tasks.face_detection, self.model_id) 17 | 18 | """ 19 | 在图片上进行人脸识别,回传人脸识别的原始数据,包含人脸的位置以及五官位置信息 20 | 其中src_img_path是图片的物理路径 21 | """ 22 | 23 | def faceDetec(self, *args): 24 | raw_result = self.mog_face_detection_func(args[0]) 25 | return raw_result 26 | 27 | def faceBatchDetection(self, frames): 28 | frames = frames.copy() 29 | faces = [] 30 | for frame in frames: 31 | face_result = self.mog_face_detection_func(frame) 32 | scores = face_result['scores'] 33 | boxes = face_result['boxes'] 34 | if scores is None or len(scores) == 0: 35 | print('No face detected') 36 | faces.append([-1,-1,-1,-1]) 37 | else: 38 | idx = scores.index(max(scores)) 39 | box = boxes[idx] 40 | x1, y1, x2, y2 = box 41 | coords= [int(x1),int(y1),int(x2),int(y2)] 42 | faces.append(coords) 43 | return faces 44 | 45 | -------------------------------------------------------------------------------- /process_util/ParamsUtil.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | class ParamsUtil(): 4 | def __init__(self,config_path:str = '../configs'): 5 | self.config_path = config_path 6 | cf = self.config_path+'/'+'train_config.yaml' 7 | with open(cf,'r')as f: 8 | datas = yaml.load(f,Loader=yaml.FullLoader) 9 | self.data = datas['train_config'] 10 | 11 | def __getattr__(self, key): 12 | if key not in self.data: 13 | raise AttributeError('Param {0} not defined in {1}'.format(key,self.config_file)) 14 | return self.data[key] 15 | 16 | def set_param(self,key,value): 17 | self.data[key]=value 18 | 19 | -------------------------------------------------------------------------------- /process_util/PreProcessor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | 4 | from modelscope.pipelines import pipeline 5 | from modelscope.utils.constant import Tasks 6 | from moviepy.editor import * 7 | 8 | from tqdm import tqdm 9 | import logging 10 | 11 | 12 | ''' 13 | 预处理各类文件和数据,用于训练,包括初始视频的处理切割,还有训练前数据集的预处理等。 14 | videoProcess()处理视频,audioProcess()处理音频。 15 | 音频处理以2s的静默作为分割,把大视频切小,保证每个视频为一句完整的语句。视频大小应该不大于5s,如果大于5s应该继续处理。 16 | ''' 17 | 18 | 19 | class PreProcessor(): 20 | logging.basicConfig(level=logging.ERROR) 21 | model_id = 'damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' 22 | 23 | ''' 24 | 利用ASR对音频文件进行时间戳分离,并写入相同文件名的json文件,然后根据时间戳对视频进行分割处理 25 | ''' 26 | 27 | def videosPreProcessByASR(self,video,**kwargs): 28 | input_dir = kwargs.get('input_dir') 29 | output_dir = kwargs.get('output_dir') 30 | asr_func = pipeline(task=Tasks.auto_speech_recognition, 31 | model=self.model_id) 32 | # 看是否已经有时间戳,没有的话就做时间戳文件 33 | parent_path = Path(video).parent 34 | file_name = Path(video).stem 35 | jsonfile = Path.joinpath(parent_path, file_name).with_suffix('.json') 36 | print('generate the viedo asr file:{}'.format(jsonfile)) 37 | self.__genTimeStampByASR(video=video, 38 | asr=asr_func) 39 | with open(jsonfile, 'r') as f: 40 | dicts = json.load(f) 41 | timestamps = dicts['timestamps'] 42 | start = 0 43 | videoC = VideoFileClip(str(video)) 44 | movieEnd = int(videoC.duration) 45 | outputD = self.__genOutputDir(input_dir, 46 | output_dir, 47 | video) 48 | for time in timestamps: 49 | tmpEnd = time['end'] 50 | if time['start'] != start: 51 | start = time['start'] 52 | if (tmpEnd - start) < 1000: 53 | continue 54 | else: 55 | startTime = round(start / 1000) 56 | endTime = round(tmpEnd / 1000) 57 | if endTime > movieEnd: 58 | endTime = movieEnd 59 | if endTime > startTime: 60 | self.__genClipVideo(videoC, startTime, endTime, outputD) 61 | else: 62 | continue 63 | start = tmpEnd 64 | return outputD 65 | 66 | ''' 67 | 把视频按时间戳文件进行切割 68 | ''' 69 | 70 | def videosPreProcessByTime(self, video,**kwargs): 71 | S_TIME = kwargs.get('s_time') 72 | input_dir = kwargs.get('input_dir') 73 | output_dir = kwargs.get('output_dir') 74 | outputD = self.__genOutputDir(input_dir, output_dir, video) 75 | i = 0 76 | videoC = VideoFileClip(str(video)) 77 | movieEnd = int(videoC.duration) 78 | 79 | # 按秒数来分割视频,最后一段到结束 80 | while i < movieEnd: 81 | startTime = i 82 | endTime = i + S_TIME 83 | if endTime > movieEnd: 84 | endTime = movieEnd 85 | self.__genClipVideo(videoC, startTime, endTime, outputD) 86 | i = i + S_TIME 87 | 88 | return outputD 89 | ''' 90 | 切割视频文件写入到指定目录 91 | ''' 92 | 93 | def __genClipVideo(self, videoClip, startTime, endTime, outputD): 94 | outputName = '{0}/{1:05}_{2:05}.mp4'.format(outputD, 95 | startTime, 96 | endTime) 97 | clipVideo = videoClip.subclip(startTime, endTime) 98 | clipVideo.write_videofile(outputName,fps=25,logger=None) 99 | 100 | ''' 101 | 处理文件后的输出目录生成并返回目录名称 102 | ''' 103 | 104 | def __genOutputDir(self, input_dir, output_dir, file): 105 | iparts = Path(input_dir).parts 106 | fparts = Path(file).parent.parts 107 | outparts = [] 108 | # 分割文件路径,获取上层需要创建的路径名成 109 | for fp in fparts: 110 | if fp not in iparts: 111 | outparts.append(fp) 112 | # 构建输出路径,并创建不存在的输出目录 113 | op = output_dir 114 | if len(outparts) > 0: 115 | for o in outparts: 116 | op = op + '/{}'.format(o) 117 | Path(op).mkdir(exist_ok=True) 118 | lastPath = Path(file).name.split('.')[0] 119 | op = op + '/{}'.format(lastPath) 120 | Path(op).mkdir(exist_ok=True) 121 | dir = str(op) 122 | return dir 123 | 124 | ''' 125 | 获取所有处理文件,并返回文件列表,type是文件的扩展名,也是文件的类型,内部私有方法 126 | ''' 127 | 128 | 129 | 130 | ''' 131 | 把视频文件的音频剥离出来 132 | ''' 133 | 134 | def __genTimeStampByASR(self, **kwargs): 135 | video = kwargs.get('video') 136 | asr_func = kwargs.get('asr') 137 | 138 | wname = Path(video).name.replace('.mp4', '.wav') 139 | temp_dir = os.environ.get('TEMP') 140 | print('wavefile put in temp:{}'.format(temp_dir)) 141 | wavfile = temp_dir + '/' + wname 142 | audio_clip = AudioFileClip(str(video)) 143 | audio_clip.write_audiofile(wavfile, logger=None) 144 | 145 | rec_result = asr_func(audio_in=wavfile) 146 | 147 | # 获取语句时间戳 148 | sentences = rec_result.get('sentences') 149 | timest = [] 150 | for items in sentences: 151 | if items['text'] is not None and items['text'].strip() != '': 152 | timest.append({'start': items['start'], 'text': items['text'], 'end': items['end']}) 153 | fname = Path(video).name.replace('.mp4', '.json') 154 | path = Path(video).parent 155 | tf = str(path) + '/' + fname 156 | video_times = {'timestamps': timest} 157 | with open(tf, 'w') as f: 158 | f.write(json.dumps(video_times)) 159 | return video_times 160 | -------------------------------------------------------------------------------- /process_util/SyncnetPScore.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from moviepy.editor import * 4 | 5 | from process_util.FaceDetector import FaceDetector 6 | 7 | 8 | class SyncnetPScore(): 9 | face_detector = FaceDetector() 10 | def __init__(self,pre_model,batch_size,tmp): 11 | self.pre_model=pre_model 12 | self.batch_size=batch_size 13 | self.tmp_dir=tmp 14 | def scoreVideo(self,v_file): 15 | video_file=v_file 16 | self.__extract_video(video_file) 17 | 18 | def __extract_video(self, video_file): 19 | videoC = VideoFileClip(video_file) 20 | v_end = int(videoC.duration) 21 | outputName = self.tmp_dir+'/'+Path(video_file).stem+'.mp4' 22 | clipVideo = videoC.subclip(0, v_end) 23 | clipVideo.write_videofile(outputName, fps=25) 24 | 25 | new_videoC=VideoFileClip(outputName) 26 | face_list={} 27 | for idx, frame in enumerate(new_videoC.iter_frames()): 28 | idx=idx +1 29 | face_result = self.face_detector.faceDetec(frame) 30 | scores = face_result['scores'] 31 | boxes = face_result['boxes'] 32 | if scores is None or len(scores) == 0: 33 | print('bad face video,drop it!') 34 | continue 35 | else: 36 | idx = scores.index(max(scores)) 37 | box = boxes[idx] 38 | x1, y1, x2, y2 = box 39 | face = frame[int(y1):int(y2), int(x1):int(x2)] 40 | face_list['{}'.format(idx)] = face 41 | if np.size(face) == 0: 42 | continue 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /process_util/SyncnetScore.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy 6 | import numpy as np 7 | import torchaudio 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F2 11 | import torchaudio.functional as F 12 | from tqdm import tqdm 13 | 14 | from models.SyncNetModel import SyncNetModel 15 | 16 | """ 17 | 利用已经训练过的syncnet的来判断数据的阈值,高于阈值的数据丢弃 18 | """ 19 | 20 | 21 | class SyncnetScore(): 22 | def __load_checkpoint(self,checkpoint_pth, model): 23 | if torch.cuda.is_available(): 24 | checkpoint = torch.load(checkpoint_pth) 25 | else: 26 | checkpoint = torch.load(checkpoint_pth, map_location=lambda storage, loc: storage) 27 | 28 | s = checkpoint["state_dict"] 29 | new_s = {} 30 | for k, v in s.items(): 31 | new_s[k.replace('module.', '')] = v 32 | model.load_state_dict(new_s) 33 | 34 | return model 35 | def score_video(self, v_file, **kwargs): 36 | v_dir = kwargs['data_root'] + '/' + v_file 37 | checkpoint = kwargs['checkpoint'] 38 | batch_size = kwargs['batch_size'] 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | model = SyncNetModel().to(device) 41 | for p in model.parameters(): 42 | p.requires_grad = False 43 | model = self.__load_checkpoint(checkpoint, model) 44 | score, conf = self.__score(v_dir, model, batch_size) 45 | return v_file, score, conf 46 | 47 | def __score(self, v_dir, model, batch_size): 48 | files = [] 49 | wavfile = v_dir + '/audio.wav' 50 | for file in Path.glob(Path(v_dir), '**/*.jpg'): 51 | if file.is_file(): 52 | img = file.stem 53 | files.append(int(img)) 54 | files.sort(key=int) 55 | model.eval() 56 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 57 | 58 | #把图形文件名按batchsize做成batch 59 | last_fname = len(files) - 5 60 | original_mel = self.__get_mel(wavfile) 61 | lip_feats = [] 62 | aud_feats = [] 63 | 64 | for i in range(0, last_fname, batch_size): 65 | lip_batch = [] 66 | aud_batch = [] 67 | for fname in range(i + 1, min(last_fname, i + batch_size)): 68 | lip_win = self.__get_lipwin(v_dir, fname) 69 | aud_win = self.__get_aud_windows(original_mel, fname) 70 | lip_batch.append(lip_win) 71 | aud_batch.append(aud_win) 72 | lip_wins=torch.cat(lip_batch,0) 73 | aud_wins=torch.cat(aud_batch,0) 74 | x = lip_wins.to(device) 75 | mel = aud_wins.to(device) 76 | a, v = model(mel, x) 77 | lip_feats.append(v.cpu()) 78 | aud_feats.append(a.cpu()) 79 | 80 | if len(lip_feats) != len(aud_feats): 81 | return 15,15. 82 | lip_feat = torch.cat(lip_feats,0) 83 | aud_feat = torch.cat(aud_feats,0) 84 | a_pad = F2.pad(aud_feat, (0, 0, 15, 15)) 85 | dists = [] 86 | for i in range(0, len(lip_feat)): 87 | s_l = lip_feat[[i], :].repeat(31, 1) 88 | s_a = a_pad[i:i + 31, :] 89 | d = F2.cosine_similarity(s_a, s_l) 90 | dists.append(d) 91 | mdist = torch.mean(torch.stack(dists, 1), 1) 92 | maxval, maxidx = torch.max(mdist, 0) 93 | 94 | offset = 15 - maxidx.item() 95 | conf = maxval - torch.median(mdist).item() 96 | 97 | return offset, conf 98 | 99 | def __crop_audio_window(self, spec, start_frame): 100 | start_frame_num = start_frame 101 | 102 | start_idx = int(80. * (start_frame_num / 25.)) 103 | 104 | end_idx = start_idx + 16 105 | 106 | spec = spec[start_idx:end_idx, :] 107 | 108 | return spec 109 | 110 | def __get_mel(self, wavfile): 111 | try: 112 | wavform, sf = torchaudio.load(wavfile) 113 | 114 | wavform = F.preemphasis(wavform, 0.97) 115 | specgram = torchaudio.transforms.MelSpectrogram(sample_rate=16000, 116 | n_fft=800, 117 | power=1., 118 | hop_length=200, 119 | win_length=800, 120 | f_min=55, 121 | f_max=7600, 122 | n_mels=80, 123 | normalized=True) 124 | orig_mel = specgram(wavform) 125 | orig_mel = F.amplitude_to_DB(orig_mel, multiplier=10., amin=-100, 126 | db_multiplier=20, top_db=100) 127 | orig_mel = torch.mean(orig_mel, dim=0) 128 | orig_mel = orig_mel.t().numpy() 129 | except Exception as e: 130 | print("mel error:".format(e)) 131 | return None 132 | return orig_mel 133 | 134 | def __get_lipwin(self, path, fname): 135 | start_id = fname 136 | seek_id = fname+5 137 | window =[] 138 | for fidx in range(start_id, seek_id): 139 | img_name = path + '/' + '{}.jpg'.format(fidx) 140 | 141 | try: 142 | img_f = cv2.imread(img_name) 143 | img_f = cv2.resize(img_f, (288, 288)) 144 | except Exception as e: 145 | print('image resize error:{}'.format(e)) 146 | img_f = np.zeros((288, 288, 3)) 147 | window.append(img_f) 148 | 149 | x = np.concatenate(window, axis=2) / 255. 150 | x = x.transpose(2, 0, 1) 151 | x = x[:, x.shape[1] // 2:] 152 | x = torch.tensor(x, dtype=torch.float).unsqueeze(0) 153 | 154 | 155 | return x 156 | 157 | def __get_aud_windows(self, original_mel, fname): 158 | mel = self.__crop_audio_window(original_mel.copy(), fname) 159 | if mel.shape[0] != 16: 160 | return torch.zeros(1,1,80,16) 161 | mel = torch.tensor(np.transpose(mel, (1, 0)), dtype=torch.float).unsqueeze(0).unsqueeze(0) 162 | 163 | return mel 164 | -------------------------------------------------------------------------------- /process_util/VideoExtractor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import shutil 4 | import subprocess 5 | import time 6 | from pathlib import Path 7 | 8 | import cv2 9 | from tqdm import tqdm 10 | 11 | from process_util.FaceDetector import FaceDetector 12 | 13 | 14 | class VideoExtractor(): 15 | face_detector = FaceDetector() 16 | 17 | def __init__(self, tmp_dir, fps): 18 | self.tmp_dir = tmp_dir 19 | self.fps = fps 20 | 21 | def __extract_video(self, video_f): 22 | v_f = Path(video_f) 23 | v_path = v_f.parts[-2] 24 | v_name = v_f.stem 25 | tmp_video_name = v_path + '_' + v_name + '.avi' 26 | output_dir = self.tmp_dir + '/' + v_path + '_' + v_name 27 | extract_dir = self.tmp_dir + '/' + v_path + '_' + v_name + '/frames' 28 | Path(extract_dir).mkdir(exist_ok=True, parents=True) 29 | print('start extract video file {}'.format(v_f.as_posix())) 30 | s = time.time() 31 | # convert video to 25fps 32 | ffmpeg_cmd = "ffmpeg -loglevel error -y -i {0} -qscale:v 2 -async 1 -r {1} {2}/{3}".format(v_f, self.fps, 33 | output_dir, 34 | tmp_video_name) 35 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 36 | # extract video to jpg image 37 | ffmpeg_cmd = 'ffmpeg -loglevel error -y -i {0} -qscale:v 2 -threads 6 -f image2 {1}/{2}.jpg'.format( 38 | output_dir + '/' + tmp_video_name, extract_dir, '%d') 39 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 40 | # extract audio file 41 | ffmpeg_cmd = 'ffmpeg -loglevel error -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}/{}'.format( 42 | output_dir + '/' + tmp_video_name, extract_dir, 'audio.wav') 43 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 44 | os.unlink(output_dir+'/'+tmp_video_name) 45 | st = time.time() 46 | print('process the video file {} cost {:.2f}s'.format(tmp_video_name, st - s)) 47 | return output_dir 48 | 49 | def pipline_video(self, video_f, **kwargs): 50 | data_root = kwargs['data_root'] 51 | video_file = data_root + '/' + video_f 52 | 53 | video_path = self.__extract_video(video_file) 54 | faces = self.__face_crop(video_path) 55 | shutil.rmtree(video_path + '/frames') 56 | 57 | def __face_crop(self, video_path): 58 | frames = [] 59 | 60 | for frame in Path(video_path).glob('**/*.jpg'): 61 | if frame.is_file(): 62 | frames.append(int(frame.stem)) 63 | frames.sort(key=int) 64 | bad_f_l = len(frames) % 25 65 | frames = frames[:-bad_f_l] 66 | sc = 1 67 | j = 0 68 | face_files = [] 69 | start_frame = frames[0] 70 | end_frame = 0 71 | face_flag = 0 72 | probar = tqdm(enumerate(frames), total=len(frames), leave=False) 73 | for idx, frame in probar: 74 | j += 1 75 | frame = video_path + '/frames/' + '{}.jpg'.format(idx + 1) 76 | img = cv2.imread(frame) 77 | y_max = img.shape[0] 78 | x_max = img.shape[1] 79 | face_result = self.face_detector.faceDetec(img) 80 | scores = face_result['scores'] 81 | boxes = face_result['boxes'] 82 | if scores is None or len(scores) == 0: 83 | if start_frame < end_frame: 84 | aud_file = self.__write_aud_file(sc, video_path, start_frame, end_frame) 85 | face_files.append('sc_{}'.format(sc)) 86 | start_frame = end_frame 87 | if face_flag == 0: 88 | sc += 1 89 | face_flag = 1 90 | j = 0 91 | continue 92 | else: 93 | face_flag=0 94 | end_frame = idx + 1 95 | idx_s = scores.index(max(scores)) 96 | box = boxes[idx_s] 97 | x1, y1, x2, y2 = box 98 | face = img[max(int(y1)-110,0):min(int(y2)+110,y_max),max(int(x1)-110,0):min(int(x2)+110,x_max)] 99 | face_path = video_path + '/' + 'sc_{}'.format(sc) 100 | Path(face_path).mkdir(exist_ok=True,parents=True) 101 | cv2.imwrite(face_path+'/{}.jpg'.format(j),face) 102 | aud_file = self.__write_aud_file(sc, video_path, start_frame, end_frame) 103 | face_files.append('sc_{}'.format(sc)) 104 | return face_files 105 | 106 | def __write_aud_file(self, sc, vid_path, start_frame, end_frame): 107 | 108 | aud_start = int(start_frame) / 25 109 | aud_end = int(end_frame) / 25 110 | Path(vid_path).mkdir(exist_ok=True, parents=True) 111 | aud_file = vid_path + '/' + 'sc_{}/audio_sc{}'.format(sc,sc) + '.wav' 112 | ffmpeg_cmd = 'ffmpeg -loglevel error -y -i {} -ss {:.3f} -to {:.3f} {}'.format(vid_path + '/frames/audio.wav', 113 | aud_start, aud_end, aud_file) 114 | output = subprocess.call(ffmpeg_cmd, shell=True, stdout=None) 115 | return aud_file 116 | -------------------------------------------------------------------------------- /process_util/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(level=logging.ERROR) -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | pytorchaudio 3 | pytorchvision 4 | pydub -------------------------------------------------------------------------------- /testcase/dataProcessTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import Counter 3 | from pathlib import Path 4 | 5 | import cv2 6 | from tqdm import tqdm 7 | 8 | from process_util.DataProcessor import DataProcessor 9 | 10 | class DataProcessTest(unittest.TestCase): 11 | __input_dir__='../data/test_data/outputT' 12 | dataProcessor = DataProcessor() 13 | def testOpenCV(self): 14 | dp = self.dataProcessor 15 | files = [] 16 | for file in Path.glob(Path(self.__input_dir__), '**/*.mp4'): 17 | if file.is_file(): 18 | files.append(file.as_posix()) 19 | files.sort() 20 | for video in tqdm(files): 21 | dp.processVideoFile(video, 22 | processed_data_root='../data/test_data/pr_data') 23 | 24 | 25 | 26 | 27 | def testffmpeg(self): 28 | pass 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /testcase/discriminatorTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from torchinfo import summary 4 | from models.Discriminator import Discriminator 5 | 6 | 7 | class TestDiscriminator(unittest.TestCase): 8 | def test_modelNet(self): 9 | disc = Discriminator() 10 | faces =torch.randn(1,3,5,288,288) 11 | summary(disc, input_data=(faces)) 12 | -------------------------------------------------------------------------------- /testcase/faceEncodeTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | from torchinfo import summary 9 | 10 | from models.BaseConv2D import BaseConv2D 11 | import matplotlib.pyplot as plt 12 | 13 | from models.Discriminator import Discriminator 14 | from models.FaceCreator import FaceCreator 15 | from wldatasets.FaceDataset import FaceDataset 16 | 17 | 18 | class FaceEncode(unittest.TestCase): 19 | 20 | def testfaceEncode(self): 21 | fe = FaceCreator() 22 | # random_data = torch.randn([1,3,96,96]) 23 | # output = fe.forward(random_data) 24 | audios=torch.randn(5,5,1,80,16) 25 | faces =torch.randn(5,6,5,288,288) 26 | summary(fe,input_data=(audios,faces)) 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /testcase/faceTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import cv2 4 | from PIL import Image 5 | 6 | from process_util.FaceDetector import FaceDetector 7 | 8 | 9 | class TestFaceDetector(unittest.TestCase): 10 | faceDetector = FaceDetector() 11 | 12 | def test_imgPath(self): 13 | imgPath = '../data/test_data/test.jpg' 14 | 15 | result = self.faceDetector.faceDetec(imgPath) 16 | value = {'scores': [0.9578035473823547, 0.957781195640564, 0.9503945708274841, 0.9378407001495361, 17 | 0.8208725452423096], 18 | 'boxes': [[490.98687744140625, 81.58421325683594, 559.7579345703125, 179.40164184570312], 19 | [205.03994750976562, 62.21476364135742, 273.06390380859375, 151.08944702148438], 20 | [803.9385375976562, 157.96820068359375, 865.4893188476562, 235.7114715576172], 21 | [611.2905883789062, 172.23983764648438, 661.6884155273438, 234.4497833251953], 22 | [329.2081604003906, 179.24513244628906, 380.1270751953125, 244.6616973876953]], 23 | 'keypoints': None} 24 | self.assertEqual(result, value, 'testfiles') 25 | 26 | def test_img(self): 27 | img = cv2.imread('../data/test_data/input/test.jpg') 28 | result = self.faceDetector.faceBatchDetection([img]) 29 | face,coords = result[0] 30 | print(coords) 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /testcase/melspecTest.py: -------------------------------------------------------------------------------- 1 | import random 2 | import unittest 3 | 4 | import librosa 5 | import torchaudio.functional as F 6 | import torchaudio 7 | import torchaudio.transforms as T 8 | import torch 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from scipy import signal 12 | 13 | 14 | class MelSpecTest(unittest.TestCase): 15 | def testMelShow(self): 16 | wavfile = "../data/test_data/pr_data/000001/000001_00058_00063/audio.wav" 17 | wavform, sf = torchaudio.load(wavfile) 18 | spec = T.Spectrogram(n_fft=800,hop_length=200,win_length=800,power=1.0)(wavform) 19 | spec = T.AmplitudeToDB(stype='magnitude', 20 | top_db=80.)(spec) 21 | true=0 22 | false1=0 23 | for i in range(1,1000): 24 | if random.choice([True, False]): 25 | true +=1 26 | else: 27 | false1+=1 28 | 29 | print('chosie true:{} flas:{}'.format(true,false1)) 30 | 31 | 32 | asis = 0.97 # filter coefficient. 33 | wavform = F.preemphasis(wavform,float(asis)) 34 | specgram = torchaudio.transforms.MelSpectrogram(sample_rate=16000, 35 | n_fft=850, 36 | power=2., 37 | hop_length=200, 38 | win_length=850, 39 | f_max=7600, 40 | f_min=55, 41 | norm='slaney', 42 | normalized=True, 43 | n_mels=80, 44 | pad_mode='reflect', 45 | mel_scale='htk' 46 | ) 47 | orig_mel = specgram(wavform) 48 | orig_mel = F.amplitude_to_DB(orig_mel, multiplier=20., amin=-100, 49 | db_multiplier=-20) 50 | orig_mel = torch.mean(orig_mel, dim=0) 51 | mel = orig_mel.t().numpy().copy() 52 | 53 | #num_frames = (T x hop_size * fps) / sample_rate 54 | #np.clip((2 * 4) * ( 55 | # (mel - (-100)) / (-(-100))) - 4, 56 | # -4, 4) 57 | """start_frame_num = 48 58 | 59 | start_idx = int(80. * (start_frame_num / float(25))) # 80.乘出来刚好是frame的长度 60 | 61 | end_idx = start_idx + 16 62 | mel = mel[start_idx:end_idx,:]""" 63 | 64 | 65 | fig, axs = plt.subplots(3, 1) 66 | self.plot_wavform(wavform,sf,title='Original wavform',ax=axs[0]) 67 | self.plot_spectrogram(spec[0],title="spectrogram",ax=axs[1]) 68 | self.plot_spectrogram(np.transpose(mel, (1, 0)), title="Mel-spectrogram",ax=axs[2]) 69 | fig.tight_layout() 70 | 71 | 72 | plt.show() 73 | 74 | 75 | def plot_wavform(self,wavform,sr,title="wavform",ax=None): 76 | waveform=wavform.numpy() 77 | 78 | num_channels,num_frames = waveform.shape 79 | time_axis = torch.arange(0,num_frames)/sr 80 | 81 | if ax is None: 82 | _,ax = plt.subplots(num_channels,1) 83 | ax.plot(time_axis,waveform[0],linewidth=1) 84 | ax.grid(True) 85 | ax.set_xlim([0,time_axis[-1]]) 86 | ax.set_title(title) 87 | 88 | def plot_spectrogram(self,specgram,title=None,ylabel='freq_bin',ax=None): 89 | if ax is None: 90 | _,ax =plt.subplots(1,1) 91 | if title is not None: 92 | ax.set_title(title) 93 | ax.set_ylabel(ylabel) 94 | ax.imshow(specgram,origin='lower',aspect='auto',interpolation='nearest') 95 | 96 | -------------------------------------------------------------------------------- /testcase/paramsUtilsTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from process_util.ParamsUtil import ParamsUtil 4 | 5 | 6 | class TestParamsUtils(unittest.TestCase): 7 | 8 | def testParams(self): 9 | paramsU = ParamsUtil() 10 | 11 | self.assertEqual(80,paramsU.num_mels) 12 | self.assertEqual(0.9, paramsU.resacling_max) 13 | self.assertEqual('None', paramsU.frame_shift_ms) 14 | self.assertEqual(-100, paramsU.min_level_db) 15 | self.assertEqual(20, paramsU.ref_level_db) 16 | self.assertEqual(288, paramsU.img_size) 17 | -------------------------------------------------------------------------------- /testcase/processorTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import Counter 3 | from pathlib import Path 4 | 5 | from process_util.PreProcessor import PreProcessor 6 | 7 | 8 | class TestProcessor(unittest.TestCase): 9 | preProcessor = PreProcessor() 10 | __input_dir__ = '../data/test_data/input' 11 | __output_dir__ = '../data/test_data/output' 12 | __outputT_dir__ = '../data/test_data/outputT' 13 | 14 | def testVideosPreProcessByASR(self): 15 | processor = self.preProcessor 16 | processor.videosPreProcessByASR(input_dir=self.__input_dir__, 17 | output_dir=self.__output_dir__, 18 | ext='mp4') 19 | 20 | def testVideoPreProcessByTime(self): 21 | Path(self.__outputT_dir__).mkdir(exist_ok=True) 22 | processor = self.preProcessor 23 | 24 | files = [] 25 | for file in Path.glob(Path(self.__input_dir__), '**/*.mp4'): 26 | if file.is_file(): 27 | files.append(file) 28 | files.sort() 29 | for video in files: 30 | v = processor.videosPreProcessByTime(video, 31 | s_time=5, 32 | input_dir=self.__input_dir__, 33 | output_dir=self.__outputT_dir__ 34 | ) 35 | -------------------------------------------------------------------------------- /testcase/syncnetTrainTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import optim 6 | from torch.utils.data import DataLoader 7 | 8 | from models.SyncNetModel import SyncNetModel 9 | from process_util.ParamsUtil import ParamsUtil 10 | from trains import syncnet_train 11 | from wldatasets.SyncNetDataset import SyncNetDataset 12 | 13 | 14 | class TestSyncnetTrain(unittest.TestCase): 15 | 16 | def testGPUTrain(self): 17 | device=torch.device('cuda'if torch.cuda.is_available() else'cpu') 18 | data_root='../data/test_data/pr_data' 19 | train_txt = data_root + '/train.txt' 20 | eval_txt = data_root + '/eval.txt' 21 | Path(train_txt).write_text('') 22 | Path(eval_txt).write_text('') 23 | for line in Path.glob(Path(data_root), '*/*'): 24 | if line.is_dir(): 25 | dirs = line.parts 26 | input_line = str(dirs[-2] + '/' + dirs[-1]) 27 | with open(train_txt, 'a') as f: 28 | f.write(input_line + '\n') 29 | with open(eval_txt, 'a') as f: 30 | f.write(input_line + '\n') 31 | 32 | param = ParamsUtil() 33 | train_dataset = SyncNetDataset(data_root, run_type='train', img_size=288) 34 | val_dataset = SyncNetDataset(data_root, run_type='eval', img_size=288) 35 | train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, 36 | num_workers=8,drop_last=True) 37 | val_dataloader = DataLoader(val_dataset, batch_size=2,num_workers=8,drop_last=True) 38 | 39 | model = SyncNetModel().to(device) 40 | print("SyncNet Model's Total trainable params {}".format( 41 | sum(p.numel() for p in model.parameters() if p.requires_grad))) 42 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=float(param.syncnet_learning_rate)) 43 | #optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=float(param.syncnet_learning_rate)) 44 | start_step = 0 45 | start_epoch = 0 46 | syncnet_train.train(device, model, train_dataloader, val_dataloader, optimizer, '../data/test_data/checkpoint', start_step, start_epoch) 47 | 48 | -------------------------------------------------------------------------------- /testcase/syncnetmodelTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | from torch.utils.data import DataLoader 8 | from torchinfo import summary 9 | import matplotlib.pyplot as plt 10 | from models.BaseConv2D import BaseConv2D 11 | 12 | from models.SyncNetModel import SyncNetModel 13 | from wldatasets.SyncNetDataset import SyncNetDataset 14 | 15 | class TestSyncnetModel(unittest.TestCase): 16 | 17 | def testModelSummary(self): 18 | model = SyncNetModel() 19 | faces = torch.randn(1,15,144,288) 20 | audios = torch.randn(1,1,80,16) 21 | summary(model,input_data=(audios,faces),) 22 | 23 | def testFaceShape(self): 24 | faces = np.random.randint(255,size=(288,288,3)) 25 | plt.imshow(faces) 26 | plt.show() -------------------------------------------------------------------------------- /testcase/testFaceDataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import Counter 3 | from pathlib import Path 4 | 5 | import librosa.display 6 | import torch 7 | import torchaudio 8 | from torch.utils.data import DataLoader 9 | from wldatasets.FaceDataset import FaceDataset 10 | import matplotlib.pyplot as plt 11 | from process_util.ParamsUtil import ParamsUtil 12 | 13 | 14 | class TestFaceDataset(unittest.TestCase): 15 | 16 | def test_getItem(self): 17 | sData = FaceDataset('../data/test_data/pr_data', img_size=288) 18 | test_loader = DataLoader(sData) 19 | for x, indiv_mels, mel, y in test_loader: 20 | print("matrix x's size:{}".format(x.size())) 21 | print("matrix y size:{}".format(y.size())) 22 | print("matrix mel1's size:{}".format(mel.size())) 23 | print("matrix invid_mes's size:{}".format(indiv_mels.size())) 24 | 25 | 26 | """wavform, sf = torchaudio.load('../data/test_data/pr_data/000001/000001_00000_00006/audio.wav') 27 | print('wav shape is {}'.format(wavform.size())) 28 | specgram = torchaudio.transforms.MelSpectrogram(sample_rate=hp.sample_rate, 29 | n_fft=hp.n_fft, 30 | hop_length=hp.hop_size, 31 | win_length=hp.win_size, 32 | power=2, 33 | f_min=hp.fmin, 34 | f_max=hp.fmax, 35 | n_mels=hp.num_mels, 36 | normalized=hp.signal_normalization)(wavform) 37 | specgram = specgram.mT 38 | print('specgram shape is {}'.format(specgram.size())) 39 | plt.figure() 40 | p = plt.imsave('test.png',specgram.log2()[0,:,:].detach().numpy(),cmap="gray",bbox_inches=None,pad_inches=0)""" 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /testcase/testSyncDataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from torch.utils.data import DataLoader 3 | from process_util.ParamsUtil import ParamsUtil 4 | from wldatasets.SyncNetDataset import SyncNetDataset 5 | 6 | 7 | class TestSyncDataset(unittest.TestCase): 8 | 9 | def test_getItem(self): 10 | sData = SyncNetDataset('../data/test_data/pr_data', img_size=288) 11 | 12 | test_loader = DataLoader(sData) 13 | for x,mel,y in test_loader: 14 | print("matrix x's size:{}".format(x.size())) 15 | print("matrix y size:{}".format(y.size())) 16 | print("matrix mel1's size:{}".format(mel.size())) 17 | -------------------------------------------------------------------------------- /testcase/videoExtractorTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from process_util.VideoExtractor import VideoExtractor 3 | 4 | class MyTestCase(unittest.TestCase): 5 | def test_something(self): 6 | video_e = VideoExtractor('../data/temp',25) 7 | video_e.pipline_video('cctvm0000003/0003.mp4',data_root='../data/original_data') 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /testcase/w2ltrainTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pathlib import Path 3 | import torch 4 | from torch import optim 5 | from torch.utils.data import DataLoader 6 | 7 | from models.FaceCreator import FaceCreator 8 | from models.Discriminator import Discriminator 9 | from models.SyncNetModel import SyncNetModel 10 | from process_util.ParamsUtil import ParamsUtil 11 | from trains import wl_train 12 | from wldatasets.SyncNetDataset import SyncNetDataset 13 | from wldatasets.FaceDataset import FaceDataset 14 | 15 | syncnet = SyncNetModel() 16 | for p in syncnet.parameters(): 17 | p.requires_grad = False 18 | class W2LtrainTest(unittest.TestCase): 19 | def test_train(self): 20 | param = ParamsUtil() 21 | data_root='../data/test_data/pr_data' 22 | checkpoint_dir = '../data/test_data/checkpoint' 23 | Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) 24 | disc_checkpoint_path = None 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | syncnet_checkpoint_path='../data/syncnet_checkpoint/sync_checkpoint_step000340000.pth' 27 | 28 | train_dataset = FaceDataset(data_root, run_type='train', img_size=param.img_size) 29 | test_dataset = FaceDataset(data_root, run_type='eval', img_size=param.img_size) 30 | 31 | train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, 32 | num_workers=8) 33 | 34 | test_data_loader = DataLoader(test_dataset, batch_size=4, 35 | num_workers=8) 36 | 37 | 38 | model=FaceCreator() 39 | disc = Discriminator() 40 | 41 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], 42 | lr=float(param.init_learning_rate), betas=(0.5, 0.999)) 43 | 44 | disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad], 45 | lr=float(param.disc_initial_learning_rate), betas=(0.5, 0.999)) 46 | 47 | 48 | 49 | start_step = 0 50 | start_epoch = 0 51 | 52 | # 装在sync_net 53 | wl_train.load_checkpoint(syncnet_checkpoint_path, syncnet, None, reset_optimizer=True) 54 | 55 | wl_train.train(model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, 56 | checkpoint_dir, 0, 0) -------------------------------------------------------------------------------- /trains/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/trains/README.md -------------------------------------------------------------------------------- /trains/syncnet_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from torch.optim.lr_scheduler import MultiStepLR 5 | 6 | from models.SyncNetModel import SyncNetModel 7 | from tqdm import tqdm 8 | from wldatasets.SyncNetDataset import SyncNetDataset 9 | 10 | import torch 11 | from torch import nn 12 | from torch import optim 13 | from torch.utils.data import DataLoader 14 | from torch.nn import functional as F 15 | from process_util.ParamsUtil import ParamsUtil 16 | import torch.multiprocessing 17 | import argparse 18 | from visualdl import LogWriter 19 | 20 | param = ParamsUtil() 21 | logloss = nn.BCELoss() 22 | #logloss = nn.BCEWithLogitsLoss() 23 | 24 | class MyDataParallel(nn.DataParallel): 25 | def __getattr__(self, name): 26 | try: 27 | return super().__getattr__(name) 28 | except AttributeError: 29 | return getattr(self.module, name) 30 | # 判断是否使用gpu 31 | def load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False): 32 | print("load checkpoint from: {}".format(checkpoint_path)) 33 | if torch.cuda.is_available(): 34 | checkpoint = torch.load(checkpoint_path) 35 | else: 36 | checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 37 | s = checkpoint["state_dict"] 38 | model.load_state_dict(s) 39 | if not reset_optimizer: 40 | optimizer_state = checkpoint["optimizer"] 41 | if optimizer_state is not None: 42 | print("Load optimizer state from {}".format(checkpoint_path)) 43 | optimizer.load_state_dict(checkpoint["optimizer"]) 44 | step = checkpoint["global_step"] 45 | epoch = checkpoint["global_epoch"] 46 | 47 | return model, step, epoch,optimizer 48 | 49 | 50 | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch): 51 | path = Path(checkpoint_dir + '/sync_checkpoint_step{:09d}.pth'.format(step)) 52 | optimizer_state = optimizer.state_dict() if param.save_optimizer_state else None 53 | torch.save({ 54 | "state_dict": model.state_dict(), 55 | "optimizer": optimizer_state, 56 | "global_step": step, 57 | "global_epoch": epoch, 58 | }, path) 59 | print("save the checkpoint step {}".format(path)) 60 | 61 | 62 | def eval_model(val_dataloader, global_step, device, model): 63 | eval_steps = 1400 64 | losses = [] 65 | print('Evaluating for {} steps'.format(eval_steps)) 66 | while 1: 67 | for vstep, (x, mel, y) in enumerate(val_dataloader): 68 | model.eval() 69 | 70 | x = x.to(device) 71 | mel = mel.to(device) 72 | y = y.to(device) 73 | 74 | a, v = model(mel, x) 75 | 76 | d = F.cosine_similarity(a, v) 77 | loss = logloss(d.unsqueeze(1), y) 78 | 79 | losses.append(loss.item()) 80 | 81 | if vstep > eval_steps: break 82 | averaged_loss = sum(losses) / len(losses) 83 | print('The evaluating loss:{}'.format(averaged_loss)) 84 | return averaged_loss 85 | 86 | 87 | def train(device, model, train_dataloader, val_dataloader, optimizer, checkpoint_dir, start_step, start_epoch): 88 | global_step = start_step 89 | epoch = start_epoch 90 | numepochs = param.epochs 91 | checkpoint_interval = param.syncnet_checkpoint_interval 92 | eval_interval = param.syncnet_eval_interval 93 | #scheduler = MultiStepLR(optimizer,milestones=[int(param.syncnet_min),int(param.syncnet_med),int(param.syncnet_max)],gamma=0.1) 94 | 95 | with LogWriter(logdir="../logs/syncnet_train/train") as writer: 96 | while epoch < numepochs: 97 | running_loss = 0 98 | prog_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False) 99 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 100 | for step, (x, mel, y) in prog_bar: 101 | model.train() 102 | optimizer.zero_grad() 103 | 104 | # transform data to cuda 105 | x = x.to(device) 106 | mel = mel.to(device) 107 | a, v = model(mel, x) 108 | y = y.to(device) 109 | 110 | # 计算loss 111 | d = F.cosine_similarity(a, v) 112 | loss = logloss(d.unsqueeze(1), y) 113 | loss.backward() 114 | 115 | optimizer.step() 116 | 117 | global_step = global_step + 1 118 | running_loss += loss.item() 119 | if global_step % checkpoint_interval == 0: 120 | save_checkpoint(model, optimizer, global_step, checkpoint_dir, epoch) 121 | 122 | if global_step % eval_interval == 0: 123 | with torch.no_grad(): 124 | eval_loss=eval_model(val_dataloader, global_step, device, model) 125 | writer.add_scalar(tag='sync_train/eval_loss', step=global_step, value=eval_loss) 126 | 127 | prog_bar.set_description('Syncnet Train Epoch [{0}/{1}]'.format(epoch, numepochs)) 128 | prog_bar.set_postfix(train_loss=running_loss / (step + 1), step=step + 1, gloab_step=global_step,lr=lr) 129 | writer.add_scalar(tag='sync_train/step_loss', step=global_step, value=running_loss / (step + 1)) 130 | #自动调整lr 131 | # scheduler.step() 132 | epoch += 1 133 | 134 | 135 | 136 | def main(): 137 | args = parse_args() 138 | 139 | checkpoint_dir = args.checkpoint_dir 140 | checkpoint_path = args.checkpoint_path 141 | train_type = args.train_type 142 | 143 | Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) 144 | 145 | train_dataset = SyncNetDataset(args.data_root, run_type=train_type) 146 | val_dataset = SyncNetDataset(args.data_root, run_type='eval') 147 | 148 | train_dataloader = DataLoader(train_dataset, batch_size=param.syncnet_batch_size, shuffle=True, 149 | num_workers=param.num_works) 150 | val_dataloader = DataLoader(val_dataset, batch_size=param.syncnet_batch_size, 151 | num_workers=param.num_works) 152 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 153 | model = SyncNetModel() 154 | cuda_ids = [int(d_id) for d_id in os.environ.get('CUDA_VISIBLE_DEVICES').split(',')] 155 | print('cuda ids:{}'.format(cuda_ids)) 156 | model = MyDataParallel(model, device_ids=cuda_ids) 157 | model.to(device) 158 | 159 | 160 | print("SyncNet Model's Total trainable params {}".format( 161 | sum(p.numel() for p in model.parameters() if p.requires_grad))) 162 | 163 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=float(param.syncnet_learning_rate)) 164 | # optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=param.syncnet_learning_rate) 165 | 166 | start_step = 0 167 | start_epoch = 0 168 | 169 | if checkpoint_path is not None: 170 | model, start_step, start_epoch,optimizer = load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False) 171 | 172 | 173 | train(device, model, train_dataloader, val_dataloader, optimizer, checkpoint_dir, start_step, start_epoch) 174 | 175 | 176 | def parse_args(): 177 | # parse args and config 178 | parser = argparse.ArgumentParser( 179 | description="train the sync_net model for wav2lip") 180 | parser.add_argument("--data_root", help='Root folder of the preprocessed dataset', required=True) 181 | parser.add_argument("--checkpoint_dir", help='Save checkpoints to this directory', required=True, type=str) 182 | parser.add_argument("--checkpoint_path", help='Resume from this checkpoint', default=None, type=str) 183 | parser.add_argument('--config_file', help='The train config file', default='../configs/train_config.yaml', 184 | required=True, type=str) 185 | parser.add_argument('--train_type', help='Resume qulity disc from this checkpoint', default='train', type=str) 186 | 187 | args = parser.parse_args() 188 | 189 | return args 190 | 191 | 192 | if __name__ == '__main__': 193 | main() 194 | -------------------------------------------------------------------------------- /trains/wl_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from torch import optim, nn 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | from visualdl import LogWriter 12 | 13 | from models.Discriminator import Discriminator 14 | from models.FaceCreator import FaceCreator 15 | from models.SyncNetModel import SyncNetModel 16 | from process_util.ParamsUtil import ParamsUtil 17 | from wldatasets.FaceDataset import FaceDataset 18 | 19 | 20 | class MyDataParallel(nn.DataParallel): 21 | def __getattr__(self, name): 22 | try: 23 | return super().__getattr__(name) 24 | except AttributeError: 25 | return getattr(self.module, name) 26 | # 判断是否使用gpu 27 | import os 28 | 29 | param = ParamsUtil() 30 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 31 | # device = torch.device("cpu") 32 | 33 | syncnet = SyncNetModel().to(device) 34 | for p in syncnet.parameters(): 35 | p.requires_grad = False 36 | 37 | logloss = nn.BCEWithLogitsLoss() 38 | # logloss = nn.BCELoss() 39 | recon_loss = nn.L1Loss() 40 | 41 | 42 | def load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False): 43 | print("Load checkpoint from: {}".format(checkpoint_path)) 44 | 45 | if torch.cuda.is_available(): 46 | checkpoint = torch.load(checkpoint_path) 47 | else: 48 | checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 49 | 50 | s = checkpoint["state_dict"] 51 | new_s = {} 52 | for k, v in s.items(): 53 | new_s[k.replace('module.', '')] = v 54 | model.load_state_dict(new_s) 55 | if not reset_optimizer: 56 | optimizer_state = checkpoint["optimizer"] 57 | if optimizer_state is not None: 58 | print("Load optimizer state from {}".format(checkpoint_path)) 59 | optimizer.load_state_dict(checkpoint["optimizer"]) 60 | step = checkpoint["global_step"] 61 | epoch = checkpoint["global_epoch"] 62 | 63 | return model, step, epoch 64 | 65 | 66 | def cosine_loss(a, v, y): 67 | d = nn.functional.cosine_similarity(a, v) 68 | loss = logloss(d.unsqueeze(1), y) 69 | return loss 70 | 71 | 72 | def get_sync_loss(mel, g): 73 | g = g[:, :, :, g.size(3) // 2:] 74 | g = torch.cat([g[:, :, i] for i in range(param.syncnet_T)], dim=1) 75 | # B, 3 * T, H//2, W 76 | a, v = syncnet(mel, g) 77 | y = torch.ones(g.size(0), 1, dtype=torch.float).to(device) 78 | return cosine_loss(a, v, y) 79 | 80 | 81 | def save_sample_images(x, g, gt, global_step, checkpoint_dir): 82 | x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 83 | g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 84 | gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 85 | 86 | refs, inps = x[..., 3:], x[..., :3] 87 | folder = checkpoint_dir + "/samples_step_{:09d}".format(global_step) 88 | Path(folder).mkdir(parents=True, exist_ok=True) 89 | collage = np.concatenate((refs, inps, g, gt), axis=-2) 90 | for batch_idx, c in enumerate(collage): 91 | for t in range(len(c)): 92 | cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t]) 93 | return collage 94 | 95 | 96 | def save_checkpoint(model, optimizer, global_step, checkpoint_dir, epoch, prefix=''): 97 | checkpoint_path = checkpoint_dir + "/{}checkpoint_step{:09d}.pth".format(prefix, global_step) 98 | optimizer_state = optimizer.state_dict() if param.save_optimizer_state else None 99 | torch.save({ 100 | "state_dict": model.state_dict(), 101 | "optimizer": optimizer_state, 102 | "global_step": global_step, 103 | "global_epoch": epoch, 104 | }, checkpoint_path) 105 | print("Saved checkpoint:", checkpoint_path) 106 | 107 | 108 | def eval_model(test_data_loader, model, disc): 109 | eval_steps = 300 110 | print('Evaluating for {} steps:'.format(eval_steps)) 111 | running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], [] 112 | while 1: 113 | for step, (x, indiv_mels, mel, gt) in enumerate(test_data_loader): 114 | model.eval() 115 | disc.eval() 116 | 117 | x = x.to(device) 118 | mel = mel.to(device) 119 | indiv_mels = indiv_mels.to(device) 120 | gt = gt.to(device) 121 | 122 | pred = disc(gt) 123 | disc_real_loss = F.binary_cross_entropy(pred, torch.ones((pred.size(0), 1)).to(device)) 124 | 125 | g = model(indiv_mels, x) 126 | pred = disc(g) 127 | disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((pred.size(0), 1)).to(device)) 128 | 129 | running_disc_real_loss.append(disc_real_loss.item()) 130 | running_disc_fake_loss.append(disc_fake_loss.item()) 131 | 132 | sync_loss = get_sync_loss(mel=mel, g=g) 133 | 134 | if param.disc_wt > 0.: 135 | perceptual_loss = disc.perceptual_forward(g) 136 | else: 137 | perceptual_loss = 0. 138 | 139 | l1loss = recon_loss(g, gt) 140 | loss = float(param.syncnet_wt) * sync_loss + float(param.disc_wt) * perceptual_loss + \ 141 | (1. - float(param.syncnet_wt) - float(param.disc_wt)) * l1loss 142 | running_l1_loss.append(loss.item()) 143 | running_sync_loss.append(sync_loss.item()) 144 | 145 | if param.disc_wt > 0.: 146 | running_perceptual_loss.append(perceptual_loss.item()) 147 | else: 148 | running_perceptual_loss.append(0.) 149 | 150 | if step > eval_steps: break 151 | 152 | print('L1: {}, \n Sync: {}, \n Percep: {} | Fake: {}, Real: {}'.format( 153 | sum(running_l1_loss) / len(running_l1_loss), 154 | sum(running_sync_loss) / len( 155 | running_sync_loss), 156 | sum(running_perceptual_loss) / len( 157 | running_perceptual_loss), 158 | sum(running_disc_fake_loss) / len( 159 | running_disc_fake_loss), 160 | sum(running_disc_real_loss) / len( 161 | running_disc_real_loss))) 162 | eval_loss = sum(running_sync_loss) / len(running_sync_loss) 163 | 164 | return eval_loss 165 | 166 | 167 | def train(model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,checkpoint_dir, 168 | start_step, start_epoch): 169 | global_step = start_step 170 | epoch = start_epoch 171 | num_epochs = param.epochs 172 | checkpoint_interval = param.checkpoint_interval 173 | eval_interval = param.eval_interval 174 | #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[param.m_min,param.m_med,param.m_max],gamma=0.1) 175 | lr = optimizer.state_dict()['param_groups'][0]['lr'] 176 | with LogWriter(logdir="../logs/wav2lip/train") as writer: 177 | while epoch < num_epochs: 178 | running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0. 179 | running_disc_real_loss, running_disc_fake_loss = 0., 0. 180 | prog_bar = tqdm(enumerate(train_data_loader), total=len(train_data_loader), leave=False) 181 | 182 | for step, (x, indiv_mels, mel, gt) in prog_bar: 183 | disc.train() 184 | model.train() 185 | 186 | x = x.to(device) 187 | mel = mel.to(device) 188 | indiv_mels = indiv_mels.to(device) 189 | gt = gt.to(device) 190 | 191 | ### Train generator now. Remove ALL grads. 192 | optimizer.zero_grad() 193 | disc_optimizer.zero_grad() 194 | 195 | g = model(indiv_mels, x) 196 | 197 | if float(param.syncnet_wt) > 0.: 198 | sync_loss = get_sync_loss(mel, g) 199 | else: 200 | sync_loss = torch.tensor(0, dtype=torch.float) 201 | 202 | if float(param.disc_wt) > 0.: 203 | perceptual_loss = disc.perceptual_forward(g) 204 | else: 205 | perceptual_loss = torch.tensor(0, dtype=torch.float) 206 | # print ("g:{}|gt:{}".format(g.shape,gt.shape)) 207 | l1loss = recon_loss(g, gt) 208 | 209 | loss = float(param.syncnet_wt) * sync_loss + float(param.disc_wt) * perceptual_loss + \ 210 | (1. - float(param.syncnet_wt) - float(param.disc_wt)) * l1loss 211 | 212 | loss.backward() 213 | optimizer.step() 214 | 215 | 216 | # Remove all gradients before Training disc 217 | disc_optimizer.zero_grad() 218 | 219 | pred = disc(gt) 220 | disc_real_loss = F.binary_cross_entropy(pred, torch.ones(pred.size(0), 1).to(device)) 221 | disc_real_loss.backward() 222 | 223 | pred = disc(g.detach()) 224 | disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros(pred.size(0), 1).to(device)) 225 | disc_fake_loss.backward() 226 | 227 | disc_optimizer.step() 228 | 229 | running_disc_real_loss += disc_real_loss.item() 230 | running_disc_fake_loss += disc_fake_loss.item() 231 | 232 | if global_step % checkpoint_interval == 0: 233 | collage = save_sample_images(x, g, gt, global_step, checkpoint_dir) 234 | for batch_idx, c in enumerate(collage): 235 | for t in range(0,c.shape[0]-1): 236 | x = cv2.cvtColor(c[t], cv2.COLOR_RGB2BGR) 237 | writer.add_image(tag='train/sample', img=x/ 255., step=global_step) 238 | 239 | global_step += 1 240 | 241 | running_l1_loss += l1loss.item() 242 | if float(param.syncnet_wt) > 0.: 243 | running_sync_loss += sync_loss.item() 244 | else: 245 | running_sync_loss += 0. 246 | 247 | if param.disc_wt > 0.: 248 | running_perceptual_loss += perceptual_loss.item() 249 | else: 250 | running_perceptual_loss += 0. 251 | 252 | if global_step == 1 or global_step % checkpoint_interval == 0: 253 | save_checkpoint(model, optimizer, global_step, checkpoint_dir, epoch) 254 | save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, epoch, prefix='disc_') 255 | 256 | if global_step % eval_interval == 0: 257 | with torch.no_grad(): 258 | average_sync_loss = eval_model(test_data_loader, model, disc) 259 | writer.add_scalar(tag='train/eval_loss', step=global_step, value=average_sync_loss) 260 | if average_sync_loss < .75: 261 | param.set_param('syncnet_wt', 0.01) 262 | prog_bar.set_description('Syncnet Train Epoch [{0}/{1}]'.format(epoch, num_epochs)) 263 | prog_bar.set_postfix(Step=global_step, 264 | L1=running_l1_loss / (step + 1), 265 | lr=lr, 266 | Sync=running_sync_loss / (step + 1), 267 | Percep=running_perceptual_loss / (step + 1), 268 | Fake=running_disc_fake_loss / (step + 1), 269 | Real=running_disc_real_loss / (step + 1)) 270 | writer.add_scalar(tag='train/L1_loss', step=global_step, value=running_l1_loss / (step + 1)) 271 | writer.add_scalar(tag='train/Sync_loss', step=global_step, value=running_sync_loss / (step + 1)) 272 | writer.add_scalar(tag='train/Percep_loss', step=global_step, value=running_perceptual_loss / (step + 1)) 273 | writer.add_scalar(tag='train/Real_loss', step=global_step, value=running_disc_real_loss / (step + 1)) 274 | writer.add_scalar(tag='train/Fake_loss', step=global_step, value=running_disc_fake_loss / (step + 1)) 275 | #scheduler.step() 276 | #lr = optimizer.state_dict()['param_groups'][0]['lr'] 277 | epoch += 1 278 | 279 | 280 | def main(): 281 | args = parse_args() 282 | 283 | # 创建checkpoint目录 284 | checkpoint_dir = args.checkpoint_dir 285 | Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) 286 | 287 | checkpoint_path = args.checkpoint_path 288 | disc_checkpoint_path = args.disc_checkpoint_path 289 | syncnet_checkpoint_path = args.syncnet_checkpoint_path 290 | train_type = args.train_type 291 | 292 | train_dataset = FaceDataset(args.data_root, run_type=train_type, img_size=param.img_size) 293 | test_dataset = FaceDataset(args.data_root, run_type='eval', img_size=param.img_size) 294 | 295 | train_data_loader = DataLoader(train_dataset, batch_size=param.batch_size, shuffle=True, 296 | num_workers=param.num_works, drop_last=True) 297 | 298 | test_data_loader = DataLoader(test_dataset, batch_size=param.batch_size, 299 | num_workers=param.num_works, drop_last=True) 300 | 301 | model = FaceCreator() 302 | disc = Discriminator() 303 | 304 | print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 305 | print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad))) 306 | 307 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], 308 | lr=float(param.init_learning_rate), betas=(0.5, 0.999)) 309 | 310 | disc_optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], 311 | lr=float(param.init_learning_rate), betas=(0.5, 0.999)) 312 | start_step = 0 313 | start_epoch = 0 314 | 315 | if checkpoint_path is not None: 316 | model, start_step, start_epoch = load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False) 317 | if disc_checkpoint_path is not None: 318 | disc, start_step, start_epoch = load_checkpoint(disc_checkpoint_path, disc, disc_optimizer, 319 | reset_optimizer=False) 320 | cuda_ids = [int(d_id) for d_id in os.environ.get('CUDA_VISIBLE_DEVICES').split(',')] 321 | model = MyDataParallel(model,device_ids=cuda_ids) 322 | model.to(device) 323 | 324 | disc = MyDataParallel(disc,device_ids=cuda_ids) 325 | disc.to(device) 326 | 327 | 328 | # 装在sync_net 329 | load_checkpoint(syncnet_checkpoint_path, syncnet, None, reset_optimizer=True) 330 | 331 | train(model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, 332 | checkpoint_dir=checkpoint_dir, start_step=start_step, start_epoch=start_epoch) 333 | 334 | 335 | def parse_args(): 336 | # parse args and config 337 | 338 | parser = argparse.ArgumentParser(description='code to train the wav2lip with visual quality discriminator') 339 | parser.add_argument('--data_root', help='Root folder of the preprocessed datasets', required=True, type=str) 340 | parser.add_argument('--checkpoint_dir', help='checkpoint files will be saved to this directory', required=True, 341 | type=str) 342 | parser.add_argument('--syncnet_checkpoint_path', help='Load he pre-trained Expert discriminator', required=True, 343 | type=str) 344 | parser.add_argument('--checkpoint_path', help='Load he pre-trained ', required=False, 345 | type=str) 346 | parser.add_argument('--checkpoint', help='Resume generator from this checkpoint', default=None, type=str) 347 | parser.add_argument('--disc_checkpoint_path', help='Resume qulity disc from this checkpoint', default=None, 348 | type=str) 349 | parser.add_argument('--train_type', help='the train tyep train or test', default='train', 350 | type=str) 351 | args = parser.parse_args() 352 | return args 353 | 354 | 355 | if __name__ == "__main__": 356 | main() 357 | -------------------------------------------------------------------------------- /wldatasets/FaceDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from pathlib import Path 7 | 8 | import torchaudio 9 | import torchaudio.functional as F 10 | from torch.utils.data import Dataset 11 | 12 | from process_util.ParamsUtil import ParamsUtil 13 | 14 | hp = ParamsUtil() 15 | 16 | 17 | class FaceDataset(Dataset): 18 | 19 | def __init__(self, data_dir, 20 | run_type: str = 'train', 21 | **kwargs): 22 | """ 23 | :param data_dir: 数据文件的根目录 24 | :param video_info: 所有视频文件的目录信息,一般放在train.txt文件中。 25 | """ 26 | self.data_dir = data_dir 27 | self.type = run_type 28 | self.img_size = kwargs['img_size'] 29 | self.dirlist = self.__get_split_video_list() 30 | 31 | def __getitem__(self, idx): 32 | """ 33 | 循环去一段视频和一段错误视频进行网络对抗,这里就是取两个不同视频的方法 34 | :param idx: the index of item 35 | :return: image 36 | """ 37 | img_dir = self.dirlist[idx] 38 | # print('img dir:{}'.format(img_dir)) 39 | while 1: 40 | # 随机抽取一个帧作为起始帧进行处理 41 | image_names = self.__get_imgs(img_dir) 42 | if image_names is None or len(image_names) <= 3 * hp.syncnet_T: 43 | continue 44 | 45 | # 获取连续5张脸,正确和错误的 46 | img_name, wrong_img_name = self.__get_choosen(image_names) 47 | window_fnames = self.__get_window(img_name, img_dir) 48 | wrong_window_fnames = self.__get_window(wrong_img_name, img_dir) 49 | if window_fnames is None or wrong_window_fnames is None: 50 | continue 51 | if len(window_fnames) < hp.syncnet_T or len(wrong_window_fnames) < hp.syncnet_T: 52 | continue 53 | 54 | window = self.__read_window(window_fnames) 55 | if window is None: 56 | continue 57 | wrong_window = self.__read_window(wrong_window_fnames) 58 | if wrong_window is None: 59 | continue 60 | 61 | # 对音频进行mel图谱化,并进行对应。 62 | orginal_mel = self.__get_orginal_mel(img_dir) 63 | if orginal_mel is None: 64 | continue 65 | 66 | mel = self.__crop_audio_window(orginal_mel.copy(), int(img_name)) 67 | if mel is None or mel.shape[0] != hp.syncnet_mel_step_size: 68 | continue 69 | 70 | indiv_mels = self.__get_segmented_mels(orginal_mel.copy(), img_name) 71 | 72 | if indiv_mels is None: 73 | continue 74 | # 对window进行范围缩小到0-1之间的array的处理 75 | window = self.__prepare_window(window) 76 | y = window.copy() 77 | # 把图片的上半部分抹去 78 | window[:, :, window.shape[2] // 2:] = 0. 79 | 80 | wrong_window = self.__prepare_window(wrong_window) 81 | x = np.concatenate([window, wrong_window], axis=0) 82 | 83 | x = torch.tensor(x, dtype=torch.float) 84 | y = torch.tensor(y, dtype=torch.float) 85 | mel = torch.tensor(np.transpose(mel, (1, 0)), dtype=torch.float).unsqueeze(0) 86 | indiv_mels = torch.tensor(indiv_mels, dtype=torch.float).unsqueeze(1) 87 | # print('img_dir: {}|window start: {}|wrong window:{}|indiv_mels size: {}|mel size:{}'.format(img_dir,window_fnames[0],wrong_window_fnames[0],len(indiv_mels),mel.size())) 88 | return x, indiv_mels, mel, y 89 | 90 | def __len__(self): 91 | return len(self.dirlist) 92 | 93 | """ 94 | 下面的方法都是内部的方法,用于数据装载时,对数据的处理 95 | """ 96 | 97 | def __get_choosen(self, image_names): 98 | img_name = random.choice(image_names) 99 | wrong_img_name = random.choice(image_names) 100 | while wrong_img_name == img_name: 101 | wrong_img_name = random.choice(image_names) 102 | 103 | return img_name, wrong_img_name 104 | 105 | def __get_split_video_list(self): 106 | load_file = self.data_dir + '/{}.txt'.format(self.type) 107 | dirlist = [] 108 | with open(load_file, 'r') as f: 109 | for line in f: 110 | line = line.strip() 111 | dirlist.append(line) 112 | 113 | return dirlist 114 | 115 | def __get_imgs(self, img_dir): 116 | img_names = [] 117 | for img in Path(self.data_dir + '/' + img_dir).glob('**/*.jpg'): 118 | img = img.stem 119 | img_names.append(img) 120 | img_names.sort(key=int) 121 | return img_names 122 | 123 | def __get_window(self, img_name, img_dir): 124 | start_id = int(img_name) 125 | seek_id = start_id + int(hp.syncnet_T) 126 | vidPath = self.data_dir + '/' + img_dir 127 | window_frames = [] 128 | for frame_id in range(start_id, seek_id): 129 | frame = vidPath + '/{}.jpg'.format(frame_id) 130 | if not Path(frame).exists(): 131 | return None 132 | window_frames.append(frame) 133 | return window_frames 134 | 135 | def __read_window(self, window_fnames): 136 | window = [] 137 | for f_name in window_fnames: 138 | try: 139 | img = cv2.imread(f_name) 140 | img = cv2.resize(img, (self.img_size, self.img_size)) 141 | except Exception as e: 142 | print('Resize the face image error: {}'.format(e)) 143 | return None 144 | window.append(img) 145 | return window 146 | 147 | def __prepare_window(self, window): 148 | # 数组转换3xTxHxW 149 | wa = np.asarray(window) / 255. 150 | wa = np.transpose(wa, (3, 0, 1, 2)) 151 | 152 | return wa 153 | 154 | def __crop_audio_window(self, spec, start_frame): 155 | mel_step_size = hp.syncnet_mel_step_size 156 | fps = hp.fps 157 | start_frame_num = start_frame 158 | start_idx = int(80. * (start_frame_num / float(fps))) 159 | end_idx = start_idx + mel_step_size 160 | 161 | spec = spec[start_idx:end_idx, :] 162 | 163 | return spec 164 | 165 | def __get_segmented_mels(self, spec, image_name): 166 | mels = [] 167 | syncnet_T = 5 168 | mel_step_size = hp.syncnet_mel_step_size 169 | start_frame_num = int(image_name) + 1 170 | if start_frame_num - 2 < 0: 171 | return None 172 | for i in range(start_frame_num, start_frame_num + syncnet_T): 173 | m = self.__crop_audio_window(spec, i - 2) 174 | if m.shape[0] != mel_step_size: 175 | return None 176 | mels.append(m.T) 177 | mels = np.asarray(mels) 178 | 179 | return mels 180 | 181 | def __get_orginal_mel(self, img_dir): 182 | wavfile = self.data_dir + '/' + img_dir + '/audio.wav' 183 | try: 184 | wavform, sf = torchaudio.load(wavfile) 185 | resample = torchaudio.transforms.Resample(sf, 16000) 186 | wavform = resample(wavform) 187 | wavform = F.preemphasis(wavform, hp.preemphasis) 188 | specgram = torchaudio.transforms.MelSpectrogram(sample_rate=hp.sample_rate, 189 | n_fft=hp.n_fft, 190 | power=1., 191 | hop_length=hp.hop_size, 192 | win_length=hp.win_size, 193 | f_min=hp.fmin, 194 | f_max=hp.fmax, 195 | n_mels=hp.num_mels, 196 | normalized=hp.signal_normalization) 197 | orig_mel = specgram(wavform) 198 | orig_mel = F.amplitude_to_DB(orig_mel, multiplier=10., amin=hp.min_level_db, 199 | db_multiplier=hp.ref_level_db,top_db=100) 200 | orig_mel = torch.mean(orig_mel, dim=0) 201 | orig_mel = orig_mel.t().numpy() 202 | except Exception as e: 203 | orig_mel = None 204 | 205 | return orig_mel 206 | -------------------------------------------------------------------------------- /wldatasets/SyncNetDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from pathlib import Path 7 | 8 | import torchaudio 9 | import torchaudio.functional as F 10 | import torchaudio.transforms as T 11 | from torch.utils.data import Dataset 12 | 13 | from process_util.ParamsUtil import ParamsUtil 14 | 15 | 16 | class SyncNetDataset(Dataset): 17 | hp = ParamsUtil() 18 | def __init__(self, data_dir, 19 | run_type: str = 'train', 20 | **kwargs): 21 | self.data_dir = data_dir 22 | self.run_type = run_type 23 | self.dirlist = self.__get_split_video_list() 24 | 25 | def __getitem__(self, idx): 26 | img_dir = self.dirlist[idx] 27 | #print('process imgs v_dir is {}'.format(img_dir)) 28 | while 1: 29 | image_names = self.__get_imgs(img_dir) 30 | if image_names is None or len(image_names)==0: 31 | print('v_dir is {} {} is empty'.format(idx,img_dir)) 32 | continue 33 | #取图片进行训练 34 | 35 | img_name,choosen,y = self.__get_choosen(image_names) 36 | window = self.__get_window(choosen,img_dir) 37 | if window is None or len(window) < 5: 38 | continue 39 | 40 | x = np.concatenate(window, axis=2) / 255. 41 | x = x.transpose(2, 0, 1) 42 | x = x[:, x.shape[1] // 2:] 43 | 44 | mel = self.__get_segment_mel(img_dir,img_name) 45 | if mel.shape[0] != int(self.hp.syncnet_mel_step_size): 46 | #print("mel's shape is {} ,v_dir is {} {},rechoose!!!".format(mel.shape[0],img_dir,img_name)) 47 | continue 48 | 49 | x = torch.tensor(x, dtype=torch.float) 50 | mel = torch.tensor(np.transpose(mel, (1, 0)), dtype=torch.float).unsqueeze(0) 51 | return x, mel, y 52 | 53 | 54 | def __len__(self): 55 | return len(self.dirlist) 56 | 57 | def __get_split_video_list(self): 58 | load_file = self.data_dir + '/{}.txt'.format(self.run_type) 59 | dirlist = [] 60 | with open(load_file, 'r') as f: 61 | for line in f: 62 | line = line.strip() 63 | dirlist.append(line) 64 | return dirlist 65 | 66 | def __get_imgs(self, img_dir): 67 | img_names = [] 68 | for img in Path(self.data_dir+'/'+img_dir).glob('**/*.jpg'): 69 | img = img.stem 70 | img_names.append(img) 71 | img_names.sort(key=int) 72 | return img_names 73 | 74 | def __get_window(self, img_name,img_dir): 75 | start_id = int(img_name) 76 | seek_id = start_id + int(self.hp.syncnet_T) 77 | vidPath = self.data_dir+'/'+img_dir 78 | window_frames = [] 79 | for frame_id in range(start_id, seek_id): 80 | frame = vidPath + '/{}.jpg'.format(frame_id) 81 | if Path(frame).exists() is False: 82 | return None 83 | try: 84 | img = cv2.imread(frame) 85 | img = cv2.resize(img, (self.hp.img_size, self.hp.img_size)) 86 | except Exception as e: 87 | return None 88 | window_frames.append(img) 89 | return window_frames 90 | 91 | def __crop_audio_window(self, spec, start_frame): 92 | # num_frames = (T x hop_size * fps) / sample_rate 93 | start_frame_num = start_frame 94 | start_idx = int(80. * (start_frame_num / float(self.hp.fps))) #80.乘出来刚好是frame的长度 95 | 96 | end_idx = start_idx + int(self.hp.syncnet_mel_step_size) 97 | 98 | spec = spec[start_idx:end_idx, :] 99 | return spec 100 | 101 | def __get_choosen(self, image_names): 102 | img_name = random.choice(image_names) 103 | wrong_img_name = random.choice(image_names) 104 | while wrong_img_name == img_name: 105 | wrong_img_name = random.choice(image_names) 106 | 107 | if random.choice([True, False]): 108 | y = torch.ones(1,dtype=torch.float) 109 | choosen = img_name 110 | else: 111 | y = torch.zeros(1,dtype=torch.float) 112 | choosen = wrong_img_name 113 | return img_name,choosen,y 114 | 115 | def __get_segment_mel(self, img_dir, choosen): 116 | wavfile = self.data_dir + '/' + img_dir + '/audio.wav' 117 | try: 118 | wavform, sf = torchaudio.load(wavfile) 119 | resample = torchaudio.transforms.Resample(sf, 16000) 120 | wavform = resample(wavform) 121 | wavform = F.preemphasis(wavform, self.hp.preemphasis) 122 | specgram = torchaudio.transforms.MelSpectrogram(sample_rate=self.hp.sample_rate, 123 | n_fft=self.hp.n_fft, 124 | power=1., 125 | hop_length=self.hp.hop_size, 126 | win_length=self.hp.win_size, 127 | f_min=self.hp.fmin, 128 | f_max=self.hp.fmax, 129 | n_mels=self.hp.num_mels, 130 | normalized=self.hp.signal_normalization) 131 | orig_mel = specgram(wavform) 132 | orig_mel = F.amplitude_to_DB(orig_mel, multiplier=10., amin=self.hp.min_level_db, 133 | db_multiplier=self.hp.ref_level_db, top_db=100) 134 | orig_mel = torch.mean(orig_mel, dim=0) 135 | orig_mel = orig_mel.t().numpy() 136 | spec = self.__crop_audio_window(orig_mel.copy(), int(choosen)) 137 | #spec = self.__normalization(spec) 138 | except Exception as e: 139 | print("Mel trasfer execption:{}".format(e)) 140 | spec = None 141 | 142 | return spec 143 | -------------------------------------------------------------------------------- /wldatasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rogerle/wav2lipup/7149a5fb30d52f1af7773dccc991c42acf59ac77/wldatasets/__init__.py --------------------------------------------------------------------------------