├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── cog.yaml ├── debug_api.py ├── examples ├── driven_audio │ ├── RD_Radio31_000.wav │ ├── RD_Radio34_002.wav │ ├── RD_Radio36_000.wav │ ├── RD_Radio40_000.wav │ ├── bus_chinese.wav │ ├── chinese_news.wav │ ├── chinese_poem1.wav │ ├── chinese_poem2.wav │ ├── deyu.wav │ ├── eluosi.wav │ ├── fayu.wav │ ├── imagine.wav │ ├── intro.wav │ ├── itosinger1.wav │ ├── japanese.wav │ └── short.wav ├── ref_video │ ├── WDA_AlexandriaOcasioCortez_000.mp4 │ └── WDA_KatieHill_000.mp4 └── source_image │ ├── art_0.png │ ├── art_1.png │ ├── art_10.png │ ├── art_11.png │ ├── art_12.png │ ├── art_13.png │ ├── art_14.png │ ├── art_15.png │ ├── art_16.png │ ├── art_17.png │ ├── art_18.png │ ├── art_19.png │ ├── art_2.png │ ├── art_20.png │ ├── art_3.png │ ├── art_4.png │ ├── art_5.png │ ├── art_6.png │ ├── art_7.png │ ├── art_8.png │ ├── art_9.png │ ├── full3.png │ ├── full4.jpeg │ ├── full_body_1.png │ ├── full_body_2.png │ ├── happy.png │ ├── happy1.png │ ├── people_0.png │ ├── sad.png │ ├── sad1.png │ └── test.png ├── inference.py ├── interpolate.py ├── launcher.py ├── predict.py ├── quick_demo.ipynb ├── req.txt ├── sadtalker_api.py ├── scripts ├── download_models.sh ├── extension.py └── test.sh ├── silent.wav ├── src ├── audio2exp_models │ ├── audio2exp.py │ └── networks.py ├── audio2pose_models │ ├── audio2pose.py │ ├── audio_encoder.py │ ├── cvae.py │ ├── discriminator.py │ ├── networks.py │ └── res_unet.py ├── config │ ├── auido2exp.yaml │ ├── auido2pose.yaml │ ├── facerender.yaml │ ├── facerender_still.yaml │ └── similarity_Lm3D_all.mat ├── face3d │ ├── data │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── flist_dataset.py │ │ ├── image_folder.py │ │ └── template_dataset.py │ ├── extract_kp_videos.py │ ├── extract_kp_videos_safe.py │ ├── models │ │ ├── __init__.py │ │ ├── arcface_torch │ │ │ ├── README.md │ │ │ ├── backbones │ │ │ │ ├── __init__.py │ │ │ │ ├── iresnet.py │ │ │ │ ├── iresnet2060.py │ │ │ │ └── mobilefacenet.py │ │ │ ├── configs │ │ │ │ ├── 3millions.py │ │ │ │ ├── 3millions_pfc.py │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── glint360k_mbf.py │ │ │ │ ├── glint360k_r100.py │ │ │ │ ├── glint360k_r18.py │ │ │ │ ├── glint360k_r34.py │ │ │ │ ├── glint360k_r50.py │ │ │ │ ├── ms1mv3_mbf.py │ │ │ │ ├── ms1mv3_r18.py │ │ │ │ ├── ms1mv3_r2060.py │ │ │ │ ├── ms1mv3_r34.py │ │ │ │ ├── ms1mv3_r50.py │ │ │ │ └── speed.py │ │ │ ├── dataset.py │ │ │ ├── docs │ │ │ │ ├── eval.md │ │ │ │ ├── install.md │ │ │ │ ├── modelzoo.md │ │ │ │ └── speed_benchmark.md │ │ │ ├── eval │ │ │ │ ├── __init__.py │ │ │ │ └── verification.py │ │ │ ├── eval_ijbc.py │ │ │ ├── inference.py │ │ │ ├── losses.py │ │ │ ├── onnx_helper.py │ │ │ ├── onnx_ijbc.py │ │ │ ├── partial_fc.py │ │ │ ├── requirement.txt │ │ │ ├── run.sh │ │ │ ├── torch2onnx.py │ │ │ ├── train.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── plot.py │ │ │ │ ├── utils_amp.py │ │ │ │ ├── utils_callbacks.py │ │ │ │ ├── utils_config.py │ │ │ │ ├── utils_logging.py │ │ │ │ └── utils_os.py │ │ ├── base_model.py │ │ ├── bfm.py │ │ ├── facerecon_model.py │ │ ├── losses.py │ │ ├── networks.py │ │ └── template_model.py │ ├── options │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── inference_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── util │ │ ├── BBRegressorParam_r.mat │ │ ├── __init__.py │ │ ├── detect_lm68.py │ │ ├── generate_list.py │ │ ├── html.py │ │ ├── load_mats.py │ │ ├── nvdiffrast.py │ │ ├── preprocess.py │ │ ├── skin_mask.py │ │ ├── test_mean_face.txt │ │ ├── util.py │ │ └── visualizer.py │ └── visualize.py ├── facerender │ ├── animate.py │ ├── modules │ │ ├── dense_motion.py │ │ ├── discriminator.py │ │ ├── generator.py │ │ ├── keypoint_detector.py │ │ ├── make_animation.py │ │ ├── mapping.py │ │ └── util.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py ├── generate_batch.py ├── generate_facerender_batch.py ├── gradio_demo.py ├── test_audio2coeff copy.py ├── test_audio2coeff.py └── utils │ ├── append_audio.py │ ├── audio.py │ ├── croper.py │ ├── face_enhancer.py │ ├── hparams.py │ ├── init_path.py │ ├── model2safetensor.py │ ├── paste_pic.py │ ├── preprocess.py │ ├── safetensor_helper.py │ ├── text2speech.py │ └── videoio.py └── test_curl.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | examples/results/* 163 | gfpgan/* 164 | checkpoints/* 165 | assets/* 166 | results/* 167 | Dockerfile 168 | start_docker.sh 169 | start.sh 170 | InterpolatedVideo/* 171 | 172 | # Mac 173 | .DS_Store 174 | 175 | # VsCode 176 | .vscode 177 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tencent AI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | cuda: "11.3" 4 | python_version: "3.8" 5 | system_packages: 6 | - "ffmpeg" 7 | - "libgl1-mesa-glx" 8 | - "libglib2.0-0" 9 | python_packages: 10 | - "torch==1.12.1" 11 | - "torchvision==0.13.1" 12 | - "torchaudio==0.12.1" 13 | - "joblib==1.1.0" 14 | - "scikit-image==0.19.3" 15 | - "basicsr==1.4.2" 16 | - "facexlib==0.3.0" 17 | - "resampy==0.3.1" 18 | - "pydub==0.25.1" 19 | - "scipy==1.10.1" 20 | - "kornia==0.6.8" 21 | - "face_alignment==1.3.5" 22 | - "imageio==2.19.3" 23 | - "imageio-ffmpeg==0.4.7" 24 | - "librosa==0.9.2" # 25 | - "tqdm==4.65.0" 26 | - "yacs==0.1.8" 27 | - "gfpgan==1.3.8" 28 | - "dlib-bin==19.24.1" 29 | - "av==10.0.0" 30 | - "trimesh==3.9.20" 31 | run: 32 | - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth" 33 | - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip" 34 | 35 | predict: "predict.py:Predictor" 36 | -------------------------------------------------------------------------------- /examples/driven_audio/RD_Radio31_000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/RD_Radio31_000.wav -------------------------------------------------------------------------------- /examples/driven_audio/RD_Radio34_002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/RD_Radio34_002.wav -------------------------------------------------------------------------------- /examples/driven_audio/RD_Radio36_000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/RD_Radio36_000.wav -------------------------------------------------------------------------------- /examples/driven_audio/RD_Radio40_000.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/RD_Radio40_000.wav -------------------------------------------------------------------------------- /examples/driven_audio/bus_chinese.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/bus_chinese.wav -------------------------------------------------------------------------------- /examples/driven_audio/chinese_news.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/chinese_news.wav -------------------------------------------------------------------------------- /examples/driven_audio/chinese_poem1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/chinese_poem1.wav -------------------------------------------------------------------------------- /examples/driven_audio/chinese_poem2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/chinese_poem2.wav -------------------------------------------------------------------------------- /examples/driven_audio/deyu.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/deyu.wav -------------------------------------------------------------------------------- /examples/driven_audio/eluosi.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/eluosi.wav -------------------------------------------------------------------------------- /examples/driven_audio/fayu.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/fayu.wav -------------------------------------------------------------------------------- /examples/driven_audio/imagine.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/imagine.wav -------------------------------------------------------------------------------- /examples/driven_audio/intro.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/intro.wav -------------------------------------------------------------------------------- /examples/driven_audio/itosinger1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/itosinger1.wav -------------------------------------------------------------------------------- /examples/driven_audio/japanese.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/japanese.wav -------------------------------------------------------------------------------- /examples/driven_audio/short.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/driven_audio/short.wav -------------------------------------------------------------------------------- /examples/ref_video/WDA_AlexandriaOcasioCortez_000.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/ref_video/WDA_AlexandriaOcasioCortez_000.mp4 -------------------------------------------------------------------------------- /examples/ref_video/WDA_KatieHill_000.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/ref_video/WDA_KatieHill_000.mp4 -------------------------------------------------------------------------------- /examples/source_image/art_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_0.png -------------------------------------------------------------------------------- /examples/source_image/art_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_1.png -------------------------------------------------------------------------------- /examples/source_image/art_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_10.png -------------------------------------------------------------------------------- /examples/source_image/art_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_11.png -------------------------------------------------------------------------------- /examples/source_image/art_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_12.png -------------------------------------------------------------------------------- /examples/source_image/art_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_13.png -------------------------------------------------------------------------------- /examples/source_image/art_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_14.png -------------------------------------------------------------------------------- /examples/source_image/art_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_15.png -------------------------------------------------------------------------------- /examples/source_image/art_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_16.png -------------------------------------------------------------------------------- /examples/source_image/art_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_17.png -------------------------------------------------------------------------------- /examples/source_image/art_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_18.png -------------------------------------------------------------------------------- /examples/source_image/art_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_19.png -------------------------------------------------------------------------------- /examples/source_image/art_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_2.png -------------------------------------------------------------------------------- /examples/source_image/art_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_20.png -------------------------------------------------------------------------------- /examples/source_image/art_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_3.png -------------------------------------------------------------------------------- /examples/source_image/art_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_4.png -------------------------------------------------------------------------------- /examples/source_image/art_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_5.png -------------------------------------------------------------------------------- /examples/source_image/art_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_6.png -------------------------------------------------------------------------------- /examples/source_image/art_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_7.png -------------------------------------------------------------------------------- /examples/source_image/art_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_8.png -------------------------------------------------------------------------------- /examples/source_image/art_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/art_9.png -------------------------------------------------------------------------------- /examples/source_image/full3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/full3.png -------------------------------------------------------------------------------- /examples/source_image/full4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/full4.jpeg -------------------------------------------------------------------------------- /examples/source_image/full_body_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/full_body_1.png -------------------------------------------------------------------------------- /examples/source_image/full_body_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/full_body_2.png -------------------------------------------------------------------------------- /examples/source_image/happy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/happy.png -------------------------------------------------------------------------------- /examples/source_image/happy1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/happy1.png -------------------------------------------------------------------------------- /examples/source_image/people_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/people_0.png -------------------------------------------------------------------------------- /examples/source_image/sad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/sad.png -------------------------------------------------------------------------------- /examples/source_image/sad1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/sad1.png -------------------------------------------------------------------------------- /examples/source_image/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/examples/source_image/test.png -------------------------------------------------------------------------------- /interpolate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from moviepy.editor import VideoFileClip, ImageSequenceClip, concatenate_videoclips 4 | import time 5 | import subprocess 6 | 7 | def interpolate_frames(prev_frame, next_frame, factor=2): 8 | t = np.linspace(0, 1, factor+2)[1:-1] 9 | interpolated_frames = [] 10 | for i in range(factor): 11 | alpha = t[i] 12 | beta = 1 - alpha 13 | interpolated_frame = cv2.addWeighted(prev_frame, beta, next_frame, alpha, 0) # 图像融合:addWeighted 14 | interpolated_frames.append(interpolated_frame) 15 | return interpolated_frames 16 | 17 | def interpolate_video(video_path, output_path, factor=2): 18 | clip = VideoFileClip(video_path) 19 | fps = clip.fps 20 | width, height = clip.size 21 | 22 | new_frames = [] 23 | prev_frame = None 24 | for frame in clip.iter_frames(dtype='uint8'): 25 | if prev_frame is not None: 26 | interpolated_frames = interpolate_frames(prev_frame, frame, factor) 27 | new_frames.extend(interpolated_frames) 28 | 29 | prev_frame = frame 30 | 31 | new_clip = ImageSequenceClip(new_frames, fps=factor * fps) 32 | new_clip = new_clip.resize((width, height)) 33 | new_clip = new_clip.set_audio(clip.audio) 34 | 35 | new_clip.write_videofile(output_path, codec='libx264', audio_codec='aac') 36 | 37 | def optical_flow_interpolation(input_path, output_path, factor): # optical_flow method 38 | cap = cv2.VideoCapture(input_path) 39 | if not cap.isOpened(): 40 | print("Error: Unable to open video file.") 41 | return 42 | 43 | fps = cap.get(cv2.CAP_PROP_FPS) 44 | new_fps = int(fps * factor) 45 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 46 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 47 | 48 | frames = [] # List to store interpolated frames 49 | 50 | prev_frame = None 51 | 52 | while True: 53 | ret, frame = cap.read() 54 | if not ret: 55 | break 56 | 57 | 58 | if prev_frame is not None: 59 | prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) 60 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 61 | flow = cv2.calcOpticalFlowFarneback(prev_gray, gray, None, 0.5, 3, 15, 3, 5, 1.2, 0) 62 | 63 | h, w = gray.shape[:2] 64 | x, y = np.meshgrid(np.arange(w), np.arange(h)) 65 | x_flow = x + flow[..., 0] 66 | y_flow = y + flow[..., 1] 67 | 68 | for i in range(new_fps // int(fps)): 69 | alpha = (i + 1) / (new_fps // int(fps) + 1) 70 | 71 | x_interpolated = x + alpha * flow[..., 0] 72 | y_interpolated = y + alpha * flow[..., 1] 73 | 74 | # Interpolate the frame using remap 75 | interpolated_frame = cv2.remap(prev_frame, x_interpolated.astype(np.float32), y_interpolated.astype(np.float32), cv2.INTER_LINEAR) 76 | 77 | # Convert the frame back to BGR color 78 | interpolated_frame_bgr = cv2.cvtColor(interpolated_frame, cv2.COLOR_RGB2BGR) 79 | 80 | # Append the interpolated frame to the list of frames 81 | frames.append(interpolated_frame_bgr) 82 | 83 | prev_frame = frame 84 | 85 | cap.release() 86 | cv2.destroyAllWindows() 87 | 88 | # Create an ImageSequenceClip from the list of interpolated frames 89 | clip = ImageSequenceClip(frames, fps=new_fps) 90 | 91 | # Load the original audio 92 | audio_clip = VideoFileClip(input_path).audio 93 | 94 | # Set the audio for the final video clip 95 | final_clip = clip.set_audio(audio_clip) 96 | 97 | # Write the final video to the output file 98 | final_clip.write_videofile(output_path, codec='libx264', audio_codec='aac', remove_temp=True) 99 | 100 | 101 | def interpolate_frames_by_ffmpeg(video_path, output_path, factor=2): 102 | clip = VideoFileClip(video_path) 103 | fps = clip.fps 104 | out_fps = fps * factor 105 | command = f"""ffmpeg -i {video_path} -filter_complex "minterpolate='fps={out_fps}'" {output_path}""" 106 | subprocess.run(command, shell=True) 107 | 108 | 109 | if __name__ == "__main__": 110 | fps = 8 111 | factor = 4 112 | video_path = "./videos/fps_{}.mp4".format(fps) 113 | output_path = "./TVFIResults/fps_{}_{}X.mp4".format(fps, factor) 114 | start = time.time() 115 | interpolate_video(video_path, output_path, factor=factor) 116 | end = time.time() 117 | print('TVFI Time: {}'.format(end - start)) -------------------------------------------------------------------------------- /req.txt: -------------------------------------------------------------------------------- 1 | llvmlite==0.38.1 2 | numpy==1.21.6 3 | face_alignment==1.3.5 4 | imageio==2.19.3 5 | imageio-ffmpeg==0.4.7 6 | librosa==0.10.0.post2 7 | numba==0.55.1 8 | resampy==0.3.1 9 | pydub==0.25.1 10 | scipy==1.10.1 11 | kornia==0.6.8 12 | tqdm 13 | yacs==0.1.8 14 | pyyaml 15 | joblib==1.1.0 16 | scikit-image==0.19.3 17 | basicsr==1.4.2 18 | facexlib==0.3.0 19 | gradio 20 | gfpgan 21 | av 22 | safetensors 23 | -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | mkdir ./checkpoints 2 | 3 | # lagency download link 4 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth 5 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth 6 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth 7 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar 8 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat 9 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth 10 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar 11 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar 12 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip 13 | # unzip -n ./checkpoints/hub.zip -d ./checkpoints/ 14 | 15 | 16 | #### download the new links. 17 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar 18 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar 19 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors 20 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors 21 | 22 | 23 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip 24 | # unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/ 25 | 26 | ### enhancer 27 | mkdir -p ./gfpgan/weights 28 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth 29 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth 30 | wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth 31 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth 32 | 33 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | # ### some test command before commit. 2 | # python inference.py --preprocess crop --size 256 3 | # python inference.py --preprocess crop --size 512 4 | 5 | # python inference.py --preprocess extcrop --size 256 6 | # python inference.py --preprocess extcrop --size 512 7 | 8 | # python inference.py --preprocess resize --size 256 9 | # python inference.py --preprocess resize --size 512 10 | 11 | # python inference.py --preprocess full --size 256 12 | # python inference.py --preprocess full --size 512 13 | 14 | # python inference.py --preprocess extfull --size 256 15 | # python inference.py --preprocess extfull --size 512 16 | 17 | python inference.py --preprocess full --size 256 --enhancer gfpgan 18 | python inference.py --preprocess full --size 512 --enhancer gfpgan 19 | 20 | python inference.py --preprocess full --size 256 --enhancer gfpgan --still 21 | python inference.py --preprocess full --size 512 --enhancer gfpgan --still 22 | -------------------------------------------------------------------------------- /silent.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/silent.wav -------------------------------------------------------------------------------- /src/audio2exp_models/audio2exp.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class Audio2Exp(nn.Module): 7 | def __init__(self, netG, cfg, device, prepare_training_loss=False): 8 | super(Audio2Exp, self).__init__() 9 | self.cfg = cfg 10 | self.device = device 11 | self.netG = netG.to(device) 12 | 13 | def test(self, batch): 14 | 15 | mel_input = batch['indiv_mels'] # bs T 1 80 16 16 | bs = mel_input.shape[0] 17 | T = mel_input.shape[1] 18 | 19 | exp_coeff_pred = [] 20 | 21 | for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames 22 | 23 | current_mel_input = mel_input[:,i:i+10] 24 | 25 | #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 26 | ref = batch['ref'][:, :, :64][:, i:i+10] 27 | ratio = batch['ratio_gt'][:, i:i+10] #bs T 28 | 29 | audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 30 | 31 | curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 32 | 33 | exp_coeff_pred += [curr_exp_coeff_pred] 34 | 35 | # BS x T x 64 36 | results_dict = { 37 | 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) 38 | } 39 | return results_dict 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/audio2exp_models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | self.use_act = use_act 15 | 16 | def forward(self, x): 17 | out = self.conv_block(x) 18 | if self.residual: 19 | out += x 20 | 21 | if self.use_act: 22 | return self.act(out) 23 | else: 24 | return out 25 | 26 | class SimpleWrapperV2(nn.Module): 27 | def __init__(self) -> None: 28 | super().__init__() 29 | self.audio_encoder = nn.Sequential( 30 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 31 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 32 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 33 | 34 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 35 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 36 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 37 | 38 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 39 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 40 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 41 | 42 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 43 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 44 | 45 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 46 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), 47 | ) 48 | 49 | #### load the pre-trained audio_encoder 50 | #self.audio_encoder = self.audio_encoder.to(device) 51 | ''' 52 | wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] 53 | state_dict = self.audio_encoder.state_dict() 54 | 55 | for k,v in wav2lip_state_dict.items(): 56 | if 'audio_encoder' in k: 57 | print('init:', k) 58 | state_dict[k.replace('module.audio_encoder.', '')] = v 59 | self.audio_encoder.load_state_dict(state_dict) 60 | ''' 61 | 62 | self.mapping1 = nn.Linear(512+64+1, 64) 63 | #self.mapping2 = nn.Linear(30, 64) 64 | #nn.init.constant_(self.mapping1.weight, 0.) 65 | nn.init.constant_(self.mapping1.bias, 0.) 66 | 67 | def forward(self, x, ref, ratio): 68 | # x = self.audio_encoder(x).view(x.size(0), -1) 69 | # ref_reshape = ref.reshape(x.size(0), -1) 70 | # ratio = ratio.reshape(x.size(0), -1) 71 | 72 | # y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) 73 | # out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial 74 | # return out 75 | x1 = self.audio_encoder(x) 76 | x = x1.view(x1.size(0), -1) 77 | ref_reshape = ref.reshape(x.size(0), -1) 78 | ratio = ratio.reshape(x.size(0), -1) 79 | cat = torch.cat([x, ref_reshape, ratio], dim=1) 80 | y = self.mapping1(cat) 81 | out = y.reshape(ref.shape[0], ref.shape[1], -1) 82 | return out 83 | -------------------------------------------------------------------------------- /src/audio2pose_models/audio2pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from src.audio2pose_models.cvae import CVAE 4 | from src.audio2pose_models.discriminator import PoseSequenceDiscriminator 5 | from src.audio2pose_models.audio_encoder import AudioEncoder 6 | 7 | class Audio2Pose(nn.Module): 8 | def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): 9 | super().__init__() 10 | self.cfg = cfg 11 | self.seq_len = cfg.MODEL.CVAE.SEQ_LEN 12 | self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE 13 | self.device = device 14 | 15 | self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) 16 | self.audio_encoder.eval() 17 | for param in self.audio_encoder.parameters(): 18 | param.requires_grad = False 19 | 20 | self.netG = CVAE(cfg) 21 | self.netD_motion = PoseSequenceDiscriminator(cfg) 22 | 23 | 24 | def forward(self, x): 25 | 26 | batch = {} 27 | coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 28 | batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6 29 | batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6 30 | batch['class'] = x['class'].squeeze(0).cuda() # bs 31 | indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 32 | 33 | # forward 34 | audio_emb_list = [] 35 | audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 36 | batch['audio_emb'] = audio_emb 37 | batch = self.netG(batch) 38 | 39 | pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 40 | pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6 41 | pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6 42 | 43 | batch['pose_pred'] = pose_pred 44 | batch['pose_gt'] = pose_gt 45 | 46 | return batch 47 | 48 | def test(self, x): 49 | 50 | batch = {} 51 | ref = x['ref'] #bs 1 70 52 | batch['ref'] = x['ref'][:,0,-6:] 53 | batch['class'] = x['class'] 54 | bs = ref.shape[0] 55 | 56 | indiv_mels= x['indiv_mels'] # bs T 1 80 16 57 | indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame 58 | num_frames = x['num_frames'] 59 | num_frames = int(num_frames) - 1 60 | 61 | # 62 | div = num_frames//self.seq_len 63 | re = num_frames%self.seq_len 64 | audio_emb_list = [] 65 | pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, 66 | device=batch['ref'].device)] 67 | 68 | for i in range(div): 69 | z = torch.randn(bs, self.latent_dim).to(ref.device) 70 | batch['z'] = z 71 | audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 72 | batch['audio_emb'] = audio_emb 73 | batch = self.netG.test(batch) 74 | pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 75 | 76 | if re != 0: 77 | z = torch.randn(bs, self.latent_dim).to(ref.device) 78 | batch['z'] = z 79 | audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 80 | if audio_emb.shape[1] != self.seq_len: 81 | pad_dim = self.seq_len-audio_emb.shape[1] 82 | pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) 83 | audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) 84 | batch['audio_emb'] = audio_emb 85 | batch = self.netG.test(batch) 86 | pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) 87 | 88 | pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) 89 | batch['pose_motion_pred'] = pose_motion_pred 90 | 91 | pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 92 | 93 | batch['pose_pred'] = pose_pred 94 | return batch 95 | -------------------------------------------------------------------------------- /src/audio2pose_models/audio_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | 15 | def forward(self, x): 16 | out = self.conv_block(x) 17 | if self.residual: 18 | out += x 19 | return self.act(out) 20 | 21 | class AudioEncoder(nn.Module): 22 | def __init__(self, wav2lip_checkpoint, device): 23 | super(AudioEncoder, self).__init__() 24 | 25 | self.audio_encoder = nn.Sequential( 26 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 27 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 28 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 29 | 30 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 31 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 32 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 33 | 34 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 35 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 36 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 37 | 38 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 39 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 40 | 41 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 42 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 43 | 44 | #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. 45 | # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] 46 | # state_dict = self.audio_encoder.state_dict() 47 | 48 | # for k,v in wav2lip_state_dict.items(): 49 | # if 'audio_encoder' in k: 50 | # state_dict[k.replace('module.audio_encoder.', '')] = v 51 | # self.audio_encoder.load_state_dict(state_dict) 52 | 53 | 54 | def forward(self, audio_sequences): 55 | # audio_sequences = (B, T, 1, 80, 16) 56 | B = audio_sequences.size(0) 57 | 58 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) 59 | 60 | audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 61 | dim = audio_embedding.shape[1] 62 | audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) 63 | 64 | return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 65 | -------------------------------------------------------------------------------- /src/audio2pose_models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | class ConvNormRelu(nn.Module): 6 | def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, 7 | kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): 8 | super().__init__() 9 | if kernel_size is None: 10 | if downsample: 11 | kernel_size, stride, padding = 4, 2, 1 12 | else: 13 | kernel_size, stride, padding = 3, 1, 1 14 | 15 | if conv_type == '2d': 16 | self.conv = nn.Conv2d( 17 | in_channels, 18 | out_channels, 19 | kernel_size, 20 | stride, 21 | padding, 22 | bias=False, 23 | ) 24 | if norm == 'BN': 25 | self.norm = nn.BatchNorm2d(out_channels) 26 | elif norm == 'IN': 27 | self.norm = nn.InstanceNorm2d(out_channels) 28 | else: 29 | raise NotImplementedError 30 | elif conv_type == '1d': 31 | self.conv = nn.Conv1d( 32 | in_channels, 33 | out_channels, 34 | kernel_size, 35 | stride, 36 | padding, 37 | bias=False, 38 | ) 39 | if norm == 'BN': 40 | self.norm = nn.BatchNorm1d(out_channels) 41 | elif norm == 'IN': 42 | self.norm = nn.InstanceNorm1d(out_channels) 43 | else: 44 | raise NotImplementedError 45 | nn.init.kaiming_normal_(self.conv.weight) 46 | 47 | self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) 48 | 49 | def forward(self, x): 50 | x = self.conv(x) 51 | if isinstance(self.norm, nn.InstanceNorm1d): 52 | x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] 53 | else: 54 | x = self.norm(x) 55 | x = self.act(x) 56 | return x 57 | 58 | 59 | class PoseSequenceDiscriminator(nn.Module): 60 | def __init__(self, cfg): 61 | super().__init__() 62 | self.cfg = cfg 63 | leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU 64 | 65 | self.seq = nn.Sequential( 66 | ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 67 | ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 68 | ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 69 | nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 70 | ) 71 | 72 | def forward(self, x): 73 | x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) 74 | x = self.seq(x) 75 | x = x.squeeze(1) 76 | return x -------------------------------------------------------------------------------- /src/audio2pose_models/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ResidualConv(nn.Module): 6 | def __init__(self, input_dim, output_dim, stride, padding): 7 | super(ResidualConv, self).__init__() 8 | 9 | self.conv_block = nn.Sequential( 10 | nn.BatchNorm2d(input_dim), 11 | nn.ReLU(), 12 | nn.Conv2d( 13 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 14 | ), 15 | nn.BatchNorm2d(output_dim), 16 | nn.ReLU(), 17 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 18 | ) 19 | self.conv_skip = nn.Sequential( 20 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 21 | nn.BatchNorm2d(output_dim), 22 | ) 23 | 24 | def forward(self, x): 25 | 26 | return self.conv_block(x) + self.conv_skip(x) 27 | 28 | 29 | class Upsample(nn.Module): 30 | def __init__(self, input_dim, output_dim, kernel, stride): 31 | super(Upsample, self).__init__() 32 | 33 | self.upsample = nn.ConvTranspose2d( 34 | input_dim, output_dim, kernel_size=kernel, stride=stride 35 | ) 36 | 37 | def forward(self, x): 38 | return self.upsample(x) 39 | 40 | 41 | class Squeeze_Excite_Block(nn.Module): 42 | def __init__(self, channel, reduction=16): 43 | super(Squeeze_Excite_Block, self).__init__() 44 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 45 | self.fc = nn.Sequential( 46 | nn.Linear(channel, channel // reduction, bias=False), 47 | nn.ReLU(inplace=True), 48 | nn.Linear(channel // reduction, channel, bias=False), 49 | nn.Sigmoid(), 50 | ) 51 | 52 | def forward(self, x): 53 | b, c, _, _ = x.size() 54 | y = self.avg_pool(x).view(b, c) 55 | y = self.fc(y).view(b, c, 1, 1) 56 | return x * y.expand_as(x) 57 | 58 | 59 | class ASPP(nn.Module): 60 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): 61 | super(ASPP, self).__init__() 62 | 63 | self.aspp_block1 = nn.Sequential( 64 | nn.Conv2d( 65 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] 66 | ), 67 | nn.ReLU(inplace=True), 68 | nn.BatchNorm2d(out_dims), 69 | ) 70 | self.aspp_block2 = nn.Sequential( 71 | nn.Conv2d( 72 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] 73 | ), 74 | nn.ReLU(inplace=True), 75 | nn.BatchNorm2d(out_dims), 76 | ) 77 | self.aspp_block3 = nn.Sequential( 78 | nn.Conv2d( 79 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] 80 | ), 81 | nn.ReLU(inplace=True), 82 | nn.BatchNorm2d(out_dims), 83 | ) 84 | 85 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) 86 | self._init_weights() 87 | 88 | def forward(self, x): 89 | x1 = self.aspp_block1(x) 90 | x2 = self.aspp_block2(x) 91 | x3 = self.aspp_block3(x) 92 | out = torch.cat([x1, x2, x3], dim=1) 93 | return self.output(out) 94 | 95 | def _init_weights(self): 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_normal_(m.weight) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | 103 | 104 | class Upsample_(nn.Module): 105 | def __init__(self, scale=2): 106 | super(Upsample_, self).__init__() 107 | 108 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) 109 | 110 | def forward(self, x): 111 | return self.upsample(x) 112 | 113 | 114 | class AttentionBlock(nn.Module): 115 | def __init__(self, input_encoder, input_decoder, output_dim): 116 | super(AttentionBlock, self).__init__() 117 | 118 | self.conv_encoder = nn.Sequential( 119 | nn.BatchNorm2d(input_encoder), 120 | nn.ReLU(), 121 | nn.Conv2d(input_encoder, output_dim, 3, padding=1), 122 | nn.MaxPool2d(2, 2), 123 | ) 124 | 125 | self.conv_decoder = nn.Sequential( 126 | nn.BatchNorm2d(input_decoder), 127 | nn.ReLU(), 128 | nn.Conv2d(input_decoder, output_dim, 3, padding=1), 129 | ) 130 | 131 | self.conv_attn = nn.Sequential( 132 | nn.BatchNorm2d(output_dim), 133 | nn.ReLU(), 134 | nn.Conv2d(output_dim, 1, 1), 135 | ) 136 | 137 | def forward(self, x1, x2): 138 | out = self.conv_encoder(x1) + self.conv_decoder(x2) 139 | out = self.conv_attn(out) 140 | return out * x2 -------------------------------------------------------------------------------- /src/audio2pose_models/res_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from src.audio2pose_models.networks import ResidualConv, Upsample 4 | 5 | 6 | class ResUnet(nn.Module): 7 | def __init__(self, channel=1, filters=[32, 64, 128, 256]): 8 | super(ResUnet, self).__init__() 9 | 10 | self.input_layer = nn.Sequential( 11 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), 12 | nn.BatchNorm2d(filters[0]), 13 | nn.ReLU(), 14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 15 | ) 16 | self.input_skip = nn.Sequential( 17 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) 18 | ) 19 | 20 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) 21 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) 22 | 23 | self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) 24 | 25 | self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) 26 | self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) 27 | 28 | self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) 29 | self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) 30 | 31 | self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) 32 | self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) 33 | 34 | self.output_layer = nn.Sequential( 35 | nn.Conv2d(filters[0], 1, 1, 1), 36 | nn.Sigmoid(), 37 | ) 38 | 39 | def forward(self, x): 40 | # Encode 41 | # x1 = self.input_layer(x) + self.input_skip(x) 42 | x_l = self.input_layer(x) 43 | x_s = self.input_skip(x) 44 | x1 = x_l + x_s 45 | x2 = self.residual_conv_1(x1) 46 | x3 = self.residual_conv_2(x2) 47 | # Bridge 48 | x4 = self.bridge(x3) 49 | 50 | # Decode 51 | x4 = self.upsample_1(x4) 52 | x5 = torch.cat([x4, x3], dim=1) 53 | 54 | x6 = self.up_residual_conv1(x5) 55 | 56 | x6 = self.upsample_2(x6) 57 | x7 = torch.cat([x6, x2], dim=1) 58 | 59 | x8 = self.up_residual_conv2(x7) 60 | 61 | x8 = self.upsample_3(x8) 62 | x9 = torch.cat([x8, x1], dim=1) 63 | 64 | x10 = self.up_residual_conv3(x9) 65 | 66 | output = self.output_layer(x10) 67 | 68 | return output -------------------------------------------------------------------------------- /src/config/auido2exp.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt 3 | EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt 4 | TRAIN_BATCH_SIZE: 32 5 | EVAL_BATCH_SIZE: 32 6 | EXP: True 7 | EXP_DIM: 64 8 | FRAME_LEN: 32 9 | COEFF_LEN: 73 10 | NUM_CLASSES: 46 11 | AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav 12 | COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm 13 | LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb 14 | DEBUG: True 15 | NUM_REPEATS: 2 16 | T: 40 17 | 18 | 19 | MODEL: 20 | FRAMEWORK: V2 21 | AUDIOENCODER: 22 | LEAKY_RELU: True 23 | NORM: 'IN' 24 | DISCRIMINATOR: 25 | LEAKY_RELU: False 26 | INPUT_CHANNELS: 6 27 | CVAE: 28 | AUDIO_EMB_IN_SIZE: 512 29 | AUDIO_EMB_OUT_SIZE: 128 30 | SEQ_LEN: 32 31 | LATENT_SIZE: 256 32 | ENCODER_LAYER_SIZES: [192, 1024] 33 | DECODER_LAYER_SIZES: [1024, 192] 34 | 35 | 36 | TRAIN: 37 | MAX_EPOCH: 300 38 | GENERATOR: 39 | LR: 2.0e-5 40 | DISCRIMINATOR: 41 | LR: 1.0e-5 42 | LOSS: 43 | W_FEAT: 0 44 | W_COEFF_EXP: 2 45 | W_LM: 1.0e-2 46 | W_LM_MOUTH: 0 47 | W_REG: 0 48 | W_SYNC: 0 49 | W_COLOR: 0 50 | W_EXPRESSION: 0 51 | W_LIPREADING: 0.01 52 | W_LIPREADING_VV: 0 53 | W_EYE_BLINK: 4 54 | 55 | TAG: 56 | NAME: small_dataset 57 | 58 | 59 | -------------------------------------------------------------------------------- /src/config/auido2pose.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt 3 | EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt 4 | TRAIN_BATCH_SIZE: 64 5 | EVAL_BATCH_SIZE: 1 6 | EXP: True 7 | EXP_DIM: 64 8 | FRAME_LEN: 32 9 | COEFF_LEN: 73 10 | NUM_CLASSES: 46 11 | AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav 12 | COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb 13 | DEBUG: True 14 | 15 | 16 | MODEL: 17 | AUDIOENCODER: 18 | LEAKY_RELU: True 19 | NORM: 'IN' 20 | DISCRIMINATOR: 21 | LEAKY_RELU: False 22 | INPUT_CHANNELS: 6 23 | CVAE: 24 | AUDIO_EMB_IN_SIZE: 512 25 | AUDIO_EMB_OUT_SIZE: 6 26 | SEQ_LEN: 32 27 | LATENT_SIZE: 64 28 | ENCODER_LAYER_SIZES: [192, 128] 29 | DECODER_LAYER_SIZES: [128, 192] 30 | 31 | 32 | TRAIN: 33 | MAX_EPOCH: 150 34 | GENERATOR: 35 | LR: 1.0e-4 36 | DISCRIMINATOR: 37 | LR: 1.0e-4 38 | LOSS: 39 | LAMBDA_REG: 1 40 | LAMBDA_LANDMARKS: 0 41 | LAMBDA_VERTICES: 0 42 | LAMBDA_GAN_MOTION: 0.7 43 | LAMBDA_GAN_COEFF: 0 44 | LAMBDA_KL: 1 45 | 46 | TAG: 47 | NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder 48 | 49 | 50 | -------------------------------------------------------------------------------- /src/config/facerender.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | common_params: 3 | num_kp: 15 4 | image_channel: 3 5 | feature_channel: 32 6 | estimate_jacobian: False # True 7 | kp_detector_params: 8 | temperature: 0.1 9 | block_expansion: 32 10 | max_features: 1024 11 | scale_factor: 0.25 # 0.25 12 | num_blocks: 5 13 | reshape_channel: 16384 # 16384 = 1024 * 16 14 | reshape_depth: 16 15 | he_estimator_params: 16 | block_expansion: 64 17 | max_features: 2048 18 | num_bins: 66 19 | generator_params: 20 | block_expansion: 64 21 | max_features: 512 22 | num_down_blocks: 2 23 | reshape_channel: 32 24 | reshape_depth: 16 # 512 = 32 * 16 25 | num_resblocks: 6 26 | estimate_occlusion_map: True 27 | dense_motion_params: 28 | block_expansion: 32 29 | max_features: 1024 30 | num_blocks: 5 31 | reshape_depth: 16 32 | compress: 4 33 | discriminator_params: 34 | scales: [1] 35 | block_expansion: 32 36 | max_features: 512 37 | num_blocks: 4 38 | sn: True 39 | mapping_params: 40 | coeff_nc: 70 41 | descriptor_nc: 1024 42 | layer: 3 43 | num_kp: 15 44 | num_bins: 66 45 | 46 | -------------------------------------------------------------------------------- /src/config/facerender_still.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | common_params: 3 | num_kp: 15 4 | image_channel: 3 5 | feature_channel: 32 6 | estimate_jacobian: False # True 7 | kp_detector_params: 8 | temperature: 0.1 9 | block_expansion: 32 10 | max_features: 1024 11 | scale_factor: 0.25 # 0.25 12 | num_blocks: 5 13 | reshape_channel: 16384 # 16384 = 1024 * 16 14 | reshape_depth: 16 15 | he_estimator_params: 16 | block_expansion: 64 17 | max_features: 2048 18 | num_bins: 66 19 | generator_params: 20 | block_expansion: 64 21 | max_features: 512 22 | num_down_blocks: 2 23 | reshape_channel: 32 24 | reshape_depth: 16 # 512 = 32 * 16 25 | num_resblocks: 6 26 | estimate_occlusion_map: True 27 | dense_motion_params: 28 | block_expansion: 32 29 | max_features: 1024 30 | num_blocks: 5 31 | reshape_depth: 16 32 | compress: 4 33 | discriminator_params: 34 | scales: [1] 35 | block_expansion: 32 36 | max_features: 512 37 | num_blocks: 4 38 | sn: True 39 | mapping_params: 40 | coeff_nc: 73 41 | descriptor_nc: 1024 42 | layer: 3 43 | num_kp: 15 44 | num_bins: 66 45 | 46 | -------------------------------------------------------------------------------- /src/config/similarity_Lm3D_all.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/src/config/similarity_Lm3D_all.mat -------------------------------------------------------------------------------- /src/face3d/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import numpy as np 14 | import importlib 15 | import torch.utils.data 16 | from face3d.data.base_dataset import BaseDataset 17 | 18 | 19 | def find_dataset_using_name(dataset_name): 20 | """Import the module "data/[dataset_name]_dataset.py". 21 | 22 | In the file, the class called DatasetNameDataset() will 23 | be instantiated. It has to be a subclass of BaseDataset, 24 | and it is case-insensitive. 25 | """ 26 | dataset_filename = "data." + dataset_name + "_dataset" 27 | datasetlib = importlib.import_module(dataset_filename) 28 | 29 | dataset = None 30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 31 | for name, cls in datasetlib.__dict__.items(): 32 | if name.lower() == target_dataset_name.lower() \ 33 | and issubclass(cls, BaseDataset): 34 | dataset = cls 35 | 36 | if dataset is None: 37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 38 | 39 | return dataset 40 | 41 | 42 | def get_option_setter(dataset_name): 43 | """Return the static method of the dataset class.""" 44 | dataset_class = find_dataset_using_name(dataset_name) 45 | return dataset_class.modify_commandline_options 46 | 47 | 48 | def create_dataset(opt, rank=0): 49 | """Create a dataset given the option. 50 | 51 | This function wraps the class CustomDatasetDataLoader. 52 | This is the main interface between this package and 'train.py'/'test.py' 53 | 54 | Example: 55 | >>> from data import create_dataset 56 | >>> dataset = create_dataset(opt) 57 | """ 58 | data_loader = CustomDatasetDataLoader(opt, rank=rank) 59 | dataset = data_loader.load_data() 60 | return dataset 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt, rank=0): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | self.sampler = None 75 | print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) 76 | if opt.use_ddp and opt.isTrain: 77 | world_size = opt.world_size 78 | self.sampler = torch.utils.data.distributed.DistributedSampler( 79 | self.dataset, 80 | num_replicas=world_size, 81 | rank=rank, 82 | shuffle=not opt.serial_batches 83 | ) 84 | self.dataloader = torch.utils.data.DataLoader( 85 | self.dataset, 86 | sampler=self.sampler, 87 | num_workers=int(opt.num_threads / world_size), 88 | batch_size=int(opt.batch_size / world_size), 89 | drop_last=True) 90 | else: 91 | self.dataloader = torch.utils.data.DataLoader( 92 | self.dataset, 93 | batch_size=opt.batch_size, 94 | shuffle=(not opt.serial_batches) and opt.isTrain, 95 | num_workers=int(opt.num_threads), 96 | drop_last=True 97 | ) 98 | 99 | def set_epoch(self, epoch): 100 | self.dataset.current_epoch = epoch 101 | if self.sampler is not None: 102 | self.sampler.set_epoch(epoch) 103 | 104 | def load_data(self): 105 | return self 106 | 107 | def __len__(self): 108 | """Return the number of data in the dataset""" 109 | return min(len(self.dataset), self.opt.max_dataset_size) 110 | 111 | def __iter__(self): 112 | """Return a batch of data""" 113 | for i, data in enumerate(self.dataloader): 114 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 115 | break 116 | yield data 117 | -------------------------------------------------------------------------------- /src/face3d/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | # self.root = opt.dataroot 31 | self.current_epoch = 0 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_transform(grayscale=False): 65 | transform_list = [] 66 | if grayscale: 67 | transform_list.append(transforms.Grayscale(1)) 68 | transform_list += [transforms.ToTensor()] 69 | return transforms.Compose(transform_list) 70 | 71 | def get_affine_mat(opt, size): 72 | shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False 73 | w, h = size 74 | 75 | if 'shift' in opt.preprocess: 76 | shift_pixs = int(opt.shift_pixs) 77 | shift_x = random.randint(-shift_pixs, shift_pixs) 78 | shift_y = random.randint(-shift_pixs, shift_pixs) 79 | if 'scale' in opt.preprocess: 80 | scale = 1 + opt.scale_delta * (2 * random.random() - 1) 81 | if 'rot' in opt.preprocess: 82 | rot_angle = opt.rot_angle * (2 * random.random() - 1) 83 | rot_rad = -rot_angle * np.pi/180 84 | if 'flip' in opt.preprocess: 85 | flip = random.random() > 0.5 86 | 87 | shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) 88 | flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) 89 | shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) 90 | rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) 91 | scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) 92 | shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) 93 | 94 | affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin 95 | affine_inv = np.linalg.inv(affine) 96 | return affine, affine_inv, flip 97 | 98 | def apply_img_affine(img, affine_inv, method=Image.BICUBIC): 99 | return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) 100 | 101 | def apply_lm_affine(landmark, affine, flip, size): 102 | _, h = size 103 | lm = landmark.copy() 104 | lm[:, 1] = h - 1 - lm[:, 1] 105 | lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) 106 | lm = lm @ np.transpose(affine) 107 | lm[:, :2] = lm[:, :2] / lm[:, 2:] 108 | lm = lm[:, :2] 109 | lm[:, 1] = h - 1 - lm[:, 1] 110 | if flip: 111 | lm_ = lm.copy() 112 | lm_[:17] = lm[16::-1] 113 | lm_[17:22] = lm[26:21:-1] 114 | lm_[22:27] = lm[21:16:-1] 115 | lm_[31:36] = lm[35:30:-1] 116 | lm_[36:40] = lm[45:41:-1] 117 | lm_[40:42] = lm[47:45:-1] 118 | lm_[42:46] = lm[39:35:-1] 119 | lm_[46:48] = lm[41:39:-1] 120 | lm_[48:55] = lm[54:47:-1] 121 | lm_[55:60] = lm[59:54:-1] 122 | lm_[60:65] = lm[64:59:-1] 123 | lm_[65:68] = lm[67:64:-1] 124 | lm = lm_ 125 | return lm 126 | -------------------------------------------------------------------------------- /src/face3d/data/flist_dataset.py: -------------------------------------------------------------------------------- 1 | """This script defines the custom dataset for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os.path 5 | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | import random 9 | import util.util as util 10 | import numpy as np 11 | import json 12 | import torch 13 | from scipy.io import loadmat, savemat 14 | import pickle 15 | from util.preprocess import align_img, estimate_norm 16 | from util.load_mats import load_lm3d 17 | 18 | 19 | def default_flist_reader(flist): 20 | """ 21 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 22 | """ 23 | imlist = [] 24 | with open(flist, 'r') as rf: 25 | for line in rf.readlines(): 26 | impath = line.strip() 27 | imlist.append(impath) 28 | 29 | return imlist 30 | 31 | def jason_flist_reader(flist): 32 | with open(flist, 'r') as fp: 33 | info = json.load(fp) 34 | return info 35 | 36 | def parse_label(label): 37 | return torch.tensor(np.array(label).astype(np.float32)) 38 | 39 | 40 | class FlistDataset(BaseDataset): 41 | """ 42 | It requires one directories to host training images '/path/to/data/train' 43 | You can train the model with the dataset flag '--dataroot /path/to/data'. 44 | """ 45 | 46 | def __init__(self, opt): 47 | """Initialize this dataset class. 48 | 49 | Parameters: 50 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 51 | """ 52 | BaseDataset.__init__(self, opt) 53 | 54 | self.lm3d_std = load_lm3d(opt.bfm_folder) 55 | 56 | msk_names = default_flist_reader(opt.flist) 57 | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] 58 | 59 | self.size = len(self.msk_paths) 60 | self.opt = opt 61 | 62 | self.name = 'train' if opt.isTrain else 'val' 63 | if '_' in opt.flist: 64 | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] 65 | 66 | 67 | def __getitem__(self, index): 68 | """Return a data point and its metadata information. 69 | 70 | Parameters: 71 | index (int) -- a random integer for data indexing 72 | 73 | Returns a dictionary that contains A, B, A_paths and B_paths 74 | img (tensor) -- an image in the input domain 75 | msk (tensor) -- its corresponding attention mask 76 | lm (tensor) -- its corresponding 3d landmarks 77 | im_paths (str) -- image paths 78 | aug_flag (bool) -- a flag used to tell whether its raw or augmented 79 | """ 80 | msk_path = self.msk_paths[index % self.size] # make sure index is within then range 81 | img_path = msk_path.replace('mask/', '') 82 | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' 83 | 84 | raw_img = Image.open(img_path).convert('RGB') 85 | raw_msk = Image.open(msk_path).convert('RGB') 86 | raw_lm = np.loadtxt(lm_path).astype(np.float32) 87 | 88 | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) 89 | 90 | aug_flag = self.opt.use_aug and self.opt.isTrain 91 | if aug_flag: 92 | img, lm, msk = self._augmentation(img, lm, self.opt, msk) 93 | 94 | _, H = img.size 95 | M = estimate_norm(lm, H) 96 | transform = get_transform() 97 | img_tensor = transform(img) 98 | msk_tensor = transform(msk)[:1, ...] 99 | lm_tensor = parse_label(lm) 100 | M_tensor = parse_label(M) 101 | 102 | 103 | return {'imgs': img_tensor, 104 | 'lms': lm_tensor, 105 | 'msks': msk_tensor, 106 | 'M': M_tensor, 107 | 'im_paths': img_path, 108 | 'aug_flag': aug_flag, 109 | 'dataset': self.name} 110 | 111 | def _augmentation(self, img, lm, opt, msk=None): 112 | affine, affine_inv, flip = get_affine_mat(opt, img.size) 113 | img = apply_img_affine(img, affine_inv) 114 | lm = apply_lm_affine(lm, affine, flip, img.size) 115 | if msk is not None: 116 | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) 117 | return img, lm, msk 118 | 119 | 120 | 121 | 122 | def __len__(self): 123 | """Return the total number of images in the dataset. 124 | """ 125 | return self.size 126 | -------------------------------------------------------------------------------- /src/face3d/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | import numpy as np 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | '.tif', '.TIF', '.tiff', '.TIFF', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, max_dataset_size=float("inf")): 25 | images = [] 26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | return images[:min(max_dataset_size, len(images))] 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /src/face3d/data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /src/face3d/extract_kp_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import glob 5 | import argparse 6 | import face_alignment 7 | import numpy as np 8 | from PIL import Image 9 | from tqdm import tqdm 10 | from itertools import cycle 11 | 12 | from torch.multiprocessing import Pool, Process, set_start_method 13 | 14 | class KeypointExtractor(): 15 | def __init__(self, device): 16 | self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, 17 | device=device) 18 | 19 | def extract_keypoint(self, images, name=None, info=True): 20 | if isinstance(images, list): 21 | keypoints = [] 22 | if info: 23 | i_range = tqdm(images,desc='landmark Det:') 24 | else: 25 | i_range = images 26 | 27 | for image in i_range: 28 | current_kp = self.extract_keypoint(image) 29 | if np.mean(current_kp) == -1 and keypoints: 30 | keypoints.append(keypoints[-1]) 31 | else: 32 | keypoints.append(current_kp[None]) 33 | 34 | keypoints = np.concatenate(keypoints, 0) 35 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 36 | return keypoints 37 | else: 38 | while True: 39 | try: 40 | keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] 41 | break 42 | except RuntimeError as e: 43 | if str(e).startswith('CUDA'): 44 | print("Warning: out of memory, sleep for 1s") 45 | time.sleep(1) 46 | else: 47 | print(e) 48 | break 49 | except TypeError: 50 | print('No face detected in this image') 51 | shape = [68, 2] 52 | keypoints = -1. * np.ones(shape) 53 | break 54 | if name is not None: 55 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 56 | return keypoints 57 | 58 | def read_video(filename): 59 | frames = [] 60 | cap = cv2.VideoCapture(filename) 61 | while cap.isOpened(): 62 | ret, frame = cap.read() 63 | if ret: 64 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 65 | frame = Image.fromarray(frame) 66 | frames.append(frame) 67 | else: 68 | break 69 | cap.release() 70 | return frames 71 | 72 | def run(data): 73 | filename, opt, device = data 74 | os.environ['CUDA_VISIBLE_DEVICES'] = device 75 | kp_extractor = KeypointExtractor() 76 | images = read_video(filename) 77 | name = filename.split('/')[-2:] 78 | os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) 79 | kp_extractor.extract_keypoint( 80 | images, 81 | name=os.path.join(opt.output_dir, name[-2], name[-1]) 82 | ) 83 | 84 | if __name__ == '__main__': 85 | set_start_method('spawn') 86 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 87 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 88 | parser.add_argument('--output_dir', type=str, help='the folder of the output files') 89 | parser.add_argument('--device_ids', type=str, default='0,1') 90 | parser.add_argument('--workers', type=int, default=4) 91 | 92 | opt = parser.parse_args() 93 | filenames = list() 94 | VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} 95 | VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) 96 | extensions = VIDEO_EXTENSIONS 97 | 98 | for ext in extensions: 99 | os.listdir(f'{opt.input_dir}') 100 | print(f'{opt.input_dir}/*.{ext}') 101 | filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) 102 | print('Total number of videos:', len(filenames)) 103 | pool = Pool(opt.workers) 104 | args_list = cycle([opt]) 105 | device_ids = opt.device_ids.split(",") 106 | device_ids = cycle(device_ids) 107 | for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): 108 | None 109 | -------------------------------------------------------------------------------- /src/face3d/extract_kp_videos_safe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import glob 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | from tqdm import tqdm 10 | from itertools import cycle 11 | from facexlib.alignment import init_alignment_model, landmark_98_to_68 12 | from facexlib.detection import init_detection_model 13 | from torch.multiprocessing import Pool, Process, set_start_method 14 | 15 | 16 | class KeypointExtractor(): 17 | def __init__(self, device='cuda'): 18 | 19 | ### gfpgan/weights 20 | try: 21 | import webui # in webui 22 | root_path = 'extensions/SadTalker/gfpgan/weights' 23 | 24 | except: 25 | root_path = 'gfpgan/weights' 26 | 27 | self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path) 28 | self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path) 29 | 30 | def extract_keypoint(self, images, name=None, info=True): 31 | if isinstance(images, list): 32 | keypoints = [] 33 | if info: 34 | i_range = tqdm(images,desc='landmark Det:') 35 | else: 36 | i_range = images 37 | 38 | for image in i_range: 39 | current_kp = self.extract_keypoint(image) 40 | # current_kp = self.detector.get_landmarks(np.array(image)) 41 | if np.mean(current_kp) == -1 and keypoints: 42 | keypoints.append(keypoints[-1]) 43 | else: 44 | keypoints.append(current_kp[None]) 45 | 46 | keypoints = np.concatenate(keypoints, 0) 47 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 48 | return keypoints 49 | else: 50 | while True: 51 | try: 52 | with torch.no_grad(): 53 | # face detection -> face alignment. 54 | img = np.array(images) 55 | bboxes = self.det_net.detect_faces(images, 0.97) 56 | 57 | bboxes = bboxes[0] 58 | img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] 59 | 60 | keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0] 61 | 62 | #### keypoints to the original location 63 | keypoints[:,0] += int(bboxes[0]) 64 | keypoints[:,1] += int(bboxes[1]) 65 | 66 | break 67 | except RuntimeError as e: 68 | if str(e).startswith('CUDA'): 69 | print("Warning: out of memory, sleep for 1s") 70 | time.sleep(1) 71 | else: 72 | print(e) 73 | break 74 | except TypeError: 75 | print('No face detected in this image') 76 | shape = [68, 2] 77 | keypoints = -1. * np.ones(shape) 78 | break 79 | if name is not None: 80 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 81 | return keypoints 82 | 83 | def read_video(filename): 84 | frames = [] 85 | cap = cv2.VideoCapture(filename) 86 | while cap.isOpened(): 87 | ret, frame = cap.read() 88 | if ret: 89 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 90 | frame = Image.fromarray(frame) 91 | frames.append(frame) 92 | else: 93 | break 94 | cap.release() 95 | return frames 96 | 97 | def run(data): 98 | filename, opt, device = data 99 | os.environ['CUDA_VISIBLE_DEVICES'] = device 100 | kp_extractor = KeypointExtractor() 101 | images = read_video(filename) 102 | name = filename.split('/')[-2:] 103 | os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) 104 | kp_extractor.extract_keypoint( 105 | images, 106 | name=os.path.join(opt.output_dir, name[-2], name[-1]) 107 | ) 108 | 109 | if __name__ == '__main__': 110 | set_start_method('spawn') 111 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 112 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 113 | parser.add_argument('--output_dir', type=str, help='the folder of the output files') 114 | parser.add_argument('--device_ids', type=str, default='0,1') 115 | parser.add_argument('--workers', type=int, default=4) 116 | 117 | opt = parser.parse_args() 118 | filenames = list() 119 | VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} 120 | VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) 121 | extensions = VIDEO_EXTENSIONS 122 | 123 | for ext in extensions: 124 | os.listdir(f'{opt.input_dir}') 125 | print(f'{opt.input_dir}/*.{ext}') 126 | filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) 127 | print('Total number of videos:', len(filenames)) 128 | pool = Pool(opt.workers) 129 | args_list = cycle([opt]) 130 | device_ids = opt.device_ids.split(",") 131 | device_ids = cycle(device_ids) 132 | for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): 133 | None 134 | -------------------------------------------------------------------------------- /src/face3d/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from src.face3d.models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "face3d.models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | return iresnet18(False, **kwargs) 9 | elif name == "r34": 10 | return iresnet34(False, **kwargs) 11 | elif name == "r50": 12 | return iresnet50(False, **kwargs) 13 | elif name == "r100": 14 | return iresnet100(False, **kwargs) 15 | elif name == "r200": 16 | return iresnet200(False, **kwargs) 17 | elif name == "r2060": 18 | from .iresnet2060 import iresnet2060 19 | return iresnet2060(False, **kwargs) 20 | elif name == "mbf": 21 | fp16 = kwargs.get("fp16", False) 22 | num_features = kwargs.get("num_features", 512) 23 | return get_mbf(fp16=fp16, num_features=num_features) 24 | else: 25 | raise ValueError() -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/backbones/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module 8 | import torch 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class ConvBlock(Module): 17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 18 | super(ConvBlock, self).__init__() 19 | self.layers = nn.Sequential( 20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), 21 | BatchNorm2d(num_features=out_c), 22 | PReLU(num_parameters=out_c) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.layers(x) 27 | 28 | 29 | class LinearBlock(Module): 30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 31 | super(LinearBlock, self).__init__() 32 | self.layers = nn.Sequential( 33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), 34 | BatchNorm2d(num_features=out_c) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.layers(x) 39 | 40 | 41 | class DepthWise(Module): 42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 43 | super(DepthWise, self).__init__() 44 | self.residual = residual 45 | self.layers = nn.Sequential( 46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), 47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), 48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 49 | ) 50 | 51 | def forward(self, x): 52 | short_cut = None 53 | if self.residual: 54 | short_cut = x 55 | x = self.layers(x) 56 | if self.residual: 57 | output = short_cut + x 58 | else: 59 | output = x 60 | return output 61 | 62 | 63 | class Residual(Module): 64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 65 | super(Residual, self).__init__() 66 | modules = [] 67 | for _ in range(num_block): 68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) 69 | self.layers = Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class GDC(Module): 76 | def __init__(self, embedding_size): 77 | super(GDC, self).__init__() 78 | self.layers = nn.Sequential( 79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), 80 | Flatten(), 81 | Linear(512, embedding_size, bias=False), 82 | BatchNorm1d(embedding_size)) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | 87 | 88 | class MobileFaceNet(Module): 89 | def __init__(self, fp16=False, num_features=512): 90 | super(MobileFaceNet, self).__init__() 91 | scale = 2 92 | self.fp16 = fp16 93 | self.layers = nn.Sequential( 94 | ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), 95 | ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), 96 | DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), 97 | Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 98 | DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), 99 | Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 100 | DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), 101 | Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 102 | ) 103 | self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 104 | self.features = GDC(num_features) 105 | self._initialize_weights() 106 | 107 | def _initialize_weights(self): 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | 121 | def forward(self, x): 122 | with torch.cuda.amp.autocast(self.fp16): 123 | x = self.layers(x) 124 | x = self.conv_sep(x.float() if self.fp16 else x) 125 | x = self.features(x) 126 | return x 127 | 128 | 129 | def get_mbf(fp16, num_features): 130 | return MobileFaceNet(fp16, num_features) -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/3millions_pfc.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/src/face3d/models/arcface_torch/configs/__init__.py -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = "ms1mv3_arcface_r50" 12 | 13 | config.dataset = "ms1m-retinaface-t1" 14 | config.embedding_size = 512 15 | config.sample_rate = 1 16 | config.fp16 = False 17 | config.momentum = 0.9 18 | config.weight_decay = 5e-4 19 | config.batch_size = 128 20 | config.lr = 0.1 # batch size is 512 21 | 22 | if config.dataset == "emore": 23 | config.rec = "/train_tmp/faces_emore" 24 | config.num_classes = 85742 25 | config.num_image = 5822653 26 | config.num_epoch = 16 27 | config.warmup_epoch = -1 28 | config.decay_epoch = [8, 14, ] 29 | config.val_targets = ["lfw", ] 30 | 31 | elif config.dataset == "ms1m-retinaface-t1": 32 | config.rec = "/train_tmp/ms1m-retinaface-t1" 33 | config.num_classes = 93431 34 | config.num_image = 5179510 35 | config.num_epoch = 25 36 | config.warmup_epoch = -1 37 | config.decay_epoch = [11, 17, 22] 38 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 39 | 40 | elif config.dataset == "glint360k": 41 | config.rec = "/train_tmp/glint360k" 42 | config.num_classes = 360232 43 | config.num_image = 17091657 44 | config.num_epoch = 20 45 | config.warmup_epoch = -1 46 | config.decay_epoch = [8, 12, 15, 18] 47 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 48 | 49 | elif config.dataset == "webface": 50 | config.rec = "/train_tmp/faces_webface_112x112" 51 | config.num_classes = 10572 52 | config.num_image = "forget" 53 | config.num_epoch = 34 54 | config.warmup_epoch = -1 55 | config.decay_epoch = [20, 28, 32] 56 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 57 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/glint360k_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/glint360k_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 30 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 20, 25] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/ms1mv3_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r2060" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 64 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/ms1mv3_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/configs/speed.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 100 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/dataset.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os 3 | import queue as Queue 4 | import threading 5 | 6 | import mxnet as mx 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision import transforms 11 | 12 | 13 | class BackgroundGenerator(threading.Thread): 14 | def __init__(self, generator, local_rank, max_prefetch=6): 15 | super(BackgroundGenerator, self).__init__() 16 | self.queue = Queue.Queue(max_prefetch) 17 | self.generator = generator 18 | self.local_rank = local_rank 19 | self.daemon = True 20 | self.start() 21 | 22 | def run(self): 23 | torch.cuda.set_device(self.local_rank) 24 | for item in self.generator: 25 | self.queue.put(item) 26 | self.queue.put(None) 27 | 28 | def next(self): 29 | next_item = self.queue.get() 30 | if next_item is None: 31 | raise StopIteration 32 | return next_item 33 | 34 | def __next__(self): 35 | return self.next() 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | 41 | class DataLoaderX(DataLoader): 42 | 43 | def __init__(self, local_rank, **kwargs): 44 | super(DataLoaderX, self).__init__(**kwargs) 45 | self.stream = torch.cuda.Stream(local_rank) 46 | self.local_rank = local_rank 47 | 48 | def __iter__(self): 49 | self.iter = super(DataLoaderX, self).__iter__() 50 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 51 | self.preload() 52 | return self 53 | 54 | def preload(self): 55 | self.batch = next(self.iter, None) 56 | if self.batch is None: 57 | return None 58 | with torch.cuda.stream(self.stream): 59 | for k in range(len(self.batch)): 60 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 61 | 62 | def __next__(self): 63 | torch.cuda.current_stream().wait_stream(self.stream) 64 | batch = self.batch 65 | if batch is None: 66 | raise StopIteration 67 | self.preload() 68 | return batch 69 | 70 | 71 | class MXFaceDataset(Dataset): 72 | def __init__(self, root_dir, local_rank): 73 | super(MXFaceDataset, self).__init__() 74 | self.transform = transforms.Compose( 75 | [transforms.ToPILImage(), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 79 | ]) 80 | self.root_dir = root_dir 81 | self.local_rank = local_rank 82 | path_imgrec = os.path.join(root_dir, 'train.rec') 83 | path_imgidx = os.path.join(root_dir, 'train.idx') 84 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') 85 | s = self.imgrec.read_idx(0) 86 | header, _ = mx.recordio.unpack(s) 87 | if header.flag > 0: 88 | self.header0 = (int(header.label[0]), int(header.label[1])) 89 | self.imgidx = np.array(range(1, int(header.label[0]))) 90 | else: 91 | self.imgidx = np.array(list(self.imgrec.keys)) 92 | 93 | def __getitem__(self, index): 94 | idx = self.imgidx[index] 95 | s = self.imgrec.read_idx(idx) 96 | header, img = mx.recordio.unpack(s) 97 | label = header.label 98 | if not isinstance(label, numbers.Number): 99 | label = label[0] 100 | label = torch.tensor(label, dtype=torch.long) 101 | sample = mx.image.imdecode(img).asnumpy() 102 | if self.transform is not None: 103 | sample = self.transform(sample) 104 | return sample, label 105 | 106 | def __len__(self): 107 | return len(self.imgidx) 108 | 109 | 110 | class SyntheticDataset(Dataset): 111 | def __init__(self, local_rank): 112 | super(SyntheticDataset, self).__init__() 113 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 114 | img = np.transpose(img, (2, 0, 1)) 115 | img = torch.from_numpy(img).squeeze(0).float() 116 | img = ((img / 255) - 0.5) / 0.5 117 | self.img = img 118 | self.label = 1 119 | 120 | def __getitem__(self, index): 121 | return self.img, self.label 122 | 123 | def __len__(self): 124 | return 1000000 125 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Eval on ICCV2021-MFR 2 | 3 | coming soon. 4 | 5 | 6 | ## Eval IJBC 7 | You can eval ijbc with pytorch or onnx. 8 | 9 | 10 | 1. Eval IJBC With Onnx 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 13 | ``` 14 | 15 | 2. Eval IJBC With Pytorch 16 | ```shell 17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ 18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \ 19 | --image-path IJB_release/IJBC \ 20 | --result-dir ms1mv3_arcface_r50 \ 21 | --batch-size 128 \ 22 | --job ms1mv3_arcface_r50 \ 23 | --target IJBC \ 24 | --network iresnet50 25 | ``` 26 | 27 | ## Inference 28 | 29 | ```shell 30 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 31 | ``` 32 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/docs/install.md: -------------------------------------------------------------------------------- 1 | ## v1.8.0 2 | ### Linux and Windows 3 | ```shell 4 | # CUDA 11.0 5 | pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 6 | 7 | # CUDA 10.2 8 | pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 9 | 10 | # CPU only 11 | pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 12 | 13 | ``` 14 | 15 | 16 | ## v1.7.1 17 | ### Linux and Windows 18 | ```shell 19 | # CUDA 11.0 20 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 21 | 22 | # CUDA 10.2 23 | pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 24 | 25 | # CUDA 10.1 26 | pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 27 | 28 | # CUDA 9.2 29 | pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 30 | 31 | # CPU only 32 | pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 33 | ``` 34 | 35 | 36 | ## v1.6.0 37 | 38 | ### Linux and Windows 39 | ```shell 40 | # CUDA 10.2 41 | pip install torch==1.6.0 torchvision==0.7.0 42 | 43 | # CUDA 10.1 44 | pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 45 | 46 | # CUDA 9.2 47 | pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html 48 | 49 | # CPU only 50 | pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 51 | ``` -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/docs/modelzoo.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/src/face3d/models/arcface_torch/docs/modelzoo.md -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/src/face3d/models/arcface_torch/eval/__init__.py -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from backbones import get_model 8 | 9 | 10 | @torch.no_grad() 11 | def inference(weight, name, img): 12 | if img is None: 13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 14 | else: 15 | img = cv2.imread(img) 16 | img = cv2.resize(img, (112, 112)) 17 | 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | img = np.transpose(img, (2, 0, 1)) 20 | img = torch.from_numpy(img).unsqueeze(0).float() 21 | img.div_(255).sub_(0.5).div_(0.5) 22 | net = get_model(name, fp16=False) 23 | net.load_state_dict(torch.load(weight)) 24 | net.eval() 25 | feat = net(img).numpy() 26 | print(feat) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') 31 | parser.add_argument('--network', type=str, default='r50', help='backbone network') 32 | parser.add_argument('--weight', type=str, default='') 33 | parser.add_argument('--img', type=str, default=None) 34 | args = parser.parse_args() 35 | inference(args.weight, args.network, args.img) 36 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def get_loss(name): 6 | if name == "cosface": 7 | return CosFace() 8 | elif name == "arcface": 9 | return ArcFace() 10 | else: 11 | raise ValueError() 12 | 13 | 14 | class CosFace(nn.Module): 15 | def __init__(self, s=64.0, m=0.40): 16 | super(CosFace, self).__init__() 17 | self.s = s 18 | self.m = m 19 | 20 | def forward(self, cosine, label): 21 | index = torch.where(label != -1)[0] 22 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 23 | m_hot.scatter_(1, label[index, None], self.m) 24 | cosine[index] -= m_hot 25 | ret = cosine * self.s 26 | return ret 27 | 28 | 29 | class ArcFace(nn.Module): 30 | def __init__(self, s=64.0, m=0.5): 31 | super(ArcFace, self).__init__() 32 | self.s = s 33 | self.m = m 34 | 35 | def forward(self, cosine: torch.Tensor, label): 36 | index = torch.where(label != -1)[0] 37 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 38 | m_hot.scatter_(1, label[index, None], self.m) 39 | cosine.acos_() 40 | cosine[index] += m_hot 41 | cosine.cos_().mul_(self.s) 42 | return cosine 43 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 2 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh 3 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float) 10 | img = (img / 255. - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight) 16 | net.eval() 17 | torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset) 18 | model = onnx.load(output) 19 | graph = model.graph 20 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' 21 | if simplify: 22 | from onnxsim import simplify 23 | model, check = simplify(model) 24 | assert check, "Simplified ONNX model could not be validated" 25 | onnx.save(model, output) 26 | 27 | 28 | if __name__ == '__main__': 29 | import os 30 | import argparse 31 | from backbones import get_model 32 | 33 | parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') 34 | parser.add_argument('input', type=str, help='input backbone.pth file or path') 35 | parser.add_argument('--output', type=str, default=None, help='output onnx path') 36 | parser.add_argument('--network', type=str, default=None, help='backbone network') 37 | parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') 38 | args = parser.parse_args() 39 | input_file = args.input 40 | if os.path.isdir(input_file): 41 | input_file = os.path.join(input_file, "backbone.pth") 42 | assert os.path.exists(input_file) 43 | model_name = os.path.basename(os.path.dirname(input_file)).lower() 44 | params = model_name.split("_") 45 | if len(params) >= 3 and params[1] in ('arcface', 'cosface'): 46 | if args.network is None: 47 | args.network = params[2] 48 | assert args.network is not None 49 | print(args) 50 | backbone_onnx = get_model(args.network, dropout=0) 51 | 52 | output_path = args.output 53 | if output_path is None: 54 | output_path = os.path.join(os.path.dirname(__file__), 'onnx') 55 | if not os.path.exists(output_path): 56 | os.makedirs(output_path) 57 | assert os.path.isdir(output_path) 58 | output_file = os.path.join(output_path, "%s.onnx" % model_name) 59 | convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) 60 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/src/face3d/models/arcface_torch/utils/__init__.py -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/utils/plot.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 10 | from prettytable import PrettyTable 11 | from sklearn.metrics import roc_curve, auc 12 | 13 | image_path = "/data/anxiang/IJB_release/IJBC" 14 | files = [ 15 | "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" 16 | ] 17 | 18 | 19 | def read_template_pair_list(path): 20 | pairs = pd.read_csv(path, sep=' ', header=None).values 21 | t1 = pairs[:, 0].astype(np.int) 22 | t2 = pairs[:, 1].astype(np.int) 23 | label = pairs[:, 2].astype(np.int) 24 | return t1, t2, label 25 | 26 | 27 | p1, p2, label = read_template_pair_list( 28 | os.path.join('%s/meta' % image_path, 29 | '%s_template_pair_label.txt' % 'ijbc')) 30 | 31 | methods = [] 32 | scores = [] 33 | for file in files: 34 | methods.append(file.split('/')[-2]) 35 | scores.append(np.load(file)) 36 | 37 | methods = np.array(methods) 38 | scores = dict(zip(methods, scores)) 39 | colours = dict( 40 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 41 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 42 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 43 | fig = plt.figure() 44 | for method in methods: 45 | fpr, tpr, _ = roc_curve(label, scores[method]) 46 | roc_auc = auc(fpr, tpr) 47 | fpr = np.flipud(fpr) 48 | tpr = np.flipud(tpr) # select largest tpr at same fpr 49 | plt.plot(fpr, 50 | tpr, 51 | color=colours[method], 52 | lw=1, 53 | label=('[%s (AUC = %0.4f %%)]' % 54 | (method.split('-')[-1], roc_auc * 100))) 55 | tpr_fpr_row = [] 56 | tpr_fpr_row.append("%s-%s" % (method, "IJBC")) 57 | for fpr_iter in np.arange(len(x_labels)): 58 | _, min_index = min( 59 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 60 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 61 | tpr_fpr_table.add_row(tpr_fpr_row) 62 | plt.xlim([10 ** -6, 0.1]) 63 | plt.ylim([0.3, 1.0]) 64 | plt.grid(linestyle='--', linewidth=1) 65 | plt.xticks(x_labels) 66 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 67 | plt.xscale('log') 68 | plt.xlabel('False Positive Rate') 69 | plt.ylabel('True Positive Rate') 70 | plt.title('ROC on IJB') 71 | plt.legend(loc="lower right") 72 | print(tpr_fpr_table) 73 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/utils/utils_amp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | 5 | if torch.__version__ < '1.9': 6 | Iterable = torch._six.container_abcs.Iterable 7 | else: 8 | import collections 9 | 10 | Iterable = collections.abc.Iterable 11 | from torch.cuda.amp import GradScaler 12 | 13 | 14 | class _MultiDeviceReplicator(object): 15 | """ 16 | Lazily serves copies of a tensor to requested devices. Copies are cached per-device. 17 | """ 18 | 19 | def __init__(self, master_tensor: torch.Tensor) -> None: 20 | assert master_tensor.is_cuda 21 | self.master = master_tensor 22 | self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} 23 | 24 | def get(self, device) -> torch.Tensor: 25 | retval = self._per_device_tensors.get(device, None) 26 | if retval is None: 27 | retval = self.master.to(device=device, non_blocking=True, copy=True) 28 | self._per_device_tensors[device] = retval 29 | return retval 30 | 31 | 32 | class MaxClipGradScaler(GradScaler): 33 | def __init__(self, init_scale, max_scale: float, growth_interval=100): 34 | GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) 35 | self.max_scale = max_scale 36 | 37 | def scale_clip(self): 38 | if self.get_scale() == self.max_scale: 39 | self.set_growth_factor(1) 40 | elif self.get_scale() < self.max_scale: 41 | self.set_growth_factor(2) 42 | elif self.get_scale() > self.max_scale: 43 | self._scale.fill_(self.max_scale) 44 | self.set_growth_factor(1) 45 | 46 | def scale(self, outputs): 47 | """ 48 | Multiplies ('scales') a tensor or list of tensors by the scale factor. 49 | 50 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 51 | unmodified. 52 | 53 | Arguments: 54 | outputs (Tensor or iterable of Tensors): Outputs to scale. 55 | """ 56 | if not self._enabled: 57 | return outputs 58 | self.scale_clip() 59 | # Short-circuit for the common case. 60 | if isinstance(outputs, torch.Tensor): 61 | assert outputs.is_cuda 62 | if self._scale is None: 63 | self._lazy_init_scale_growth_tracker(outputs.device) 64 | assert self._scale is not None 65 | return outputs * self._scale.to(device=outputs.device, non_blocking=True) 66 | 67 | # Invoke the more complex machinery only if we're treating multiple outputs. 68 | stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale 69 | 70 | def apply_scale(val): 71 | if isinstance(val, torch.Tensor): 72 | assert val.is_cuda 73 | if len(stash) == 0: 74 | if self._scale is None: 75 | self._lazy_init_scale_growth_tracker(val.device) 76 | assert self._scale is not None 77 | stash.append(_MultiDeviceReplicator(self._scale)) 78 | return val * stash[0].get(val.device) 79 | elif isinstance(val, Iterable): 80 | iterable = map(apply_scale, val) 81 | if isinstance(val, list) or isinstance(val, tuple): 82 | return type(val)(iterable) 83 | else: 84 | return iterable 85 | else: 86 | raise ValueError("outputs must be a Tensor or an iterable of Tensors") 87 | 88 | return apply_scale(outputs) 89 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List 5 | 6 | import torch 7 | 8 | from eval import verification 9 | from utils.utils_logging import AverageMeter 10 | 11 | 12 | class CallBackVerification(object): 13 | def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): 14 | self.frequent: int = frequent 15 | self.rank: int = rank 16 | self.highest_acc: float = 0.0 17 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 18 | self.ver_list: List[object] = [] 19 | self.ver_name_list: List[str] = [] 20 | if self.rank is 0: 21 | self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) 22 | 23 | def ver_test(self, backbone: torch.nn.Module, global_step: int): 24 | results = [] 25 | for i in range(len(self.ver_list)): 26 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( 27 | self.ver_list[i], backbone, 10, 10) 28 | logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) 29 | logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) 30 | if acc2 > self.highest_acc_list[i]: 31 | self.highest_acc_list[i] = acc2 32 | logging.info( 33 | '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) 34 | results.append(acc2) 35 | 36 | def init_dataset(self, val_targets, data_dir, image_size): 37 | for name in val_targets: 38 | path = os.path.join(data_dir, name + ".bin") 39 | if os.path.exists(path): 40 | data_set = verification.load_bin(path, image_size) 41 | self.ver_list.append(data_set) 42 | self.ver_name_list.append(name) 43 | 44 | def __call__(self, num_update, backbone: torch.nn.Module): 45 | if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: 46 | backbone.eval() 47 | self.ver_test(backbone, num_update) 48 | backbone.train() 49 | 50 | 51 | class CallBackLogging(object): 52 | def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): 53 | self.frequent: int = frequent 54 | self.rank: int = rank 55 | self.time_start = time.time() 56 | self.total_step: int = total_step 57 | self.batch_size: int = batch_size 58 | self.world_size: int = world_size 59 | self.writer = writer 60 | 61 | self.init = False 62 | self.tic = 0 63 | 64 | def __call__(self, 65 | global_step: int, 66 | loss: AverageMeter, 67 | epoch: int, 68 | fp16: bool, 69 | learning_rate: float, 70 | grad_scaler: torch.cuda.amp.GradScaler): 71 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 72 | if self.init: 73 | try: 74 | speed: float = self.frequent * self.batch_size / (time.time() - self.tic) 75 | speed_total = speed * self.world_size 76 | except ZeroDivisionError: 77 | speed_total = float('inf') 78 | 79 | time_now = (time.time() - self.time_start) / 3600 80 | time_total = time_now / ((global_step + 1) / self.total_step) 81 | time_for_end = time_total - time_now 82 | if self.writer is not None: 83 | self.writer.add_scalar('time_for_end', time_for_end, global_step) 84 | self.writer.add_scalar('learning_rate', learning_rate, global_step) 85 | self.writer.add_scalar('loss', loss.avg, global_step) 86 | if fp16: 87 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ 88 | "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( 89 | speed_total, loss.avg, learning_rate, epoch, global_step, 90 | grad_scaler.get_scale(), time_for_end 91 | ) 92 | else: 93 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ 94 | "Required: %1.f hours" % ( 95 | speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end 96 | ) 97 | logging.info(msg) 98 | loss.reset() 99 | self.tic = time.time() 100 | else: 101 | self.init = True 102 | self.tic = time.time() 103 | 104 | 105 | class CallBackModelCheckpoint(object): 106 | def __init__(self, rank, output="./"): 107 | self.rank: int = rank 108 | self.output: str = output 109 | 110 | def __call__(self, global_step, backbone, partial_fc, ): 111 | if global_step > 100 and self.rank == 0: 112 | path_module = os.path.join(self.output, "backbone.pth") 113 | torch.save(backbone.module.state_dict(), path_module) 114 | logging.info("Pytorch Model Saved in '{}'".format(path_module)) 115 | 116 | if global_step > 100 and partial_fc is not None: 117 | partial_fc.save_params() 118 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | assert config_file.startswith('configs/'), 'config file setting must start with configs/' 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module("configs.base") 10 | cfg = config.config 11 | config = importlib.import_module("configs.%s" % temp_module_name) 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join('work_dirs', temp_module_name) 16 | return cfg -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(rank, models_root): 31 | if rank == 0: 32 | log_root = logging.getLogger() 33 | log_root.setLevel(logging.INFO) 34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | -------------------------------------------------------------------------------- /src/face3d/models/arcface_torch/utils/utils_os.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/src/face3d/models/arcface_torch/utils/utils_os.py -------------------------------------------------------------------------------- /src/face3d/models/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from kornia.geometry import warp_affine 5 | import torch.nn.functional as F 6 | 7 | def resize_n_crop(image, M, dsize=112): 8 | # image: (b, c, h, w) 9 | # M : (b, 2, 3) 10 | return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) 11 | 12 | ### perceptual level loss 13 | class PerceptualLoss(nn.Module): 14 | def __init__(self, recog_net, input_size=112): 15 | super(PerceptualLoss, self).__init__() 16 | self.recog_net = recog_net 17 | self.preprocess = lambda x: 2 * x - 1 18 | self.input_size=input_size 19 | def forward(imageA, imageB, M): 20 | """ 21 | 1 - cosine distance 22 | Parameters: 23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order 24 | imageB --same as imageA 25 | """ 26 | 27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) 28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) 29 | 30 | # freeze bn 31 | self.recog_net.eval() 32 | 33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) 34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) 35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 36 | # assert torch.sum((cosine_d > 1).float()) == 0 37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 38 | 39 | def perceptual_loss(id_featureA, id_featureB): 40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 41 | # assert torch.sum((cosine_d > 1).float()) == 0 42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 43 | 44 | ### image level loss 45 | def photo_loss(imageA, imageB, mask, eps=1e-6): 46 | """ 47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) 48 | Parameters: 49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order 50 | imageB --same as imageA 51 | """ 52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask 53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) 54 | return loss 55 | 56 | def landmark_loss(predict_lm, gt_lm, weight=None): 57 | """ 58 | weighted mse loss 59 | Parameters: 60 | predict_lm --torch.tensor (B, 68, 2) 61 | gt_lm --torch.tensor (B, 68, 2) 62 | weight --numpy.array (1, 68) 63 | """ 64 | if not weight: 65 | weight = np.ones([68]) 66 | weight[28:31] = 20 67 | weight[-8:] = 20 68 | weight = np.expand_dims(weight, 0) 69 | weight = torch.tensor(weight).to(predict_lm.device) 70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight 71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) 72 | return loss 73 | 74 | 75 | ### regulization 76 | def reg_loss(coeffs_dict, opt=None): 77 | """ 78 | l2 norm without the sqrt, from yu's implementation (mse) 79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss 80 | Parameters: 81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans 82 | 83 | """ 84 | # coefficient regularization to ensure plausible 3d faces 85 | if opt: 86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex 87 | else: 88 | w_id, w_exp, w_tex = 1, 1, 1, 1 89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ 90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ 91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2) 92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0] 93 | 94 | # gamma regularization to ensure a nearly-monochromatic light 95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) 96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True) 97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2) 98 | 99 | return creg_loss, gamma_loss 100 | 101 | def reflectance_loss(texture, mask): 102 | """ 103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo 104 | Parameters: 105 | texture --torch.tensor, (B, N, 3) 106 | mask --torch.tensor, (N), 1 or 0 107 | 108 | """ 109 | mask = mask.reshape([1, mask.shape[0], 1]) 110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) 111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) 112 | return loss 113 | 114 | -------------------------------------------------------------------------------- /src/face3d/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /src/face3d/options/inference_options.py: -------------------------------------------------------------------------------- 1 | from face3d.options.base_options import BaseOptions 2 | 3 | 4 | class InferenceOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 13 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 14 | 15 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 16 | parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') 17 | parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') 18 | parser.add_argument('--save_split_files', action='store_true', help='save split files or not') 19 | parser.add_argument('--inference_batch_size', type=int, default=8) 20 | 21 | # Dropout and Batchnorm has different behavior during training and test. 22 | self.isTrain = False 23 | return parser 24 | -------------------------------------------------------------------------------- /src/face3d/options/test_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the test options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | 6 | 7 | class TestOptions(BaseOptions): 8 | """This class includes test options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) # define shared options 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') 18 | 19 | # Dropout and Batchnorm has different behavior during training and test. 20 | self.isTrain = False 21 | return parser 22 | -------------------------------------------------------------------------------- /src/face3d/options/train_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the training options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | from util import util 6 | 7 | class TrainOptions(BaseOptions): 8 | """This class includes training options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) 15 | # dataset parameters 16 | # for train 17 | parser.add_argument('--data_root', type=str, default='./', help='dataset root') 18 | parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') 19 | parser.add_argument('--batch_size', type=int, default=32) 20 | parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') 21 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 22 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 23 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 24 | parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') 25 | parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') 26 | 27 | # for val 28 | parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') 29 | parser.add_argument('--batch_size_val', type=int, default=32) 30 | 31 | 32 | # visualization parameters 33 | parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') 34 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 35 | 36 | # network saving and loading parameters 37 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 38 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 39 | parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') 40 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 41 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 42 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 43 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 44 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') 45 | 46 | # training parameters 47 | parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') 48 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 49 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') 50 | parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') 51 | 52 | self.isTrain = True 53 | return parser 54 | -------------------------------------------------------------------------------- /src/face3d/util/BBRegressorParam_r.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggang1987/fast_sadtalker/b17d8a2f5f14b9e1274f391ddf8ec081d708139c/src/face3d/util/BBRegressorParam_r.mat -------------------------------------------------------------------------------- /src/face3d/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from src.face3d.util import * 3 | 4 | -------------------------------------------------------------------------------- /src/face3d/util/detect_lm68.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from scipy.io import loadmat 5 | import tensorflow as tf 6 | from util.preprocess import align_for_lm 7 | from shutil import move 8 | 9 | mean_face = np.loadtxt('util/test_mean_face.txt') 10 | mean_face = mean_face.reshape([68, 2]) 11 | 12 | def save_label(labels, save_path): 13 | np.savetxt(save_path, labels) 14 | 15 | def draw_landmarks(img, landmark, save_name): 16 | landmark = landmark 17 | lm_img = np.zeros([img.shape[0], img.shape[1], 3]) 18 | lm_img[:] = img.astype(np.float32) 19 | landmark = np.round(landmark).astype(np.int32) 20 | 21 | for i in range(len(landmark)): 22 | for j in range(-1, 1): 23 | for k in range(-1, 1): 24 | if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ 25 | img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ 26 | landmark[i, 0]+k > 0 and \ 27 | landmark[i, 0]+k < img.shape[1]: 28 | lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, 29 | :] = np.array([0, 0, 255]) 30 | lm_img = lm_img.astype(np.uint8) 31 | 32 | cv2.imwrite(save_name, lm_img) 33 | 34 | 35 | def load_data(img_name, txt_name): 36 | return cv2.imread(img_name), np.loadtxt(txt_name) 37 | 38 | # create tensorflow graph for landmark detector 39 | def load_lm_graph(graph_filename): 40 | with tf.gfile.GFile(graph_filename, 'rb') as f: 41 | graph_def = tf.GraphDef() 42 | graph_def.ParseFromString(f.read()) 43 | 44 | with tf.Graph().as_default() as graph: 45 | tf.import_graph_def(graph_def, name='net') 46 | img_224 = graph.get_tensor_by_name('net/input_imgs:0') 47 | output_lm = graph.get_tensor_by_name('net/lm:0') 48 | lm_sess = tf.Session(graph=graph) 49 | 50 | return lm_sess,img_224,output_lm 51 | 52 | # landmark detection 53 | def detect_68p(img_path,sess,input_op,output_op): 54 | print('detecting landmarks......') 55 | names = [i for i in sorted(os.listdir( 56 | img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] 57 | vis_path = os.path.join(img_path, 'vis') 58 | remove_path = os.path.join(img_path, 'remove') 59 | save_path = os.path.join(img_path, 'landmarks') 60 | if not os.path.isdir(vis_path): 61 | os.makedirs(vis_path) 62 | if not os.path.isdir(remove_path): 63 | os.makedirs(remove_path) 64 | if not os.path.isdir(save_path): 65 | os.makedirs(save_path) 66 | 67 | for i in range(0, len(names)): 68 | name = names[i] 69 | print('%05d' % (i), ' ', name) 70 | full_image_name = os.path.join(img_path, name) 71 | txt_name = '.'.join(name.split('.')[:-1]) + '.txt' 72 | full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image 73 | 74 | # if an image does not have detected 5 facial landmarks, remove it from the training list 75 | if not os.path.isfile(full_txt_name): 76 | move(full_image_name, os.path.join(remove_path, name)) 77 | continue 78 | 79 | # load data 80 | img, five_points = load_data(full_image_name, full_txt_name) 81 | input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection 82 | 83 | # if the alignment fails, remove corresponding image from the training list 84 | if scale == 0: 85 | move(full_txt_name, os.path.join( 86 | remove_path, txt_name)) 87 | move(full_image_name, os.path.join(remove_path, name)) 88 | continue 89 | 90 | # detect landmarks 91 | input_img = np.reshape( 92 | input_img, [1, 224, 224, 3]).astype(np.float32) 93 | landmark = sess.run( 94 | output_op, feed_dict={input_op: input_img}) 95 | 96 | # transform back to original image coordinate 97 | landmark = landmark.reshape([68, 2]) + mean_face 98 | landmark[:, 1] = 223 - landmark[:, 1] 99 | landmark = landmark / scale 100 | landmark[:, 0] = landmark[:, 0] + bbox[0] 101 | landmark[:, 1] = landmark[:, 1] + bbox[1] 102 | landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] 103 | 104 | if i % 100 == 0: 105 | draw_landmarks(img, landmark, os.path.join(vis_path, name)) 106 | save_label(landmark, os.path.join(save_path, txt_name)) 107 | -------------------------------------------------------------------------------- /src/face3d/util/generate_list.py: -------------------------------------------------------------------------------- 1 | """This script is to generate training list files for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | 6 | # save path to training data 7 | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): 8 | save_path = os.path.join(save_folder, mode) 9 | if not os.path.isdir(save_path): 10 | os.makedirs(save_path) 11 | with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: 12 | fd.writelines([i + '\n' for i in lms_list]) 13 | 14 | with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: 15 | fd.writelines([i + '\n' for i in imgs_list]) 16 | 17 | with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: 18 | fd.writelines([i + '\n' for i in msks_list]) 19 | 20 | # check if the path is valid 21 | def check_list(rlms_list, rimgs_list, rmsks_list): 22 | lms_list, imgs_list, msks_list = [], [], [] 23 | for i in range(len(rlms_list)): 24 | flag = 'false' 25 | lm_path = rlms_list[i] 26 | im_path = rimgs_list[i] 27 | msk_path = rmsks_list[i] 28 | if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): 29 | flag = 'true' 30 | lms_list.append(rlms_list[i]) 31 | imgs_list.append(rimgs_list[i]) 32 | msks_list.append(rmsks_list[i]) 33 | print(i, rlms_list[i], flag) 34 | return lms_list, imgs_list, msks_list 35 | -------------------------------------------------------------------------------- /src/face3d/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /src/face3d/util/load_mats.py: -------------------------------------------------------------------------------- 1 | """This script is to load 3D face model for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from scipy.io import loadmat, savemat 7 | from array import array 8 | import os.path as osp 9 | 10 | # load expression basis 11 | def LoadExpBasis(bfm_folder='BFM'): 12 | n_vertex = 53215 13 | Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') 14 | exp_dim = array('i') 15 | exp_dim.fromfile(Expbin, 1) 16 | expMU = array('f') 17 | expPC = array('f') 18 | expMU.fromfile(Expbin, 3*n_vertex) 19 | expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) 20 | Expbin.close() 21 | 22 | expPC = np.array(expPC) 23 | expPC = np.reshape(expPC, [exp_dim[0], -1]) 24 | expPC = np.transpose(expPC) 25 | 26 | expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) 27 | 28 | return expPC, expEV 29 | 30 | 31 | # transfer original BFM09 to our face model 32 | def transferBFM09(bfm_folder='BFM'): 33 | print('Transfer BFM09 to BFM_model_front......') 34 | original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) 35 | shapePC = original_BFM['shapePC'] # shape basis 36 | shapeEV = original_BFM['shapeEV'] # corresponding eigen value 37 | shapeMU = original_BFM['shapeMU'] # mean face 38 | texPC = original_BFM['texPC'] # texture basis 39 | texEV = original_BFM['texEV'] # eigen value 40 | texMU = original_BFM['texMU'] # mean texture 41 | 42 | expPC, expEV = LoadExpBasis(bfm_folder) 43 | 44 | # transfer BFM09 to our face model 45 | 46 | idBase = shapePC*np.reshape(shapeEV, [-1, 199]) 47 | idBase = idBase/1e5 # unify the scale to decimeter 48 | idBase = idBase[:, :80] # use only first 80 basis 49 | 50 | exBase = expPC*np.reshape(expEV, [-1, 79]) 51 | exBase = exBase/1e5 # unify the scale to decimeter 52 | exBase = exBase[:, :64] # use only first 64 basis 53 | 54 | texBase = texPC*np.reshape(texEV, [-1, 199]) 55 | texBase = texBase[:, :80] # use only first 80 basis 56 | 57 | # our face model is cropped along face landmarks and contains only 35709 vertex. 58 | # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. 59 | # thus we select corresponding vertex to get our face model. 60 | 61 | index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) 62 | index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) 63 | 64 | index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) 65 | index_shape = index_shape['trimIndex'].astype( 66 | np.int32) - 1 # starts from 0 (to 53490) 67 | index_shape = index_shape[index_exp] 68 | 69 | idBase = np.reshape(idBase, [-1, 3, 80]) 70 | idBase = idBase[index_shape, :, :] 71 | idBase = np.reshape(idBase, [-1, 80]) 72 | 73 | texBase = np.reshape(texBase, [-1, 3, 80]) 74 | texBase = texBase[index_shape, :, :] 75 | texBase = np.reshape(texBase, [-1, 80]) 76 | 77 | exBase = np.reshape(exBase, [-1, 3, 64]) 78 | exBase = exBase[index_exp, :, :] 79 | exBase = np.reshape(exBase, [-1, 64]) 80 | 81 | meanshape = np.reshape(shapeMU, [-1, 3])/1e5 82 | meanshape = meanshape[index_shape, :] 83 | meanshape = np.reshape(meanshape, [1, -1]) 84 | 85 | meantex = np.reshape(texMU, [-1, 3]) 86 | meantex = meantex[index_shape, :] 87 | meantex = np.reshape(meantex, [1, -1]) 88 | 89 | # other info contains triangles, region used for computing photometric loss, 90 | # , region used for skin texture regularization, and 68 landmarks index etc. 91 | other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) 92 | frontmask2_idx = other_info['frontmask2_idx'] 93 | skinmask = other_info['skinmask'] 94 | keypoints = other_info['keypoints'] 95 | point_buf = other_info['point_buf'] 96 | tri = other_info['tri'] 97 | tri_mask2 = other_info['tri_mask2'] 98 | 99 | # save our face model 100 | savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, 101 | 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) 102 | 103 | 104 | # load landmarks for standard face, which is used for image preprocessing 105 | def load_lm3d(bfm_folder): 106 | 107 | Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) 108 | Lm3D = Lm3D['lm'] 109 | 110 | # calculate 5 facial landmarks using 68 landmarks 111 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 112 | Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( 113 | Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) 114 | Lm3D = Lm3D[[1, 2, 0, 3, 4], :] 115 | 116 | return Lm3D 117 | 118 | 119 | if __name__ == '__main__': 120 | transferBFM09() -------------------------------------------------------------------------------- /src/face3d/util/nvdiffrast.py: -------------------------------------------------------------------------------- 1 | """This script is the differentiable renderer for Deep3DFaceRecon_pytorch 2 | Attention, antialiasing step is missing in current version. 3 | """ 4 | import pytorch3d.ops 5 | import torch 6 | import torch.nn.functional as F 7 | import kornia 8 | from kornia.geometry.camera import pixel2cam 9 | import numpy as np 10 | from typing import List 11 | from scipy.io import loadmat 12 | from torch import nn 13 | 14 | from pytorch3d.structures import Meshes 15 | from pytorch3d.renderer import ( 16 | look_at_view_transform, 17 | FoVPerspectiveCameras, 18 | DirectionalLights, 19 | RasterizationSettings, 20 | MeshRenderer, 21 | MeshRasterizer, 22 | SoftPhongShader, 23 | TexturesUV, 24 | ) 25 | 26 | # def ndc_projection(x=0.1, n=1.0, f=50.0): 27 | # return np.array([[n/x, 0, 0, 0], 28 | # [ 0, n/-x, 0, 0], 29 | # [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], 30 | # [ 0, 0, -1, 0]]).astype(np.float32) 31 | 32 | class MeshRenderer(nn.Module): 33 | def __init__(self, 34 | rasterize_fov, 35 | znear=0.1, 36 | zfar=10, 37 | rasterize_size=224): 38 | super(MeshRenderer, self).__init__() 39 | 40 | # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear 41 | # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( 42 | # torch.diag(torch.tensor([1., -1, -1, 1]))) 43 | self.rasterize_size = rasterize_size 44 | self.fov = rasterize_fov 45 | self.znear = znear 46 | self.zfar = zfar 47 | 48 | self.rasterizer = None 49 | 50 | def forward(self, vertex, tri, feat=None): 51 | """ 52 | Return: 53 | mask -- torch.tensor, size (B, 1, H, W) 54 | depth -- torch.tensor, size (B, 1, H, W) 55 | features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None 56 | 57 | Parameters: 58 | vertex -- torch.tensor, size (B, N, 3) 59 | tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles 60 | feat(optional) -- torch.tensor, size (B, N ,C), features 61 | """ 62 | device = vertex.device 63 | rsize = int(self.rasterize_size) 64 | # ndc_proj = self.ndc_proj.to(device) 65 | # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v 66 | if vertex.shape[-1] == 3: 67 | vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) 68 | vertex[..., 0] = -vertex[..., 0] 69 | 70 | 71 | # vertex_ndc = vertex @ ndc_proj.t() 72 | if self.rasterizer is None: 73 | self.rasterizer = MeshRasterizer() 74 | print("create rasterizer on device cuda:%d"%device.index) 75 | 76 | # ranges = None 77 | # if isinstance(tri, List) or len(tri.shape) == 3: 78 | # vum = vertex_ndc.shape[1] 79 | # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) 80 | # fstartidx = torch.cumsum(fnum, dim=0) - fnum 81 | # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() 82 | # for i in range(tri.shape[0]): 83 | # tri[i] = tri[i] + i*vum 84 | # vertex_ndc = torch.cat(vertex_ndc, dim=0) 85 | # tri = torch.cat(tri, dim=0) 86 | 87 | # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] 88 | tri = tri.type(torch.int32).contiguous() 89 | 90 | # rasterize 91 | cameras = FoVPerspectiveCameras( 92 | device=device, 93 | fov=self.fov, 94 | znear=self.znear, 95 | zfar=self.zfar, 96 | ) 97 | 98 | raster_settings = RasterizationSettings( 99 | image_size=rsize 100 | ) 101 | 102 | # print(vertex.shape, tri.shape) 103 | mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) 104 | 105 | fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) 106 | rast_out = fragments.pix_to_face.squeeze(-1) 107 | depth = fragments.zbuf 108 | 109 | # render depth 110 | depth = depth.permute(0, 3, 1, 2) 111 | mask = (rast_out > 0).float().unsqueeze(1) 112 | depth = mask * depth 113 | 114 | 115 | image = None 116 | if feat is not None: 117 | attributes = feat.reshape(-1,3)[mesh.faces_packed()] 118 | image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, 119 | fragments.bary_coords, 120 | attributes) 121 | # print(image.shape) 122 | image = image.squeeze(-2).permute(0, 3, 1, 2) 123 | image = mask * image 124 | 125 | return mask, depth, image 126 | 127 | -------------------------------------------------------------------------------- /src/face3d/util/preprocess.py: -------------------------------------------------------------------------------- 1 | """This script contains the image preprocessing code for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | from scipy.io import loadmat 6 | from PIL import Image 7 | import cv2 8 | import os 9 | from skimage import transform as trans 10 | import torch 11 | import warnings 12 | warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 13 | warnings.filterwarnings("ignore", category=FutureWarning) 14 | 15 | 16 | # calculating least square problem for image alignment 17 | def POS(xp, x): 18 | npts = xp.shape[1] 19 | 20 | A = np.zeros([2*npts, 8]) 21 | 22 | A[0:2*npts-1:2, 0:3] = x.transpose() 23 | A[0:2*npts-1:2, 3] = 1 24 | 25 | A[1:2*npts:2, 4:7] = x.transpose() 26 | A[1:2*npts:2, 7] = 1 27 | 28 | b = np.reshape(xp.transpose(), [2*npts, 1]) 29 | 30 | k, _, _, _ = np.linalg.lstsq(A, b) 31 | 32 | R1 = k[0:3] 33 | R2 = k[4:7] 34 | sTx = k[3] 35 | sTy = k[7] 36 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 37 | t = np.stack([sTx, sTy], axis=0) 38 | 39 | return t, s 40 | 41 | # resize and crop images for face reconstruction 42 | def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None): 43 | w0, h0 = img.size 44 | w = (w0*s).astype(np.int32) 45 | h = (h0*s).astype(np.int32) 46 | left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) 47 | right = left + target_size 48 | up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) 49 | below = up + target_size 50 | 51 | img = img.resize((w, h), resample=Image.BICUBIC) 52 | img = img.crop((left, up, right, below)) 53 | 54 | if mask is not None: 55 | mask = mask.resize((w, h), resample=Image.BICUBIC) 56 | mask = mask.crop((left, up, right, below)) 57 | 58 | lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - 59 | t[1] + h0/2], axis=1)*s 60 | lm = lm - np.reshape( 61 | np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) 62 | 63 | return img, lm, mask 64 | 65 | # utils for face reconstruction 66 | def extract_5p(lm): 67 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 68 | lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( 69 | lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) 70 | lm5p = lm5p[[1, 2, 0, 3, 4], :] 71 | return lm5p 72 | 73 | # utils for face reconstruction 74 | def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.): 75 | """ 76 | Return: 77 | transparams --numpy.array (raw_W, raw_H, scale, tx, ty) 78 | img_new --PIL.Image (target_size, target_size, 3) 79 | lm_new --numpy.array (68, 2), y direction is opposite to v direction 80 | mask_new --PIL.Image (target_size, target_size) 81 | 82 | Parameters: 83 | img --PIL.Image (raw_H, raw_W, 3) 84 | lm --numpy.array (68, 2), y direction is opposite to v direction 85 | lm3D --numpy.array (5, 3) 86 | mask --PIL.Image (raw_H, raw_W, 3) 87 | """ 88 | 89 | w0, h0 = img.size 90 | if lm.shape[0] != 5: 91 | lm5p = extract_5p(lm) 92 | else: 93 | lm5p = lm 94 | 95 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face 96 | t, s = POS(lm5p.transpose(), lm3D.transpose()) 97 | s = rescale_factor/s 98 | 99 | # processing the image 100 | img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) 101 | trans_params = np.array([w0, h0, s, t[0], t[1]]) 102 | 103 | return trans_params, img_new, lm_new, mask_new 104 | -------------------------------------------------------------------------------- /src/face3d/util/test_mean_face.txt: -------------------------------------------------------------------------------- 1 | -5.228591537475585938e+01 2 | 2.078247070312500000e-01 3 | -5.064269638061523438e+01 4 | -1.315765380859375000e+01 5 | -4.952939224243164062e+01 6 | -2.592591094970703125e+01 7 | -4.793047332763671875e+01 8 | -3.832135772705078125e+01 9 | -4.512159729003906250e+01 10 | -5.059623336791992188e+01 11 | -3.917720794677734375e+01 12 | -6.043736648559570312e+01 13 | -2.929953765869140625e+01 14 | -6.861183166503906250e+01 15 | -1.719801330566406250e+01 16 | -7.572736358642578125e+01 17 | -1.961936950683593750e+00 18 | -7.862001037597656250e+01 19 | 1.467941284179687500e+01 20 | -7.607844543457031250e+01 21 | 2.744073486328125000e+01 22 | -6.915261840820312500e+01 23 | 3.855677795410156250e+01 24 | -5.950350570678710938e+01 25 | 4.478240966796875000e+01 26 | -4.867547225952148438e+01 27 | 4.714337158203125000e+01 28 | -3.800830078125000000e+01 29 | 4.940315246582031250e+01 30 | -2.496297454833984375e+01 31 | 5.117234802246093750e+01 32 | -1.241538238525390625e+01 33 | 5.190507507324218750e+01 34 | 8.244247436523437500e-01 35 | -4.150688934326171875e+01 36 | 2.386329650878906250e+01 37 | -3.570307159423828125e+01 38 | 3.017010498046875000e+01 39 | -2.790358734130859375e+01 40 | 3.212951660156250000e+01 41 | -1.941773223876953125e+01 42 | 3.156523132324218750e+01 43 | -1.138106536865234375e+01 44 | 2.841992187500000000e+01 45 | 5.993263244628906250e+00 46 | 2.895182800292968750e+01 47 | 1.343590545654296875e+01 48 | 3.189880371093750000e+01 49 | 2.203153991699218750e+01 50 | 3.302221679687500000e+01 51 | 2.992478942871093750e+01 52 | 3.099150085449218750e+01 53 | 3.628388977050781250e+01 54 | 2.765748596191406250e+01 55 | -1.933914184570312500e+00 56 | 1.405374145507812500e+01 57 | -2.153038024902343750e+00 58 | 5.772636413574218750e+00 59 | -2.270050048828125000e+00 60 | -2.121643066406250000e+00 61 | -2.218330383300781250e+00 62 | -1.068978118896484375e+01 63 | -1.187252044677734375e+01 64 | -1.997912597656250000e+01 65 | -6.879402160644531250e+00 66 | -2.143579864501953125e+01 67 | -1.227821350097656250e+00 68 | -2.193494415283203125e+01 69 | 4.623237609863281250e+00 70 | -2.152721405029296875e+01 71 | 9.721397399902343750e+00 72 | -1.953671264648437500e+01 73 | -3.648714447021484375e+01 74 | 9.811126708984375000e+00 75 | -3.130242919921875000e+01 76 | 1.422447967529296875e+01 77 | -2.212834930419921875e+01 78 | 1.493019866943359375e+01 79 | -1.500880432128906250e+01 80 | 1.073588562011718750e+01 81 | -2.095037078857421875e+01 82 | 9.054298400878906250e+00 83 | -3.050099182128906250e+01 84 | 8.704177856445312500e+00 85 | 1.173237609863281250e+01 86 | 1.054329681396484375e+01 87 | 1.856353759765625000e+01 88 | 1.535009765625000000e+01 89 | 2.893331909179687500e+01 90 | 1.451992797851562500e+01 91 | 3.452944946289062500e+01 92 | 1.065280151367187500e+01 93 | 2.875990295410156250e+01 94 | 8.654792785644531250e+00 95 | 1.942100524902343750e+01 96 | 9.422447204589843750e+00 97 | -2.204488372802734375e+01 98 | -3.983994293212890625e+01 99 | -1.324458312988281250e+01 100 | -3.467377471923828125e+01 101 | -6.749649047851562500e+00 102 | -3.092894744873046875e+01 103 | -9.183349609375000000e-01 104 | -3.196458435058593750e+01 105 | 4.220649719238281250e+00 106 | -3.090406036376953125e+01 107 | 1.089889526367187500e+01 108 | -3.497008514404296875e+01 109 | 1.874589538574218750e+01 110 | -4.065438079833984375e+01 111 | 1.124106597900390625e+01 112 | -4.438417816162109375e+01 113 | 5.181709289550781250e+00 114 | -4.649170684814453125e+01 115 | -1.158607482910156250e+00 116 | -4.680406951904296875e+01 117 | -7.918922424316406250e+00 118 | -4.671575164794921875e+01 119 | -1.452505493164062500e+01 120 | -4.416526031494140625e+01 121 | -2.005007171630859375e+01 122 | -3.997841644287109375e+01 123 | -1.054919433593750000e+01 124 | -3.849683380126953125e+01 125 | -1.051826477050781250e+00 126 | -3.794863128662109375e+01 127 | 6.412681579589843750e+00 128 | -3.804645538330078125e+01 129 | 1.627674865722656250e+01 130 | -4.039697265625000000e+01 131 | 6.373878479003906250e+00 132 | -4.087213897705078125e+01 133 | -8.551712036132812500e-01 134 | -4.157129669189453125e+01 135 | -1.014953613281250000e+01 136 | -4.128469085693359375e+01 137 | -------------------------------------------------------------------------------- /src/face3d/visualize.py: -------------------------------------------------------------------------------- 1 | # check the sync of 3dmm feature and the audio 2 | import cv2 3 | import numpy as np 4 | from src.face3d.models.bfm import ParametricFaceModel 5 | from src.face3d.models.facerecon_model import FaceReconModel 6 | import torch 7 | import subprocess, platform 8 | import scipy.io as scio 9 | from tqdm import tqdm 10 | 11 | # draft 12 | def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64): 13 | 14 | coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] 15 | 16 | coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] 17 | 18 | coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 19 | 20 | coeff_full[:, 80:144] = coeff_pred[:, 0:64] 21 | coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation 22 | coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation 23 | 24 | tmp_video_path = '/tmp/face3dtmp.mp4' 25 | 26 | facemodel = FaceReconModel(args) 27 | 28 | video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) 29 | 30 | for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): 31 | cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) 32 | 33 | facemodel.forward(cur_coeff_full, device) 34 | 35 | predicted_landmark = facemodel.pred_lm # TODO. 36 | predicted_landmark = predicted_landmark.cpu().numpy().squeeze() 37 | 38 | rendered_img = facemodel.pred_face 39 | rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) 40 | out_img = rendered_img[:, :, :3].astype(np.uint8) 41 | 42 | video.write(np.uint8(out_img[:,:,::-1])) 43 | 44 | video.release() 45 | 46 | command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) 47 | subprocess.call(command, shell=platform.system() != 'Windows') 48 | 49 | -------------------------------------------------------------------------------- /src/facerender/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from facerender.modules.util import kp2gaussian 4 | import torch 5 | 6 | 7 | class DownBlock2d(nn.Module): 8 | """ 9 | Simple block for processing video (encoder). 10 | """ 11 | 12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): 13 | super(DownBlock2d, self).__init__() 14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 15 | 16 | if sn: 17 | self.conv = nn.utils.spectral_norm(self.conv) 18 | 19 | if norm: 20 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 21 | else: 22 | self.norm = None 23 | self.pool = pool 24 | 25 | def forward(self, x): 26 | out = x 27 | out = self.conv(out) 28 | if self.norm: 29 | out = self.norm(out) 30 | out = F.leaky_relu(out, 0.2) 31 | if self.pool: 32 | out = F.avg_pool2d(out, (2, 2)) 33 | return out 34 | 35 | 36 | class Discriminator(nn.Module): 37 | """ 38 | Discriminator similar to Pix2Pix 39 | """ 40 | 41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, 42 | sn=False, **kwargs): 43 | super(Discriminator, self).__init__() 44 | 45 | down_blocks = [] 46 | for i in range(num_blocks): 47 | down_blocks.append( 48 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), 49 | min(max_features, block_expansion * (2 ** (i + 1))), 50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) 51 | 52 | self.down_blocks = nn.ModuleList(down_blocks) 53 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 54 | if sn: 55 | self.conv = nn.utils.spectral_norm(self.conv) 56 | 57 | def forward(self, x): 58 | feature_maps = [] 59 | out = x 60 | 61 | for down_block in self.down_blocks: 62 | feature_maps.append(down_block(out)) 63 | out = feature_maps[-1] 64 | prediction_map = self.conv(out) 65 | 66 | return feature_maps, prediction_map 67 | 68 | 69 | class MultiScaleDiscriminator(nn.Module): 70 | """ 71 | Multi-scale (scale) discriminator 72 | """ 73 | 74 | def __init__(self, scales=(), **kwargs): 75 | super(MultiScaleDiscriminator, self).__init__() 76 | self.scales = scales 77 | discs = {} 78 | for scale in scales: 79 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) 80 | self.discs = nn.ModuleDict(discs) 81 | 82 | def forward(self, x): 83 | out_dict = {} 84 | for scale, disc in self.discs.items(): 85 | scale = str(scale).replace('-', '.') 86 | key = 'prediction_' + scale 87 | feature_maps, prediction_map = disc(x[key]) 88 | out_dict['feature_maps_' + scale] = feature_maps 89 | out_dict['prediction_map_' + scale] = prediction_map 90 | return out_dict 91 | -------------------------------------------------------------------------------- /src/facerender/modules/mapping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MappingNet(nn.Module): 9 | def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): 10 | super( MappingNet, self).__init__() 11 | 12 | self.layer = layer 13 | nonlinearity = nn.LeakyReLU(0.1) 14 | 15 | self.first = nn.Sequential( 16 | torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) 17 | 18 | for i in range(layer): 19 | net = nn.Sequential(nonlinearity, 20 | torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) 21 | setattr(self, 'encoder' + str(i), net) 22 | 23 | self.pooling = nn.AdaptiveAvgPool1d(1) 24 | self.output_nc = descriptor_nc 25 | 26 | self.fc_roll = nn.Linear(descriptor_nc, num_bins) 27 | self.fc_pitch = nn.Linear(descriptor_nc, num_bins) 28 | self.fc_yaw = nn.Linear(descriptor_nc, num_bins) 29 | self.fc_t = nn.Linear(descriptor_nc, 3) 30 | self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) 31 | 32 | def forward(self, input_3dmm): 33 | out = self.first(input_3dmm) 34 | for i in range(self.layer): 35 | model = getattr(self, 'encoder' + str(i)) 36 | out = model(out) + out[:,:,3:-3] 37 | out = self.pooling(out) 38 | out = out.view(out.shape[0], -1) 39 | #print('out:', out.shape) 40 | 41 | yaw = self.fc_yaw(out) 42 | pitch = self.fc_pitch(out) 43 | roll = self.fc_roll(out) 44 | t = self.fc_t(out) 45 | exp = self.fc_exp(out) 46 | 47 | return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} -------------------------------------------------------------------------------- /src/facerender/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /src/facerender/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /src/facerender/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /src/facerender/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /src/generate_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import torch 5 | import numpy as np 6 | import random 7 | import scipy.io as scio 8 | import src.utils.audio as audio 9 | 10 | 11 | def crop_pad_audio(wav, audio_length): 12 | if len(wav) > audio_length: 13 | wav = wav[:audio_length] 14 | elif len(wav) < audio_length: 15 | wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) 16 | return wav 17 | 18 | def parse_audio_length(audio_length, sr, fps): 19 | bit_per_frames = sr / fps 20 | 21 | num_frames = int(audio_length / bit_per_frames) 22 | audio_length = int(num_frames * bit_per_frames) 23 | 24 | return audio_length, num_frames 25 | 26 | def generate_blink_seq(num_frames): 27 | ratio = np.zeros((num_frames,1)) 28 | frame_id = 0 29 | while frame_id in range(num_frames): 30 | start = 80 31 | if frame_id+start+9<=num_frames - 1: 32 | ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] 33 | frame_id = frame_id+start+9 34 | else: 35 | break 36 | return ratio 37 | 38 | def generate_blink_seq_randomly(num_frames): 39 | ratio = np.zeros((num_frames,1)) 40 | if num_frames<=21: 41 | return ratio 42 | frame_id = 0 43 | while frame_id in range(num_frames): 44 | start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) 45 | if frame_id+start+5<=num_frames - 1: 46 | ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] 47 | frame_id = frame_id+start+5 48 | else: 49 | break 50 | return ratio 51 | 52 | def get_data(fps, first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): 53 | 54 | syncnet_mel_step_size = 16 55 | 56 | pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] 57 | audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] 58 | 59 | # 若重新训练涉及到 Lipreading,需要分割 wav 60 | if idlemode: 61 | num_frames = int(length_of_audio * fps) 62 | indiv_mels = np.zeros((num_frames, 80, 16)) 63 | else: 64 | wav = audio.load_wav(audio_path, 16000) 65 | wav_length, num_frames = parse_audio_length(len(wav), 16000, fps) 66 | wav = crop_pad_audio(wav, wav_length) 67 | orig_mel = audio.melspectrogram(wav).T 68 | spec = orig_mel.copy() # nframes 80 69 | indiv_mels = [] 70 | 71 | for i in tqdm(range(num_frames), 'mel:'): 72 | start_frame_num = i-2 73 | start_idx = int(80. * (start_frame_num / float(fps))) 74 | end_idx = start_idx + syncnet_mel_step_size 75 | seq = list(range(start_idx, end_idx)) 76 | seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] 77 | m = spec[seq, :] 78 | indiv_mels.append(m.T) 79 | indiv_mels = np.asarray(indiv_mels) # T 80 16 80 | 81 | ratio = generate_blink_seq_randomly(num_frames) # T 82 | source_semantics_path = first_coeff_path 83 | source_semantics_dict = scio.loadmat(source_semantics_path) 84 | ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 85 | ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) 86 | 87 | if ref_eyeblink_coeff_path is not None: 88 | ratio[:num_frames] = 0 89 | refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) 90 | refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] 91 | refeyeblink_num_frames = refeyeblink_coeff.shape[0] 92 | if refeyeblink_num_frames= 0 119 | if hp.symmetric_mels: 120 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 121 | else: 122 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 123 | 124 | def _denormalize(D): 125 | if hp.allow_clipping_in_normalization: 126 | if hp.symmetric_mels: 127 | return (((np.clip(D, -hp.max_abs_value, 128 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 129 | + hp.min_level_db) 130 | else: 131 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 132 | 133 | if hp.symmetric_mels: 134 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 135 | else: 136 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 137 | -------------------------------------------------------------------------------- /src/utils/face_enhancer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from gfpgan import GFPGANer 5 | 6 | from tqdm import tqdm 7 | 8 | from src.utils.videoio import load_video_to_cv2 9 | 10 | import cv2 11 | 12 | 13 | class GeneratorWithLen(object): 14 | """ From https://stackoverflow.com/a/7460929 """ 15 | 16 | def __init__(self, gen, length): 17 | self.gen = gen 18 | self.length = length 19 | 20 | def __len__(self): 21 | return self.length 22 | 23 | def __iter__(self): 24 | return self.gen 25 | 26 | def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'): 27 | gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) 28 | return list(gen) 29 | 30 | def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'): 31 | """ Provide a generator with a __len__ method so that it can passed to functions that 32 | call len()""" 33 | 34 | if os.path.isfile(images): # handle video to images 35 | # TODO: Create a generator version of load_video_to_cv2 36 | images = load_video_to_cv2(images) 37 | 38 | gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) 39 | gen_with_len = GeneratorWithLen(gen, len(images)) 40 | return gen_with_len 41 | 42 | def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'): 43 | """ Provide a generator function so that all of the enhanced images don't need 44 | to be stored in memory at the same time. This can save tons of RAM compared to 45 | the enhancer function. """ 46 | 47 | print('face enhancer....') 48 | if not isinstance(images, list) and os.path.isfile(images): # handle video to images 49 | images = load_video_to_cv2(images) 50 | 51 | # ------------------------ set up GFPGAN restorer ------------------------ 52 | if method == 'gfpgan': 53 | arch = 'clean' 54 | channel_multiplier = 2 55 | model_name = 'GFPGANv1.4' 56 | url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' 57 | elif method == 'RestoreFormer': 58 | arch = 'RestoreFormer' 59 | channel_multiplier = 2 60 | model_name = 'RestoreFormer' 61 | url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' 62 | elif method == 'codeformer': # TODO: 63 | arch = 'CodeFormer' 64 | channel_multiplier = 2 65 | model_name = 'CodeFormer' 66 | url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 67 | else: 68 | raise ValueError(f'Wrong model version {method}.') 69 | 70 | 71 | # ------------------------ set up background upsampler ------------------------ 72 | if bg_upsampler == 'realesrgan': 73 | if not torch.cuda.is_available(): # CPU 74 | import warnings 75 | warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' 76 | 'If you really want to use it, please modify the corresponding codes.') 77 | bg_upsampler = None 78 | else: 79 | from basicsr.archs.rrdbnet_arch import RRDBNet 80 | from realesrgan import RealESRGANer 81 | model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) 82 | bg_upsampler = RealESRGANer( 83 | scale=2, 84 | model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 85 | model=model, 86 | tile=400, 87 | tile_pad=10, 88 | pre_pad=0, 89 | half=True) # need to set False in CPU mode 90 | else: 91 | bg_upsampler = None 92 | 93 | # determine model paths 94 | model_path = os.path.join('gfpgan/weights', model_name + '.pth') 95 | 96 | if not os.path.isfile(model_path): 97 | model_path = os.path.join('checkpoints', model_name + '.pth') 98 | 99 | if not os.path.isfile(model_path): 100 | # download pre-trained models from url 101 | model_path = url 102 | 103 | restorer = GFPGANer( 104 | model_path=model_path, 105 | upscale=2, 106 | arch=arch, 107 | channel_multiplier=channel_multiplier, 108 | bg_upsampler=bg_upsampler) 109 | 110 | # ------------------------ restore ------------------------ 111 | for idx in tqdm(range(len(images)), 'Face Enhancer:'): 112 | 113 | img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) 114 | 115 | # restore faces and background if necessary 116 | cropped_faces, restored_faces, r_img = restorer.enhance( 117 | img, 118 | has_aligned=False, 119 | only_center_face=False, 120 | paste_back=True) 121 | 122 | r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) 123 | yield r_img 124 | -------------------------------------------------------------------------------- /src/utils/init_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): 5 | 6 | if old_version: 7 | #### load all the checkpoint of `pth` 8 | sadtalker_paths = { 9 | 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), 10 | 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), 11 | 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), 12 | 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), 13 | 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') 14 | } 15 | 16 | use_safetensor = False 17 | elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): 18 | print('using safetensor as default') 19 | sadtalker_paths = { 20 | "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), 21 | } 22 | use_safetensor = True 23 | else: 24 | print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") 25 | use_safetensor = False 26 | 27 | sadtalker_paths = { 28 | 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), 29 | 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), 30 | 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), 31 | 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), 32 | 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') 33 | } 34 | 35 | sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' 36 | sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') 37 | sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') 38 | sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') 39 | 40 | if 'full' in preprocess: 41 | sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') 42 | sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') 43 | else: 44 | sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') 45 | sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') 46 | 47 | return sadtalker_paths -------------------------------------------------------------------------------- /src/utils/paste_pic.py: -------------------------------------------------------------------------------- 1 | import cv2, os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import uuid 5 | 6 | from src.utils.videoio import save_video_with_watermark 7 | 8 | def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop=False): 9 | 10 | if not os.path.isfile(pic_path): 11 | raise ValueError('pic_path must be a valid path to video/image file') 12 | elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']: 13 | # loader for first frame 14 | full_img = cv2.imread(pic_path) 15 | else: 16 | # loader for videos 17 | video_stream = cv2.VideoCapture(pic_path) 18 | fps = video_stream.get(cv2.CAP_PROP_FPS) 19 | full_frames = [] 20 | while 1: 21 | still_reading, frame = video_stream.read() 22 | if not still_reading: 23 | video_stream.release() 24 | break 25 | break 26 | full_img = frame 27 | frame_h = full_img.shape[0] 28 | frame_w = full_img.shape[1] 29 | 30 | video_stream = cv2.VideoCapture(video_path) 31 | fps = video_stream.get(cv2.CAP_PROP_FPS) 32 | crop_frames = [] 33 | while 1: 34 | still_reading, frame = video_stream.read() 35 | if not still_reading: 36 | video_stream.release() 37 | break 38 | crop_frames.append(frame) 39 | 40 | if len(crop_info) != 3: 41 | print("you didn't crop the image") 42 | return 43 | else: 44 | r_w, r_h = crop_info[0] 45 | clx, cly, crx, cry = crop_info[1] 46 | lx, ly, rx, ry = crop_info[2] 47 | lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) 48 | # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx 49 | # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx 50 | 51 | if extended_crop: 52 | oy1, oy2, ox1, ox2 = cly, cry, clx, crx 53 | else: 54 | oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx 55 | 56 | tmp_path = str(uuid.uuid4())+'.mp4' 57 | out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h)) 58 | for crop_frame in tqdm(crop_frames, 'seamlessClone:'): 59 | p = cv2.resize(crop_frame.astype(np.uint8), (ox2-ox1, oy2 - oy1)) 60 | 61 | mask = 255*np.ones(p.shape, p.dtype) 62 | location = ((ox1+ox2) // 2, (oy1+oy2) // 2) 63 | gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE) 64 | out_tmp.write(gen_img) 65 | 66 | out_tmp.release() 67 | 68 | save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False) 69 | os.remove(tmp_path) 70 | -------------------------------------------------------------------------------- /src/utils/safetensor_helper.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def load_x_from_safetensor(checkpoint, key): 4 | x_generator = {} 5 | for k,v in checkpoint.items(): 6 | if key in k: 7 | x_generator[k.replace(key+'.', '')] = v 8 | return x_generator -------------------------------------------------------------------------------- /src/utils/text2speech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from TTS.api import TTS 4 | 5 | 6 | class TTSTalker(): 7 | def __init__(self) -> None: 8 | model_name = TTS.list_models()[0] 9 | self.tts = TTS(model_name) 10 | 11 | def test(self, text, language='en'): 12 | 13 | tempf = tempfile.NamedTemporaryFile( 14 | delete = False, 15 | suffix = ('.'+'wav'), 16 | ) 17 | 18 | self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name) 19 | 20 | return tempf.name -------------------------------------------------------------------------------- /src/utils/videoio.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import uuid 3 | 4 | import os 5 | 6 | import cv2 7 | 8 | def load_video_to_cv2(input_path): 9 | video_stream = cv2.VideoCapture(input_path) 10 | fps = video_stream.get(cv2.CAP_PROP_FPS) 11 | full_frames = [] 12 | while 1: 13 | still_reading, frame = video_stream.read() 14 | if not still_reading: 15 | video_stream.release() 16 | break 17 | full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 18 | return full_frames 19 | 20 | def save_video_with_watermark(video, audio, save_path, watermark=False): 21 | temp_file = str(uuid.uuid4())+'.mp4' 22 | cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) 23 | os.system(cmd) 24 | 25 | if watermark is False: 26 | shutil.move(temp_file, save_path) 27 | else: 28 | # watermark 29 | try: 30 | ##### check if stable-diffusion-webui 31 | import webui 32 | from modules import paths 33 | watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png" 34 | except: 35 | # get the root path of sadtalker. 36 | dir_path = os.path.dirname(os.path.realpath(__file__)) 37 | watarmark_path = dir_path+"/../../docs/sadtalker_logo.png" 38 | 39 | cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path) 40 | os.system(cmd) 41 | os.remove(temp_file) -------------------------------------------------------------------------------- /test_curl.sh: -------------------------------------------------------------------------------- 1 | curl -X POST -F "audio_file=@/home/redhat/AiModels/SadTalker/examples/driven_audio/RD_Radio31_000.wav" -F "image_file=@/home/redhat/AiModels/SadTalker-dev/yaqing.jpg" http://localhost:8009/create_video -o output_yq_test1.mp4 2 | 3 | 4 | curl -X POST -F "preprocess=\"full\"" -F "audio_file=@/home/redhat/AiModels/SadTalker/examples/driven_audio/RD_Radio31_000.wav" -F "image_file=@/home/redhat/AiModels/SadTalker-dev/yaqing.jpg" http://localhost:8009/create_video -o output_yq_test1.mp4 5 | 6 | curl -X POST -F "preprocess=full" -F "audio_file=@/home/redhat/AiModels/SadTalker/examples/driven_audio/RD_Radio31_000.wav" -F "image_file=@/home/redhat/AiModels/SadTalker-dev/yaqing.jpg" http://localhost:8009/create_video -o output_yq_test1.mp4 7 | 8 | # curl -X POST -F "audio_file=@/home/redhat/AiModels/SadTalker/examples/driven_audio/RD_Radio31_000.wav" -F "image_file=@/home/redhat/AiModels/SadTalker/examples/source_image/art_5.png" http://localhost:8009/create_video -o output_test.mp4 9 | --------------------------------------------------------------------------------