├── SadTalker ├── custom │ └── put custom ├── input │ └── put input ├── output │ └── put output ├── src │ ├── face3d │ │ ├── models │ │ │ ├── arcface_torch │ │ │ │ ├── docs │ │ │ │ │ ├── modelzoo.md │ │ │ │ │ ├── eval.md │ │ │ │ │ ├── install.md │ │ │ │ │ └── speed_benchmark.md │ │ │ │ ├── eval │ │ │ │ │ └── __init__.py │ │ │ │ ├── configs │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── speed.py │ │ │ │ │ ├── 3millions.py │ │ │ │ │ ├── 3millions_pfc.py │ │ │ │ │ ├── glint360k_mbf.py │ │ │ │ │ ├── glint360k_r100.py │ │ │ │ │ ├── glint360k_r18.py │ │ │ │ │ ├── glint360k_r34.py │ │ │ │ │ ├── glint360k_r50.py │ │ │ │ │ ├── ms1mv3_mbf.py │ │ │ │ │ ├── ms1mv3_r18.py │ │ │ │ │ ├── ms1mv3_r2060.py │ │ │ │ │ ├── ms1mv3_r34.py │ │ │ │ │ ├── ms1mv3_r50.py │ │ │ │ │ └── base.py │ │ │ │ ├── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── utils_os.py │ │ │ │ │ ├── utils_config.py │ │ │ │ │ ├── utils_logging.py │ │ │ │ │ ├── plot.py │ │ │ │ │ ├── utils_amp.py │ │ │ │ │ └── utils_callbacks.py │ │ │ │ ├── requirement.txt │ │ │ │ ├── run.sh │ │ │ │ ├── backbones │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── mobilefacenet.py │ │ │ │ ├── inference.py │ │ │ │ ├── losses.py │ │ │ │ ├── torch2onnx.py │ │ │ │ └── dataset.py │ │ │ ├── __init__.py │ │ │ └── losses.py │ │ ├── util │ │ │ ├── BBRegressorParam_r.mat │ │ │ ├── __init__.py │ │ │ ├── generate_list.py │ │ │ ├── html.py │ │ │ ├── preprocess.py │ │ │ ├── test_mean_face.txt │ │ │ ├── detect_lm68.py │ │ │ ├── load_mats.py │ │ │ └── nvdiffrast.py │ │ ├── options │ │ │ ├── __init__.py │ │ │ ├── test_options.py │ │ │ ├── inference_options.py │ │ │ └── train_options.py │ │ ├── visualize.py │ │ ├── data │ │ │ ├── image_folder.py │ │ │ ├── template_dataset.py │ │ │ ├── flist_dataset.py │ │ │ ├── __init__.py │ │ │ └── base_dataset.py │ │ └── extract_kp_videos.py │ ├── config │ │ ├── similarity_Lm3D_all.mat │ │ ├── facerender.yaml │ │ ├── facerender_still.yaml │ │ ├── auido2pose.yaml │ │ └── auido2exp.yaml │ ├── utils │ │ ├── safetensor_helper.py │ │ ├── text2speech.py │ │ ├── init_path.py │ │ ├── paste_pic.py │ │ ├── videoio.py │ │ ├── face_enhancer.py │ │ └── audio.py │ ├── facerender │ │ ├── sync_batchnorm │ │ │ ├── __init__.py │ │ │ ├── unittest.py │ │ │ ├── replicate.py │ │ │ └── comm.py │ │ └── modules │ │ │ ├── mapping.py │ │ │ └── discriminator.py │ ├── audio2exp_models │ │ ├── audio2exp.py │ │ └── networks.py │ ├── audio2pose_models │ │ ├── res_unet.py │ │ ├── discriminator.py │ │ ├── audio_encoder.py │ │ ├── audio2pose.py │ │ └── networks.py │ └── generate_batch.py ├── requirements.txt ├── requirements3d.txt ├── req.txt ├── scripts │ ├── test.sh │ └── download_models.sh ├── cog.yaml ├── check_ffmpeg.py └── .gitignore ├── examples ├── qq.jpg ├── result.mp4 └── workflow.png ├── requirements.txt ├── __init__.py ├── nodes ├── ShowText.py ├── ShowAudio.py ├── LoadRefVideo.py ├── ShowVideo.py └── SadTalkerNode.py ├── README.md └── web └── js ├── showVideo.js └── showText.js /SadTalker/custom/put custom: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SadTalker/input/put input: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SadTalker/output/put output: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/docs/modelzoo.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/eval/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/utils/utils_os.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/qq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/examples/qq.jpg -------------------------------------------------------------------------------- /examples/result.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/examples/result.mp4 -------------------------------------------------------------------------------- /examples/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/examples/workflow.png -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | -------------------------------------------------------------------------------- /SadTalker/src/config/similarity_Lm3D_all.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/SadTalker/src/config/similarity_Lm3D_all.mat -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/BBRegressorParam_r.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/SadTalker/src/face3d/util/BBRegressorParam_r.mat -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from src.face3d.util import * 3 | 4 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /SadTalker/src/utils/safetensor_helper.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def load_x_from_safetensor(checkpoint, key): 4 | x_generator = {} 5 | for k,v in checkpoint.items(): 6 | if key in k: 7 | x_generator[k.replace(key+'.', '')] = v 8 | return x_generator -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 2 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh 3 | -------------------------------------------------------------------------------- /SadTalker/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.4 2 | face_alignment==1.3.5 3 | imageio==2.19.3 4 | imageio-ffmpeg==0.4.7 5 | librosa==0.9.2 # 6 | numba 7 | resampy==0.3.1 8 | pydub==0.25.1 9 | scipy==1.10.1 10 | kornia==0.6.8 11 | tqdm 12 | yacs==0.1.8 13 | pyyaml 14 | joblib==1.1.0 15 | scikit-image==0.19.3 16 | basicsr==1.4.2 17 | facexlib==0.3.0 18 | gradio==3.50.2 19 | gfpgan 20 | av 21 | safetensors -------------------------------------------------------------------------------- /SadTalker/requirements3d.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.4 2 | face_alignment==1.3.5 3 | imageio==2.19.3 4 | imageio-ffmpeg==0.4.7 5 | librosa==0.9.2 # 6 | numba 7 | resampy==0.3.1 8 | pydub==0.25.1 9 | scipy==1.5.3 10 | kornia==0.6.8 11 | tqdm 12 | yacs==0.1.8 13 | pyyaml 14 | joblib==1.1.0 15 | scikit-image==0.19.3 16 | basicsr==1.4.2 17 | facexlib==0.3.0 18 | trimesh==3.9.20 19 | gradio 20 | gfpgan 21 | safetensors -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.4 2 | face_alignment==1.3.5 3 | imageio==2.19.3 4 | imageio-ffmpeg==0.4.7 5 | librosa==0.9.2 # 6 | numba 7 | resampy==0.3.1 8 | pydub==0.25.1 9 | scipy==1.10.1 10 | kornia==0.6.8 11 | tqdm 12 | yacs==0.1.8 13 | pyyaml 14 | joblib==1.1.0 15 | scikit-image==0.19.3 16 | basicsr==1.4.2 17 | facexlib==0.3.0 18 | gradio==3.50.2 19 | gfpgan 20 | av 21 | safetensors 22 | moviepy -------------------------------------------------------------------------------- /SadTalker/req.txt: -------------------------------------------------------------------------------- 1 | llvmlite==0.38.1 2 | numpy==1.21.6 3 | face_alignment==1.3.5 4 | imageio==2.19.3 5 | imageio-ffmpeg==0.4.7 6 | librosa==0.10.0.post2 7 | numba==0.55.1 8 | resampy==0.3.1 9 | pydub==0.25.1 10 | scipy==1.10.1 11 | kornia==0.6.8 12 | tqdm 13 | yacs==0.1.8 14 | pyyaml 15 | joblib==1.1.0 16 | scikit-image==0.19.3 17 | basicsr==1.4.2 18 | facexlib==0.3.0 19 | gradio 20 | gfpgan 21 | av 22 | safetensors 23 | -------------------------------------------------------------------------------- /SadTalker/src/facerender/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /SadTalker/src/utils/text2speech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from TTS.api import TTS 4 | 5 | 6 | class TTSTalker(): 7 | def __init__(self) -> None: 8 | model_name = TTS().list_models()[0] 9 | self.tts = TTS(model_name) 10 | 11 | def test(self, text, language='en'): 12 | 13 | tempf = tempfile.NamedTemporaryFile( 14 | delete = False, 15 | suffix = ('.'+'wav'), 16 | ) 17 | 18 | self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name) 19 | 20 | return tempf.name 21 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/speed.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 100 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/3millions_pfc.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | assert config_file.startswith('configs/'), 'config file setting must start with configs/' 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module("configs.base") 10 | cfg = config.config 11 | config = importlib.import_module("configs.%s" % temp_module_name) 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join('work_dirs', temp_module_name) 16 | return cfg -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 30 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 20, 25] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r2060" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 64 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Eval on ICCV2021-MFR 2 | 3 | coming soon. 4 | 5 | 6 | ## Eval IJBC 7 | You can eval ijbc with pytorch or onnx. 8 | 9 | 10 | 1. Eval IJBC With Onnx 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 13 | ``` 14 | 15 | 2. Eval IJBC With Pytorch 16 | ```shell 17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ 18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \ 19 | --image-path IJB_release/IJBC \ 20 | --result-dir ms1mv3_arcface_r50 \ 21 | --batch-size 128 \ 22 | --job ms1mv3_arcface_r50 \ 23 | --target IJBC \ 24 | --network iresnet50 25 | ``` 26 | 27 | ## Inference 28 | 29 | ```shell 30 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 31 | ``` 32 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/options/test_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the test options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | 6 | 7 | class TestOptions(BaseOptions): 8 | """This class includes test options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) # define shared options 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') 18 | 19 | # Dropout and Batchnorm has different behavior during training and test. 20 | self.isTrain = False 21 | return parser 22 | -------------------------------------------------------------------------------- /SadTalker/scripts/test.sh: -------------------------------------------------------------------------------- 1 | # ### some test command before commit. 2 | # python inference.py --preprocess crop --size 256 3 | # python inference.py --preprocess crop --size 512 4 | 5 | # python inference.py --preprocess extcrop --size 256 6 | # python inference.py --preprocess extcrop --size 512 7 | 8 | # python inference.py --preprocess resize --size 256 9 | # python inference.py --preprocess resize --size 512 10 | 11 | # python inference.py --preprocess full --size 256 12 | # python inference.py --preprocess full --size 512 13 | 14 | # python inference.py --preprocess extfull --size 256 15 | # python inference.py --preprocess extfull --size 512 16 | 17 | python inference.py --preprocess full --size 256 --enhancer gfpgan 18 | python inference.py --preprocess full --size 512 --enhancer gfpgan 19 | 20 | python inference.py --preprocess full --size 256 --enhancer gfpgan --still 21 | python inference.py --preprocess full --size 512 --enhancer gfpgan --still 22 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | return iresnet18(False, **kwargs) 9 | elif name == "r34": 10 | return iresnet34(False, **kwargs) 11 | elif name == "r50": 12 | return iresnet50(False, **kwargs) 13 | elif name == "r100": 14 | return iresnet100(False, **kwargs) 15 | elif name == "r200": 16 | return iresnet200(False, **kwargs) 17 | elif name == "r2060": 18 | from .iresnet2060 import iresnet2060 19 | return iresnet2060(False, **kwargs) 20 | elif name == "mbf": 21 | fp16 = kwargs.get("fp16", False) 22 | num_features = kwargs.get("num_features", 512) 23 | return get_mbf(fp16=fp16, num_features=num_features) 24 | else: 25 | raise ValueError() -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os, sys 3 | 4 | SadTalkerPath = os.path.abspath(os.path.join(os.path.dirname(__file__), "SadTalker")) 5 | sys.path.append(SadTalkerPath) 6 | 7 | from .nodes.ShowText import ShowText 8 | from .nodes.ShowVideo import ShowVideo 9 | from .nodes.ShowAudio import ShowAudio 10 | from .nodes.LoadRefVideo import LoadRefVideo 11 | from .nodes.SadTalkerNode import SadTalkerNode 12 | 13 | WEB_DIRECTORY = "./web" 14 | 15 | NODE_CLASS_MAPPINGS = { 16 | "SadTalker": SadTalkerNode, 17 | "ShowVideo": ShowVideo, 18 | "ShowText": ShowText, 19 | "ShowAudio": ShowAudio, 20 | "LoadRefVideo": LoadRefVideo, 21 | } 22 | 23 | NODE_DISPLAY_NAME_MAPPINGS = { 24 | "SadTalker": "🦚 SadTalker", 25 | "ShowVideo": "🎥 Show Video", 26 | "LoadRefVideo": "🎥 Load Ref Video", 27 | "ShowText": "💬 Show Text", 28 | "ShowAudio": "🔊 Show Audio", 29 | } 30 | 31 | 32 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] 33 | -------------------------------------------------------------------------------- /SadTalker/src/facerender/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /nodes/ShowText.py: -------------------------------------------------------------------------------- 1 | class ShowText: 2 | @classmethod 3 | def INPUT_TYPES(s): 4 | return { 5 | "required": { 6 | "text": ("STRING", {"forceInput": True}), 7 | }, 8 | "hidden": { 9 | "unique_id": "UNIQUE_ID", 10 | "extra_pnginfo": "EXTRA_PNGINFO", 11 | }, 12 | } 13 | 14 | INPUT_IS_LIST = True 15 | RETURN_TYPES = ("STRING",) 16 | FUNCTION = "notify" 17 | OUTPUT_NODE = True 18 | OUTPUT_IS_LIST = (True,) 19 | 20 | CATEGORY = "SadTalker" 21 | 22 | def notify(self, text, unique_id=None, extra_pnginfo=None): 23 | if unique_id and extra_pnginfo and "workflow" in extra_pnginfo[0]: 24 | workflow = extra_pnginfo[0]["workflow"] 25 | node = next( 26 | (x for x in workflow["nodes"] if str(x["id"]) == unique_id[0]), None 27 | ) 28 | if node: 29 | node["widgets_values"] = [text] 30 | return {"ui": {"text": text}, "result": (text,)} 31 | -------------------------------------------------------------------------------- /nodes/ShowAudio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchaudio # type: ignore 3 | import folder_paths # type: ignore 4 | 5 | 6 | class ShowAudio: 7 | SUPPORTED_FORMATS = (".wav", ".mp3", ".ogg", ".flac", ".aiff", ".aif") 8 | 9 | @classmethod 10 | def INPUT_TYPES(s): 11 | input_dir = folder_paths.get_input_directory() 12 | files = [ 13 | f 14 | for f in os.listdir(input_dir) 15 | if ( 16 | os.path.isfile(os.path.join(input_dir, f)) 17 | and f.endswith(ShowAudio.SUPPORTED_FORMATS) 18 | ) 19 | ] 20 | return {"required": {"audio": (sorted(files), {"audio_upload": True})}} 21 | 22 | CATEGORY = "SadTalker" 23 | 24 | RETURN_TYPES = ("AUDIO",) 25 | FUNCTION = "load" 26 | 27 | def load(self, audio): 28 | audio_path = folder_paths.get_annotated_filepath(audio) 29 | waveform, sample_rate = torchaudio.load(audio_path) 30 | audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} 31 | return (audio,) 32 | -------------------------------------------------------------------------------- /nodes/LoadRefVideo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths # type: ignore 3 | try: 4 | from moviepy import VideoFileClip # type: ignore 5 | except: 6 | from moviepy.editor import VideoFileClip 7 | 8 | input_path = folder_paths.get_input_directory() 9 | 10 | 11 | class LoadRefVideo: 12 | @classmethod 13 | def INPUT_TYPES(s): 14 | files = [ 15 | f 16 | for f in os.listdir(input_path) 17 | if os.path.isfile(os.path.join(input_path, f)) 18 | and f.split(".")[-1] in ["mp4", "webm", "mkv", "avi"] 19 | ] 20 | return { 21 | "required": { 22 | "video": (files,), 23 | } 24 | } 25 | 26 | CATEGORY = "SadTalker" 27 | RETURN_TYPES = ("VIDEO", "VIDEOSTRING") 28 | RETURN_NAMES = ("video", "video_path") 29 | OUTPUT_NODE = False 30 | FUNCTION = "load_video" 31 | 32 | def load_video(self, video): 33 | video_path = os.path.join(input_path, video) 34 | video_clip = VideoFileClip(video_path) 35 | return ( 36 | video_clip, 37 | video_path, 38 | ) 39 | -------------------------------------------------------------------------------- /nodes/ShowVideo.py: -------------------------------------------------------------------------------- 1 | class ShowVideo: 2 | @classmethod 3 | def INPUT_TYPES(s): 4 | return { 5 | "required": { 6 | "show_video_path": ("STRING", {"forceInput": True}), 7 | }, 8 | "hidden": { 9 | "unique_id": "UNIQUE_ID", 10 | "extra_pnginfo": "EXTRA_PNGINFO", 11 | }, 12 | } 13 | 14 | OUTPUT_NODE = True 15 | INPUT_IS_LIST = True 16 | RETURN_TYPES = () 17 | OUTPUT_IS_LIST = (True,) 18 | 19 | FUNCTION = "generate" 20 | CATEGORY = "SadTalker" 21 | 22 | def generate(self, show_video_path, unique_id=None, extra_pnginfo=None): 23 | if unique_id and extra_pnginfo and "workflow" in extra_pnginfo[0]: 24 | workflow = extra_pnginfo[0]["workflow"] 25 | node = next( 26 | (x for x in workflow["nodes"] if str(x["id"]) == unique_id[0]), None 27 | ) 28 | if node: 29 | node["widgets_values"] = [show_video_path] 30 | 31 | # 返回视频路径 32 | return { 33 | "ui": {"show_video_path": show_video_path}, 34 | "result": (show_video_path,), 35 | } 36 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from backbones import get_model 8 | 9 | 10 | @torch.no_grad() 11 | def inference(weight, name, img): 12 | if img is None: 13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 14 | else: 15 | img = cv2.imread(img) 16 | img = cv2.resize(img, (112, 112)) 17 | 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | img = np.transpose(img, (2, 0, 1)) 20 | img = torch.from_numpy(img).unsqueeze(0).float() 21 | img.div_(255).sub_(0.5).div_(0.5) 22 | net = get_model(name, fp16=False) 23 | net.load_state_dict(torch.load(weight)) 24 | net.eval() 25 | feat = net(img).numpy() 26 | print(feat) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') 31 | parser.add_argument('--network', type=str, default='r50', help='backbone network') 32 | parser.add_argument('--weight', type=str, default='') 33 | parser.add_argument('--img', type=str, default=None) 34 | args = parser.parse_args() 35 | inference(args.weight, args.network, args.img) 36 | -------------------------------------------------------------------------------- /SadTalker/cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | cuda: "11.3" 4 | python_version: "3.8" 5 | system_packages: 6 | - "ffmpeg" 7 | - "libgl1-mesa-glx" 8 | - "libglib2.0-0" 9 | python_packages: 10 | - "torch==1.12.1" 11 | - "torchvision==0.13.1" 12 | - "torchaudio==0.12.1" 13 | - "joblib==1.1.0" 14 | - "scikit-image==0.19.3" 15 | - "basicsr==1.4.2" 16 | - "facexlib==0.3.0" 17 | - "resampy==0.3.1" 18 | - "pydub==0.25.1" 19 | - "scipy==1.10.1" 20 | - "kornia==0.6.8" 21 | - "face_alignment==1.3.5" 22 | - "imageio==2.19.3" 23 | - "imageio-ffmpeg==0.4.7" 24 | - "librosa==0.9.2" # 25 | - "tqdm==4.65.0" 26 | - "yacs==0.1.8" 27 | - "gfpgan==1.3.8" 28 | - "dlib-bin==19.24.1" 29 | - "av==10.0.0" 30 | - "trimesh==3.9.20" 31 | run: 32 | - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth" 33 | - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip" 34 | 35 | predict: "predict.py:Predictor" 36 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/options/inference_options.py: -------------------------------------------------------------------------------- 1 | from face3d.options.base_options import BaseOptions 2 | 3 | 4 | class InferenceOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 13 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 14 | 15 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 16 | parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files') 17 | parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients') 18 | parser.add_argument('--save_split_files', action='store_true', help='save split files or not') 19 | parser.add_argument('--inference_batch_size', type=int, default=8) 20 | 21 | # Dropout and Batchnorm has different behavior during training and test. 22 | self.isTrain = False 23 | return parser 24 | -------------------------------------------------------------------------------- /SadTalker/src/config/facerender.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | common_params: 3 | num_kp: 15 4 | image_channel: 3 5 | feature_channel: 32 6 | estimate_jacobian: False # True 7 | kp_detector_params: 8 | temperature: 0.1 9 | block_expansion: 32 10 | max_features: 1024 11 | scale_factor: 0.25 # 0.25 12 | num_blocks: 5 13 | reshape_channel: 16384 # 16384 = 1024 * 16 14 | reshape_depth: 16 15 | he_estimator_params: 16 | block_expansion: 64 17 | max_features: 2048 18 | num_bins: 66 19 | generator_params: 20 | block_expansion: 64 21 | max_features: 512 22 | num_down_blocks: 2 23 | reshape_channel: 32 24 | reshape_depth: 16 # 512 = 32 * 16 25 | num_resblocks: 6 26 | estimate_occlusion_map: True 27 | dense_motion_params: 28 | block_expansion: 32 29 | max_features: 1024 30 | num_blocks: 5 31 | reshape_depth: 16 32 | compress: 4 33 | discriminator_params: 34 | scales: [1] 35 | block_expansion: 32 36 | max_features: 512 37 | num_blocks: 4 38 | sn: True 39 | mapping_params: 40 | coeff_nc: 70 41 | descriptor_nc: 1024 42 | layer: 3 43 | num_kp: 15 44 | num_bins: 66 45 | 46 | -------------------------------------------------------------------------------- /SadTalker/src/config/facerender_still.yaml: -------------------------------------------------------------------------------- 1 | model_params: 2 | common_params: 3 | num_kp: 15 4 | image_channel: 3 5 | feature_channel: 32 6 | estimate_jacobian: False # True 7 | kp_detector_params: 8 | temperature: 0.1 9 | block_expansion: 32 10 | max_features: 1024 11 | scale_factor: 0.25 # 0.25 12 | num_blocks: 5 13 | reshape_channel: 16384 # 16384 = 1024 * 16 14 | reshape_depth: 16 15 | he_estimator_params: 16 | block_expansion: 64 17 | max_features: 2048 18 | num_bins: 66 19 | generator_params: 20 | block_expansion: 64 21 | max_features: 512 22 | num_down_blocks: 2 23 | reshape_channel: 32 24 | reshape_depth: 16 # 512 = 32 * 16 25 | num_resblocks: 6 26 | estimate_occlusion_map: True 27 | dense_motion_params: 28 | block_expansion: 32 29 | max_features: 1024 30 | num_blocks: 5 31 | reshape_depth: 16 32 | compress: 4 33 | discriminator_params: 34 | scales: [1] 35 | block_expansion: 32 36 | max_features: 512 37 | num_blocks: 4 38 | sn: True 39 | mapping_params: 40 | coeff_nc: 73 41 | descriptor_nc: 1024 42 | layer: 3 43 | num_kp: 15 44 | num_bins: 66 45 | 46 | -------------------------------------------------------------------------------- /SadTalker/src/config/auido2pose.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt 3 | EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt 4 | TRAIN_BATCH_SIZE: 64 5 | EVAL_BATCH_SIZE: 1 6 | EXP: True 7 | EXP_DIM: 64 8 | FRAME_LEN: 32 9 | COEFF_LEN: 73 10 | NUM_CLASSES: 46 11 | AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav 12 | COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb 13 | DEBUG: True 14 | 15 | 16 | MODEL: 17 | AUDIOENCODER: 18 | LEAKY_RELU: True 19 | NORM: 'IN' 20 | DISCRIMINATOR: 21 | LEAKY_RELU: False 22 | INPUT_CHANNELS: 6 23 | CVAE: 24 | AUDIO_EMB_IN_SIZE: 512 25 | AUDIO_EMB_OUT_SIZE: 6 26 | SEQ_LEN: 32 27 | LATENT_SIZE: 64 28 | ENCODER_LAYER_SIZES: [192, 128] 29 | DECODER_LAYER_SIZES: [128, 192] 30 | 31 | 32 | TRAIN: 33 | MAX_EPOCH: 150 34 | GENERATOR: 35 | LR: 1.0e-4 36 | DISCRIMINATOR: 37 | LR: 1.0e-4 38 | LOSS: 39 | LAMBDA_REG: 1 40 | LAMBDA_LANDMARKS: 0 41 | LAMBDA_VERTICES: 0 42 | LAMBDA_GAN_MOTION: 0.7 43 | LAMBDA_GAN_COEFF: 0 44 | LAMBDA_KL: 1 45 | 46 | TAG: 47 | NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder 48 | 49 | 50 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(rank, models_root): 31 | if rank == 0: 32 | log_root = logging.getLogger() 33 | log_root.setLevel(logging.INFO) 34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def get_loss(name): 6 | if name == "cosface": 7 | return CosFace() 8 | elif name == "arcface": 9 | return ArcFace() 10 | else: 11 | raise ValueError() 12 | 13 | 14 | class CosFace(nn.Module): 15 | def __init__(self, s=64.0, m=0.40): 16 | super(CosFace, self).__init__() 17 | self.s = s 18 | self.m = m 19 | 20 | def forward(self, cosine, label): 21 | index = torch.where(label != -1)[0] 22 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 23 | m_hot.scatter_(1, label[index, None], self.m) 24 | cosine[index] -= m_hot 25 | ret = cosine * self.s 26 | return ret 27 | 28 | 29 | class ArcFace(nn.Module): 30 | def __init__(self, s=64.0, m=0.5): 31 | super(ArcFace, self).__init__() 32 | self.s = s 33 | self.m = m 34 | 35 | def forward(self, cosine: torch.Tensor, label): 36 | index = torch.where(label != -1)[0] 37 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 38 | m_hot.scatter_(1, label[index, None], self.m) 39 | cosine.acos_() 40 | cosine[index] += m_hot 41 | cosine.cos_().mul_(self.s) 42 | return cosine 43 | -------------------------------------------------------------------------------- /SadTalker/src/audio2exp_models/audio2exp.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class Audio2Exp(nn.Module): 7 | def __init__(self, netG, cfg, device, prepare_training_loss=False): 8 | super(Audio2Exp, self).__init__() 9 | self.cfg = cfg 10 | self.device = device 11 | self.netG = netG.to(device) 12 | 13 | def test(self, batch): 14 | 15 | mel_input = batch['indiv_mels'] # bs T 1 80 16 16 | bs = mel_input.shape[0] 17 | T = mel_input.shape[1] 18 | 19 | exp_coeff_pred = [] 20 | 21 | for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames 22 | 23 | current_mel_input = mel_input[:,i:i+10] 24 | 25 | #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64 26 | ref = batch['ref'][:, :, :64][:, i:i+10] 27 | ratio = batch['ratio_gt'][:, i:i+10] #bs T 28 | 29 | audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16 30 | 31 | curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64 32 | 33 | exp_coeff_pred += [curr_exp_coeff_pred] 34 | 35 | # BS x T x 64 36 | results_dict = { 37 | 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) 38 | } 39 | return results_dict 40 | 41 | 42 | -------------------------------------------------------------------------------- /SadTalker/src/config/auido2exp.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt 3 | EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt 4 | TRAIN_BATCH_SIZE: 32 5 | EVAL_BATCH_SIZE: 32 6 | EXP: True 7 | EXP_DIM: 64 8 | FRAME_LEN: 32 9 | COEFF_LEN: 73 10 | NUM_CLASSES: 46 11 | AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav 12 | COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm 13 | LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb 14 | DEBUG: True 15 | NUM_REPEATS: 2 16 | T: 40 17 | 18 | 19 | MODEL: 20 | FRAMEWORK: V2 21 | AUDIOENCODER: 22 | LEAKY_RELU: True 23 | NORM: 'IN' 24 | DISCRIMINATOR: 25 | LEAKY_RELU: False 26 | INPUT_CHANNELS: 6 27 | CVAE: 28 | AUDIO_EMB_IN_SIZE: 512 29 | AUDIO_EMB_OUT_SIZE: 128 30 | SEQ_LEN: 32 31 | LATENT_SIZE: 256 32 | ENCODER_LAYER_SIZES: [192, 1024] 33 | DECODER_LAYER_SIZES: [1024, 192] 34 | 35 | 36 | TRAIN: 37 | MAX_EPOCH: 300 38 | GENERATOR: 39 | LR: 2.0e-5 40 | DISCRIMINATOR: 41 | LR: 1.0e-5 42 | LOSS: 43 | W_FEAT: 0 44 | W_COEFF_EXP: 2 45 | W_LM: 1.0e-2 46 | W_LM_MOUTH: 0 47 | W_REG: 0 48 | W_SYNC: 0 49 | W_COLOR: 0 50 | W_EXPRESSION: 0 51 | W_LIPREADING: 0.01 52 | W_LIPREADING_VV: 0 53 | W_EYE_BLINK: 4 54 | 55 | TAG: 56 | NAME: small_dataset 57 | 58 | 59 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/generate_list.py: -------------------------------------------------------------------------------- 1 | """This script is to generate training list files for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | 6 | # save path to training data 7 | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''): 8 | save_path = os.path.join(save_folder, mode) 9 | if not os.path.isdir(save_path): 10 | os.makedirs(save_path) 11 | with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd: 12 | fd.writelines([i + '\n' for i in lms_list]) 13 | 14 | with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd: 15 | fd.writelines([i + '\n' for i in imgs_list]) 16 | 17 | with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd: 18 | fd.writelines([i + '\n' for i in msks_list]) 19 | 20 | # check if the path is valid 21 | def check_list(rlms_list, rimgs_list, rmsks_list): 22 | lms_list, imgs_list, msks_list = [], [], [] 23 | for i in range(len(rlms_list)): 24 | flag = 'false' 25 | lm_path = rlms_list[i] 26 | im_path = rimgs_list[i] 27 | msk_path = rmsks_list[i] 28 | if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): 29 | flag = 'true' 30 | lms_list.append(rlms_list[i]) 31 | imgs_list.append(rimgs_list[i]) 32 | msks_list.append(rmsks_list[i]) 33 | print(i, rlms_list[i], flag) 34 | return lms_list, imgs_list, msks_list 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comfyui-SadTalker 2 | 3 | ### 版权声明(Copyright) 4 | 5 | - [SadTalker](https://github.com/OpenTalker/SadTalker) 6 | 7 | ### 注意(Notice) 8 | 9 | - 首次开发插件,节点可能存在问题,如果遇到了困难,请提交 Issues.(This is the first time to develop a plugin. There may be problems with the node. If you run into difficulties, please submit Issues.) 10 | - 使用ComfyUi-aki-v1.3.7z(新版本未适配) 11 | 12 | ### 例子(Examples) 13 | 14 | ![Workflow](./examples/workflow.png) 15 | 16 | ### 可能遇到的问题(Problems That You May Have) 17 | 18 | - 如果报错 audio 不存在或者未定义,请升级 comfyui.(If you get the error that audio does not exist or is undefined, upgrade comfyui.) 19 | - 如果 video 显示异常,请刷新浏览器.(Reload your browser if the video displays abnormally.) 20 | - 如果报错:FileNotFoundError: [Errno 2] No such file or directory: 'F:\\\\ComfyUI-aki-v1.3\\output\\6b5735c4-ef50-4933-a34d-ccbad3b5936d.mp4',请下载ffmpeg并且将ffmpeg/bin目录加入环境变量中.(If you get an error: FileNotFoundError: [Errno 2] No such file or directory: 'F:\\\\ComfyUI-aki-v1.3\\output\\6b5735c4-ef50-4933-a34d-ccbad3b5936d.mp4', download ffmpeg and add the ffmpeg/bin directory to the environment variable.) 21 | 22 | ### 安装(Install) 23 | 24 | 1. ...custom_nodes\Comfyui-SadTalker\SadTalker\checkpoints\ 25 | - SadTalker_V0.0.2_256.safetensors(691MB) 26 | - SadTalker_V0.0.2_512.safetensors(691MB) 27 | - mapping_00109-model.pth.tar(148MB) 28 | - mapping_00229-model.pth.tar(148MB) 29 | 2. ...(comfyui root)\gfpgan\weights\ 30 | - alignment_WFLW_4HG.pth(184MB) 31 | - detection_Resnet50_Final.pth(104MB) 32 | - GFPGANv1.4.pth(332MB) 33 | - parsing_parsenet.pth(81.3MB) 34 | 35 | ### 联系我(Contact Me) 36 | 37 | 38 | -------------------------------------------------------------------------------- /SadTalker/src/facerender/modules/mapping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MappingNet(nn.Module): 9 | def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins): 10 | super( MappingNet, self).__init__() 11 | 12 | self.layer = layer 13 | nonlinearity = nn.LeakyReLU(0.1) 14 | 15 | self.first = nn.Sequential( 16 | torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) 17 | 18 | for i in range(layer): 19 | net = nn.Sequential(nonlinearity, 20 | torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) 21 | setattr(self, 'encoder' + str(i), net) 22 | 23 | self.pooling = nn.AdaptiveAvgPool1d(1) 24 | self.output_nc = descriptor_nc 25 | 26 | self.fc_roll = nn.Linear(descriptor_nc, num_bins) 27 | self.fc_pitch = nn.Linear(descriptor_nc, num_bins) 28 | self.fc_yaw = nn.Linear(descriptor_nc, num_bins) 29 | self.fc_t = nn.Linear(descriptor_nc, 3) 30 | self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp) 31 | 32 | def forward(self, input_3dmm): 33 | out = self.first(input_3dmm) 34 | for i in range(self.layer): 35 | model = getattr(self, 'encoder' + str(i)) 36 | out = model(out) + out[:,:,3:-3] 37 | out = self.pooling(out) 38 | out = out.view(out.shape[0], -1) 39 | #print('out:', out.shape) 40 | 41 | yaw = self.fc_yaw(out) 42 | pitch = self.fc_pitch(out) 43 | roll = self.fc_roll(out) 44 | t = self.fc_t(out) 45 | exp = self.fc_exp(out) 46 | 47 | return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/docs/install.md: -------------------------------------------------------------------------------- 1 | ## v1.8.0 2 | ### Linux and Windows 3 | ```shell 4 | # CUDA 11.0 5 | pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 6 | 7 | # CUDA 10.2 8 | pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 9 | 10 | # CPU only 11 | pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 12 | 13 | ``` 14 | 15 | 16 | ## v1.7.1 17 | ### Linux and Windows 18 | ```shell 19 | # CUDA 11.0 20 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 21 | 22 | # CUDA 10.2 23 | pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 24 | 25 | # CUDA 10.1 26 | pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 27 | 28 | # CUDA 9.2 29 | pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 30 | 31 | # CPU only 32 | pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 33 | ``` 34 | 35 | 36 | ## v1.6.0 37 | 38 | ### Linux and Windows 39 | ```shell 40 | # CUDA 10.2 41 | pip install torch==1.6.0 torchvision==0.7.0 42 | 43 | # CUDA 10.1 44 | pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 45 | 46 | # CUDA 9.2 47 | pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html 48 | 49 | # CPU only 50 | pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 51 | ``` -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = "ms1mv3_arcface_r50" 12 | 13 | config.dataset = "ms1m-retinaface-t1" 14 | config.embedding_size = 512 15 | config.sample_rate = 1 16 | config.fp16 = False 17 | config.momentum = 0.9 18 | config.weight_decay = 5e-4 19 | config.batch_size = 128 20 | config.lr = 0.1 # batch size is 512 21 | 22 | if config.dataset == "emore": 23 | config.rec = "/train_tmp/faces_emore" 24 | config.num_classes = 85742 25 | config.num_image = 5822653 26 | config.num_epoch = 16 27 | config.warmup_epoch = -1 28 | config.decay_epoch = [8, 14, ] 29 | config.val_targets = ["lfw", ] 30 | 31 | elif config.dataset == "ms1m-retinaface-t1": 32 | config.rec = "/train_tmp/ms1m-retinaface-t1" 33 | config.num_classes = 93431 34 | config.num_image = 5179510 35 | config.num_epoch = 25 36 | config.warmup_epoch = -1 37 | config.decay_epoch = [11, 17, 22] 38 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 39 | 40 | elif config.dataset == "glint360k": 41 | config.rec = "/train_tmp/glint360k" 42 | config.num_classes = 360232 43 | config.num_image = 17091657 44 | config.num_epoch = 20 45 | config.warmup_epoch = -1 46 | config.decay_epoch = [8, 12, 15, 18] 47 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 48 | 49 | elif config.dataset == "webface": 50 | config.rec = "/train_tmp/faces_webface_112x112" 51 | config.num_classes = 10572 52 | config.num_image = "forget" 53 | config.num_epoch = 34 54 | config.warmup_epoch = -1 55 | config.decay_epoch = [20, 28, 32] 56 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 57 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/visualize.py: -------------------------------------------------------------------------------- 1 | # check the sync of 3dmm feature and the audio 2 | import cv2 3 | import numpy as np 4 | from src.face3d.models.bfm import ParametricFaceModel 5 | from src.face3d.models.facerecon_model import FaceReconModel 6 | import torch 7 | import subprocess, platform 8 | import scipy.io as scio 9 | from tqdm import tqdm 10 | 11 | # draft 12 | def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64): 13 | 14 | coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm'] 15 | 16 | coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm'] 17 | 18 | coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257 19 | 20 | coeff_full[:, 80:144] = coeff_pred[:, 0:64] 21 | coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation 22 | coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation 23 | 24 | tmp_video_path = '/tmp/face3dtmp.mp4' 25 | 26 | facemodel = FaceReconModel(args) 27 | 28 | video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224)) 29 | 30 | for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'): 31 | cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device) 32 | 33 | facemodel.forward(cur_coeff_full, device) 34 | 35 | predicted_landmark = facemodel.pred_lm # TODO. 36 | predicted_landmark = predicted_landmark.cpu().numpy().squeeze() 37 | 38 | rendered_img = facemodel.pred_face 39 | rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0) 40 | out_img = rendered_img[:, :, :3].astype(np.uint8) 41 | 42 | video.write(np.uint8(out_img[:,:,::-1])) 43 | 44 | video.release() 45 | 46 | command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path) 47 | subprocess.call(command, shell=platform.system() != 'Windows') 48 | 49 | -------------------------------------------------------------------------------- /web/js/showVideo.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { api } from "../../../scripts/api.js"; 3 | import { ComfyWidgets } from "../../../scripts/widgets.js"; 4 | 5 | let video; 6 | 7 | const ext = { 8 | name: "SadTalker.ShowVideo", 9 | 10 | async beforeRegisterNodeDef (nodeType, nodeData, app) { 11 | if (nodeData.name === "ShowVideo") { 12 | function populate (videoPath) { 13 | 14 | const filePath = videoPath[0] 15 | if (video && filePath) { 16 | const fileName = filePath.split('\\').pop(); 17 | video.src = api.apiURL(`/view?filename=${fileName}`); 18 | video.play(); 19 | } 20 | } 21 | 22 | // When the node is executed we will be sent the input video path, display this in the widget 23 | const onExecuted = nodeType.prototype.onExecuted; 24 | nodeType.prototype.onExecuted = function (message) { 25 | onExecuted?.apply(this, arguments); 26 | populate.call(this, message.show_video_path); 27 | }; 28 | 29 | const onConfigure = nodeType.prototype.onConfigure; 30 | nodeType.prototype.onConfigure = function () { 31 | onConfigure?.apply(this, arguments); 32 | if (this.widgets_values?.length) { 33 | populate.call(this, this.widgets_values); 34 | } 35 | }; 36 | } 37 | }, 38 | 39 | loadedGraphNode (node, app) { 40 | if (node.type == "ShowVideo") { 41 | if (video) return; 42 | const container = document.createElement("div"); 43 | container.style.background = "rgba(0,0,0,0.25)"; 44 | container.style.textAlign = "center"; 45 | video = document.createElement("video") 46 | video.style.height = video.style.width = "100%"; 47 | video.controls = true 48 | video.classList.add("comfy-video") 49 | video.setAttribute("name", "media") 50 | container.replaceChildren(video); 51 | node.addDOMWidget("video", "VIDEO", container) 52 | } 53 | } 54 | }; 55 | 56 | 57 | 58 | app.registerExtension(ext); 59 | -------------------------------------------------------------------------------- /web/js/showText.js: -------------------------------------------------------------------------------- 1 | import { app } from "../../../scripts/app.js"; 2 | import { ComfyWidgets } from "../../../scripts/widgets.js"; 3 | 4 | const ext = { 5 | name: "SadTalker.ShowText", 6 | 7 | async beforeRegisterNodeDef (nodeType, nodeData, app) { 8 | if (nodeData.name === "ShowText") { 9 | function populate (text) { 10 | if (this.widgets) { 11 | for (let i = 1; i < this.widgets.length; i++) { 12 | this.widgets[i].onRemove?.(); 13 | } 14 | this.widgets.length = 1; 15 | } 16 | 17 | const v = [...text]; 18 | if (!v[0]) { 19 | v.shift(); 20 | } 21 | for (const list of v) { 22 | const w = ComfyWidgets["STRING"](this, "text", ["STRING", { multiline: true }], app).widget; 23 | w.inputEl.readOnly = true; 24 | w.inputEl.style.opacity = 0.6; 25 | w.value = list; 26 | } 27 | 28 | requestAnimationFrame(() => { 29 | const sz = this.computeSize(); 30 | if (sz[0] < this.size[0]) { 31 | sz[0] = this.size[0]; 32 | } 33 | if (sz[1] < this.size[1]) { 34 | sz[1] = this.size[1]; 35 | } 36 | this.onResize?.(sz); 37 | app.graph.setDirtyCanvas(true, false); 38 | }); 39 | } 40 | 41 | // When the node is executed we will be sent the input text, display this in the widget 42 | const onExecuted = nodeType.prototype.onExecuted; 43 | nodeType.prototype.onExecuted = function (message) { 44 | onExecuted?.apply(this, arguments); 45 | populate.call(this, message.text); 46 | }; 47 | 48 | const onConfigure = nodeType.prototype.onConfigure; 49 | nodeType.prototype.onConfigure = function () { 50 | onConfigure?.apply(this, arguments); 51 | if (this.widgets_values?.length) { 52 | populate.call(this, this.widgets_values); 53 | } 54 | }; 55 | } 56 | 57 | // This fires for every node definition so only log once 58 | // delete ext.beforeRegisterNodeDef; 59 | }, 60 | }; 61 | 62 | app.registerExtension(ext); 63 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | import numpy as np 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | '.tif', '.TIF', '.tiff', '.TIFF', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, max_dataset_size=float("inf")): 25 | images = [] 26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | return images[:min(max_dataset_size, len(images))] 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /SadTalker/src/audio2pose_models/res_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from src.audio2pose_models.networks import ResidualConv, Upsample 4 | 5 | 6 | class ResUnet(nn.Module): 7 | def __init__(self, channel=1, filters=[32, 64, 128, 256]): 8 | super(ResUnet, self).__init__() 9 | 10 | self.input_layer = nn.Sequential( 11 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), 12 | nn.BatchNorm2d(filters[0]), 13 | nn.ReLU(), 14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 15 | ) 16 | self.input_skip = nn.Sequential( 17 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) 18 | ) 19 | 20 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1) 21 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1) 22 | 23 | self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1) 24 | 25 | self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1)) 26 | self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1) 27 | 28 | self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1)) 29 | self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1) 30 | 31 | self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1)) 32 | self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1) 33 | 34 | self.output_layer = nn.Sequential( 35 | nn.Conv2d(filters[0], 1, 1, 1), 36 | nn.Sigmoid(), 37 | ) 38 | 39 | def forward(self, x): 40 | # Encode 41 | x1 = self.input_layer(x) + self.input_skip(x) 42 | x2 = self.residual_conv_1(x1) 43 | x3 = self.residual_conv_2(x2) 44 | # Bridge 45 | x4 = self.bridge(x3) 46 | 47 | # Decode 48 | x4 = self.upsample_1(x4) 49 | x5 = torch.cat([x4, x3], dim=1) 50 | 51 | x6 = self.up_residual_conv1(x5) 52 | 53 | x6 = self.upsample_2(x6) 54 | x7 = torch.cat([x6, x2], dim=1) 55 | 56 | x8 = self.up_residual_conv2(x7) 57 | 58 | x8 = self.upsample_3(x8) 59 | x9 = torch.cat([x8, x1], dim=1) 60 | 61 | x10 = self.up_residual_conv3(x9) 62 | 63 | output = self.output_layer(x10) 64 | 65 | return output -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/utils/plot.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 10 | from prettytable import PrettyTable 11 | from sklearn.metrics import roc_curve, auc 12 | 13 | image_path = "/data/anxiang/IJB_release/IJBC" 14 | files = [ 15 | "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" 16 | ] 17 | 18 | 19 | def read_template_pair_list(path): 20 | pairs = pd.read_csv(path, sep=' ', header=None).values 21 | t1 = pairs[:, 0].astype(np.int) 22 | t2 = pairs[:, 1].astype(np.int) 23 | label = pairs[:, 2].astype(np.int) 24 | return t1, t2, label 25 | 26 | 27 | p1, p2, label = read_template_pair_list( 28 | os.path.join('%s/meta' % image_path, 29 | '%s_template_pair_label.txt' % 'ijbc')) 30 | 31 | methods = [] 32 | scores = [] 33 | for file in files: 34 | methods.append(file.split('/')[-2]) 35 | scores.append(np.load(file)) 36 | 37 | methods = np.array(methods) 38 | scores = dict(zip(methods, scores)) 39 | colours = dict( 40 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 41 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 42 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 43 | fig = plt.figure() 44 | for method in methods: 45 | fpr, tpr, _ = roc_curve(label, scores[method]) 46 | roc_auc = auc(fpr, tpr) 47 | fpr = np.flipud(fpr) 48 | tpr = np.flipud(tpr) # select largest tpr at same fpr 49 | plt.plot(fpr, 50 | tpr, 51 | color=colours[method], 52 | lw=1, 53 | label=('[%s (AUC = %0.4f %%)]' % 54 | (method.split('-')[-1], roc_auc * 100))) 55 | tpr_fpr_row = [] 56 | tpr_fpr_row.append("%s-%s" % (method, "IJBC")) 57 | for fpr_iter in np.arange(len(x_labels)): 58 | _, min_index = min( 59 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 60 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 61 | tpr_fpr_table.add_row(tpr_fpr_row) 62 | plt.xlim([10 ** -6, 0.1]) 63 | plt.ylim([0.3, 1.0]) 64 | plt.grid(linestyle='--', linewidth=1) 65 | plt.xticks(x_labels) 66 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 67 | plt.xscale('log') 68 | plt.xlabel('False Positive Rate') 69 | plt.ylabel('True Positive Rate') 70 | plt.title('ROC on IJB') 71 | plt.legend(loc="lower right") 72 | print(tpr_fpr_table) 73 | -------------------------------------------------------------------------------- /SadTalker/check_ffmpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | from pathlib import Path 5 | import winreg 6 | 7 | def check_ffmpeg_path(): 8 | path_list = os.environ['Path'].split(';') 9 | ffmpeg_found = False 10 | 11 | for path in path_list: 12 | if 'ffmpeg' in path.lower() and 'bin' in path.lower(): 13 | ffmpeg_found = True 14 | print("FFmpeg already installed, skipping...") 15 | break 16 | 17 | return ffmpeg_found 18 | 19 | def add_ffmpeg_path_to_user_variable(): 20 | ffmpeg_bin_path = Path('.\\ffmpeg\\bin') 21 | if ffmpeg_bin_path.is_dir(): 22 | abs_path = str(ffmpeg_bin_path.resolve()) 23 | 24 | try: 25 | key = winreg.OpenKey( 26 | winreg.HKEY_CURRENT_USER, 27 | r"Environment", 28 | 0, 29 | winreg.KEY_READ | winreg.KEY_WRITE 30 | ) 31 | 32 | try: 33 | current_path, _ = winreg.QueryValueEx(key, "Path") 34 | if abs_path not in current_path: 35 | new_path = f"{current_path};{abs_path}" 36 | winreg.SetValueEx(key, "Path", 0, winreg.REG_EXPAND_SZ, new_path) 37 | print(f"Added FFmpeg path to user variable 'Path': {abs_path}") 38 | else: 39 | print("FFmpeg path already exists in the user variable 'Path'.") 40 | finally: 41 | winreg.CloseKey(key) 42 | except WindowsError: 43 | print("Error: Unable to modify user variable 'Path'.") 44 | sys.exit(1) 45 | 46 | else: 47 | print("Error: ffmpeg\\bin folder not found in the current path.") 48 | sys.exit(1) 49 | 50 | def get_default_browser(): 51 | browser = "" 52 | try: 53 | cmd = r'reg query HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice /v ProgId' 54 | output = subprocess.check_output(cmd, shell=True).decode() 55 | browser = output.split()[-1].split('\\')[-1] 56 | except Exception as e: 57 | print(f"Error: {e}") 58 | browser = "Unknown" 59 | return browser 60 | 61 | def main(): 62 | if not check_ffmpeg_path(): 63 | add_ffmpeg_path_to_user_variable() 64 | default_browser = get_default_browser() 65 | if not "chrome" in default_browser.lower() and not "edge" in default_browser.lower() and not "firefox" in default_browser.lower(): 66 | print("默认浏览器不符合要求,可能会影响使用,请更换为Chrome或Edge浏览器") 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float64) 10 | img = (img / 255.0 - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight) 16 | net.eval() 17 | torch.onnx.export( 18 | net, 19 | img, 20 | output, 21 | keep_initializers_as_inputs=False, 22 | verbose=False, 23 | opset_version=opset, 24 | ) 25 | model = onnx.load(output) 26 | graph = model.graph 27 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None" 28 | if simplify: 29 | from onnxsim import simplify 30 | 31 | model, check = simplify(model) 32 | assert check, "Simplified ONNX model could not be validated" 33 | onnx.save(model, output) 34 | 35 | 36 | if __name__ == "__main__": 37 | import os 38 | import argparse 39 | from backbones import get_model 40 | 41 | parser = argparse.ArgumentParser(description="ArcFace PyTorch to onnx") 42 | parser.add_argument("input", type=str, help="input backbone.pth file or path") 43 | parser.add_argument("--output", type=str, default=None, help="output onnx path") 44 | parser.add_argument("--network", type=str, default=None, help="backbone network") 45 | parser.add_argument("--simplify", type=bool, default=False, help="onnx simplify") 46 | args = parser.parse_args() 47 | input_file = args.input 48 | if os.path.isdir(input_file): 49 | input_file = os.path.join(input_file, "backbone.pth") 50 | assert os.path.exists(input_file) 51 | model_name = os.path.basename(os.path.dirname(input_file)).lower() 52 | params = model_name.split("_") 53 | if len(params) >= 3 and params[1] in ("arcface", "cosface"): 54 | if args.network is None: 55 | args.network = params[2] 56 | assert args.network is not None 57 | print(args) 58 | backbone_onnx = get_model(args.network, dropout=0) 59 | 60 | output_path = args.output 61 | if output_path is None: 62 | output_path = os.path.join(os.path.dirname(__file__), "onnx") 63 | if not os.path.exists(output_path): 64 | os.makedirs(output_path) 65 | assert os.path.isdir(output_path) 66 | output_file = os.path.join(output_path, "%s.onnx" % model_name) 67 | convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) 68 | -------------------------------------------------------------------------------- /SadTalker/src/utils/init_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'): 5 | 6 | if old_version: 7 | #### load all the checkpoint of `pth` 8 | sadtalker_paths = { 9 | 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), 10 | 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), 11 | 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), 12 | 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), 13 | 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') 14 | } 15 | 16 | use_safetensor = False 17 | elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))): 18 | print('using safetensor as default') 19 | sadtalker_paths = { 20 | "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'), 21 | } 22 | use_safetensor = True 23 | else: 24 | print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!") 25 | use_safetensor = False 26 | 27 | sadtalker_paths = { 28 | 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'), 29 | 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'), 30 | 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'), 31 | 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'), 32 | 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth') 33 | } 34 | 35 | sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting' 36 | sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml') 37 | sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml') 38 | sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml') 39 | 40 | if 'full' in preprocess: 41 | sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar') 42 | sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml') 43 | else: 44 | sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar') 45 | sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml') 46 | 47 | return sadtalker_paths -------------------------------------------------------------------------------- /SadTalker/src/utils/paste_pic.py: -------------------------------------------------------------------------------- 1 | import cv2, os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import uuid 5 | 6 | from src.utils.videoio import save_video_with_watermark 7 | 8 | 9 | def paste_pic( 10 | video_path, 11 | pic_path, 12 | crop_info, 13 | new_audio_path, 14 | full_video_path, 15 | extended_crop=False, 16 | ): 17 | 18 | if not os.path.isfile(pic_path): 19 | raise ValueError("pic_path must be a valid path to video/image file") 20 | elif pic_path.split(".")[-1] in ["jpg", "png", "jpeg"]: 21 | # loader for first frame 22 | full_img = cv2.imread(pic_path) 23 | else: 24 | # loader for videos 25 | video_stream = cv2.VideoCapture(pic_path) 26 | fps = video_stream.get(cv2.CAP_PROP_FPS) 27 | full_frames = [] 28 | while 1: 29 | still_reading, frame = video_stream.read() 30 | if not still_reading: 31 | video_stream.release() 32 | break 33 | break 34 | full_img = frame 35 | frame_h = full_img.shape[0] 36 | frame_w = full_img.shape[1] 37 | 38 | video_stream = cv2.VideoCapture(video_path) 39 | fps = video_stream.get(cv2.CAP_PROP_FPS) 40 | crop_frames = [] 41 | while 1: 42 | still_reading, frame = video_stream.read() 43 | if not still_reading: 44 | video_stream.release() 45 | break 46 | crop_frames.append(frame) 47 | 48 | if len(crop_info) != 3: 49 | print("you didn't crop the image") 50 | return 51 | else: 52 | r_w, r_h = crop_info[0] 53 | clx, cly, crx, cry = crop_info[1] 54 | lx, ly, rx, ry = crop_info[2] 55 | lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry) 56 | # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx 57 | # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx 58 | 59 | if extended_crop: 60 | oy1, oy2, ox1, ox2 = cly, cry, clx, crx 61 | else: 62 | oy1, oy2, ox1, ox2 = cly + ly, cly + ry, clx + lx, clx + rx 63 | 64 | tmp_path = str(uuid.uuid4()) + ".mp4" 65 | out_tmp = cv2.VideoWriter( 66 | tmp_path, cv2.VideoWriter_fourcc(*"MP4V"), fps, (frame_w, frame_h) 67 | ) 68 | for crop_frame in tqdm(crop_frames, "seamlessClone:"): 69 | p = cv2.resize(crop_frame.astype(np.uint8), (ox2 - ox1, oy2 - oy1)) 70 | 71 | mask = 255 * np.ones(p.shape, p.dtype) 72 | location = ((ox1 + ox2) // 2, (oy1 + oy2) // 2) 73 | gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE) 74 | out_tmp.write(gen_img) 75 | 76 | out_tmp.release() 77 | 78 | save_video_with_watermark( 79 | tmp_path, new_audio_path, full_video_path, watermark=False 80 | ) 81 | os.remove(tmp_path) 82 | -------------------------------------------------------------------------------- /SadTalker/src/utils/videoio.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import uuid 3 | import os 4 | import cv2 # type: ignore 5 | import logging 6 | import folder_paths # type: ignore 7 | 8 | output_directory = folder_paths.get_output_directory() 9 | 10 | 11 | def load_video_to_cv2(input_path): 12 | video_stream = cv2.VideoCapture(input_path) 13 | fps = video_stream.get(cv2.CAP_PROP_FPS) 14 | full_frames = [] 15 | while True: 16 | still_reading, frame = video_stream.read() 17 | if not still_reading: 18 | video_stream.release() 19 | break 20 | full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 21 | return full_frames, fps 22 | 23 | 24 | def save_video_with_watermark(video, audio, save_path, watermark=False, target_fps=30): 25 | temp_video = os.path.join(output_directory, str(uuid.uuid4()) + "_video.mp4") 26 | 27 | cmd = (r'ffmpeg -y -hide_banner -loglevel error -i "%s" -r %d -qscale 0 "%s"') % ( 28 | video, 29 | target_fps, 30 | temp_video, 31 | ) 32 | os.system(cmd) 33 | 34 | temp_file = os.path.join(output_directory, str(uuid.uuid4()) + ".mp4") 35 | cmd = ( 36 | r'ffmpeg -i "%s" -i "%s" ' 37 | r"-c:v libx264 -b:v 5000k -maxrate 5000k -bufsize 10000k " 38 | r'-c:a aac -b:a 122k -ac 2 -ar 44100 -strict experimental "%s" ' 39 | ) % (temp_video, audio, temp_file) 40 | 41 | os.system(cmd) 42 | 43 | try: 44 | if not watermark: 45 | # 如果不需要水印,直接将临时文件移动到保存路径 46 | shutil.copy(temp_file, save_path) 47 | else: 48 | # 获取水印图片路径 49 | dir_path = os.path.dirname(os.path.realpath(__file__)) 50 | watermark_path = os.path.join(dir_path, "../../docs/sadtalker_logo.png") 51 | 52 | # 构建 FFmpeg 命令以添加水印 53 | cmd = [ 54 | "ffmpeg", 55 | "-y", 56 | "-hide_banner", 57 | "-loglevel", 58 | "error", 59 | "-i", 60 | temp_file, 61 | "-i", 62 | watermark_path, 63 | "-filter_complex", 64 | "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10", 65 | "-c:v", 66 | "copy", 67 | "-c:a", 68 | "aac", 69 | "-b:a", 70 | "128k", 71 | "-ar", 72 | "16000", 73 | "-shortest", 74 | save_path, 75 | ] 76 | 77 | os.system(cmd) 78 | except Exception as e: 79 | logging.error("Error:", e) 80 | 81 | # 确保临时视频文件被删除 82 | if os.path.exists(temp_video): 83 | os.remove(temp_video) 84 | if os.path.exists(temp_file): 85 | os.remove(temp_file) 86 | -------------------------------------------------------------------------------- /SadTalker/scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | mkdir ./checkpoints 2 | 3 | # lagency download link 4 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth 5 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth 6 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth 7 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar 8 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat 9 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth 10 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar 11 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar 12 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip 13 | # unzip -n ./checkpoints/hub.zip -d ./checkpoints/ 14 | 15 | 16 | #### download the new links. 17 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar 18 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar 19 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors 20 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors 21 | 22 | 23 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip 24 | # unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/ 25 | 26 | ### enhancer 27 | mkdir -p ./gfpgan/weights 28 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth 29 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth 30 | wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth 31 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth 32 | 33 | -------------------------------------------------------------------------------- /SadTalker/src/audio2pose_models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | class ConvNormRelu(nn.Module): 6 | def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False, 7 | kernel_size=None, stride=None, padding=None, norm='BN', leaky=False): 8 | super().__init__() 9 | if kernel_size is None: 10 | if downsample: 11 | kernel_size, stride, padding = 4, 2, 1 12 | else: 13 | kernel_size, stride, padding = 3, 1, 1 14 | 15 | if conv_type == '2d': 16 | self.conv = nn.Conv2d( 17 | in_channels, 18 | out_channels, 19 | kernel_size, 20 | stride, 21 | padding, 22 | bias=False, 23 | ) 24 | if norm == 'BN': 25 | self.norm = nn.BatchNorm2d(out_channels) 26 | elif norm == 'IN': 27 | self.norm = nn.InstanceNorm2d(out_channels) 28 | else: 29 | raise NotImplementedError 30 | elif conv_type == '1d': 31 | self.conv = nn.Conv1d( 32 | in_channels, 33 | out_channels, 34 | kernel_size, 35 | stride, 36 | padding, 37 | bias=False, 38 | ) 39 | if norm == 'BN': 40 | self.norm = nn.BatchNorm1d(out_channels) 41 | elif norm == 'IN': 42 | self.norm = nn.InstanceNorm1d(out_channels) 43 | else: 44 | raise NotImplementedError 45 | nn.init.kaiming_normal_(self.conv.weight) 46 | 47 | self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True) 48 | 49 | def forward(self, x): 50 | x = self.conv(x) 51 | if isinstance(self.norm, nn.InstanceNorm1d): 52 | x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C] 53 | else: 54 | x = self.norm(x) 55 | x = self.act(x) 56 | return x 57 | 58 | 59 | class PoseSequenceDiscriminator(nn.Module): 60 | def __init__(self, cfg): 61 | super().__init__() 62 | self.cfg = cfg 63 | leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU 64 | 65 | self.seq = nn.Sequential( 66 | ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64 67 | ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32 68 | ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16 69 | nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16 70 | ) 71 | 72 | def forward(self, x): 73 | x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2) 74 | x = self.seq(x) 75 | x = x.squeeze(1) 76 | return x -------------------------------------------------------------------------------- /SadTalker/src/audio2pose_models/audio_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | 15 | def forward(self, x): 16 | out = self.conv_block(x) 17 | if self.residual: 18 | out += x 19 | return self.act(out) 20 | 21 | class AudioEncoder(nn.Module): 22 | def __init__(self, wav2lip_checkpoint, device): 23 | super(AudioEncoder, self).__init__() 24 | 25 | self.audio_encoder = nn.Sequential( 26 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 27 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 28 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 29 | 30 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 31 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 32 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 33 | 34 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 35 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 36 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 37 | 38 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 39 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 40 | 41 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 42 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 43 | 44 | #### load the pre-trained audio_encoder, we do not need to load wav2lip model here. 45 | # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] 46 | # state_dict = self.audio_encoder.state_dict() 47 | 48 | # for k,v in wav2lip_state_dict.items(): 49 | # if 'audio_encoder' in k: 50 | # state_dict[k.replace('module.audio_encoder.', '')] = v 51 | # self.audio_encoder.load_state_dict(state_dict) 52 | 53 | 54 | def forward(self, audio_sequences): 55 | # audio_sequences = (B, T, 1, 80, 16) 56 | B = audio_sequences.size(0) 57 | 58 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) 59 | 60 | audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 61 | dim = audio_embedding.shape[1] 62 | audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) 63 | 64 | return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 65 | -------------------------------------------------------------------------------- /SadTalker/src/facerender/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from facerender.modules.util import kp2gaussian 4 | import torch 5 | 6 | 7 | class DownBlock2d(nn.Module): 8 | """ 9 | Simple block for processing video (encoder). 10 | """ 11 | 12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): 13 | super(DownBlock2d, self).__init__() 14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 15 | 16 | if sn: 17 | self.conv = nn.utils.spectral_norm(self.conv) 18 | 19 | if norm: 20 | self.norm = nn.InstanceNorm2d(out_features, affine=True) 21 | else: 22 | self.norm = None 23 | self.pool = pool 24 | 25 | def forward(self, x): 26 | out = x 27 | out = self.conv(out) 28 | if self.norm: 29 | out = self.norm(out) 30 | out = F.leaky_relu(out, 0.2) 31 | if self.pool: 32 | out = F.avg_pool2d(out, (2, 2)) 33 | return out 34 | 35 | 36 | class Discriminator(nn.Module): 37 | """ 38 | Discriminator similar to Pix2Pix 39 | """ 40 | 41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, 42 | sn=False, **kwargs): 43 | super(Discriminator, self).__init__() 44 | 45 | down_blocks = [] 46 | for i in range(num_blocks): 47 | down_blocks.append( 48 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), 49 | min(max_features, block_expansion * (2 ** (i + 1))), 50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) 51 | 52 | self.down_blocks = nn.ModuleList(down_blocks) 53 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 54 | if sn: 55 | self.conv = nn.utils.spectral_norm(self.conv) 56 | 57 | def forward(self, x): 58 | feature_maps = [] 59 | out = x 60 | 61 | for down_block in self.down_blocks: 62 | feature_maps.append(down_block(out)) 63 | out = feature_maps[-1] 64 | prediction_map = self.conv(out) 65 | 66 | return feature_maps, prediction_map 67 | 68 | 69 | class MultiScaleDiscriminator(nn.Module): 70 | """ 71 | Multi-scale (scale) discriminator 72 | """ 73 | 74 | def __init__(self, scales=(), **kwargs): 75 | super(MultiScaleDiscriminator, self).__init__() 76 | self.scales = scales 77 | discs = {} 78 | for scale in scales: 79 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) 80 | self.discs = nn.ModuleDict(discs) 81 | 82 | def forward(self, x): 83 | out_dict = {} 84 | for scale, disc in self.discs.items(): 85 | scale = str(scale).replace('-', '.') 86 | key = 'prediction_' + scale 87 | feature_maps, prediction_map = disc(x[key]) 88 | out_dict['feature_maps_' + scale] = feature_maps 89 | out_dict['prediction_map_' + scale] = prediction_map 90 | return out_dict 91 | -------------------------------------------------------------------------------- /SadTalker/src/audio2exp_models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | self.use_act = use_act 15 | 16 | def forward(self, x): 17 | out = self.conv_block(x) 18 | if self.residual: 19 | out += x 20 | 21 | if self.use_act: 22 | return self.act(out) 23 | else: 24 | return out 25 | 26 | class SimpleWrapperV2(nn.Module): 27 | def __init__(self) -> None: 28 | super().__init__() 29 | self.audio_encoder = nn.Sequential( 30 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 31 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 32 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 33 | 34 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 35 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 36 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 37 | 38 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 39 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 40 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 41 | 42 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 43 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 44 | 45 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 46 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), 47 | ) 48 | 49 | #### load the pre-trained audio_encoder 50 | #self.audio_encoder = self.audio_encoder.to(device) 51 | ''' 52 | wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict'] 53 | state_dict = self.audio_encoder.state_dict() 54 | 55 | for k,v in wav2lip_state_dict.items(): 56 | if 'audio_encoder' in k: 57 | print('init:', k) 58 | state_dict[k.replace('module.audio_encoder.', '')] = v 59 | self.audio_encoder.load_state_dict(state_dict) 60 | ''' 61 | 62 | self.mapping1 = nn.Linear(512+64+1, 64) 63 | #self.mapping2 = nn.Linear(30, 64) 64 | #nn.init.constant_(self.mapping1.weight, 0.) 65 | nn.init.constant_(self.mapping1.bias, 0.) 66 | 67 | def forward(self, x, ref, ratio): 68 | x = self.audio_encoder(x).view(x.size(0), -1) 69 | ref_reshape = ref.reshape(x.size(0), -1) 70 | ratio = ratio.reshape(x.size(0), -1) 71 | 72 | y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) 73 | out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial 74 | return out 75 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from src.face3d.models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "face3d.models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /SadTalker/src/facerender/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/utils/utils_amp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | 5 | if torch.__version__ < '1.9': 6 | Iterable = torch._six.container_abcs.Iterable 7 | else: 8 | import collections 9 | 10 | Iterable = collections.abc.Iterable 11 | from torch.cuda.amp import GradScaler 12 | 13 | 14 | class _MultiDeviceReplicator(object): 15 | """ 16 | Lazily serves copies of a tensor to requested devices. Copies are cached per-device. 17 | """ 18 | 19 | def __init__(self, master_tensor: torch.Tensor) -> None: 20 | assert master_tensor.is_cuda 21 | self.master = master_tensor 22 | self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} 23 | 24 | def get(self, device) -> torch.Tensor: 25 | retval = self._per_device_tensors.get(device, None) 26 | if retval is None: 27 | retval = self.master.to(device=device, non_blocking=True, copy=True) 28 | self._per_device_tensors[device] = retval 29 | return retval 30 | 31 | 32 | class MaxClipGradScaler(GradScaler): 33 | def __init__(self, init_scale, max_scale: float, growth_interval=100): 34 | GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) 35 | self.max_scale = max_scale 36 | 37 | def scale_clip(self): 38 | if self.get_scale() == self.max_scale: 39 | self.set_growth_factor(1) 40 | elif self.get_scale() < self.max_scale: 41 | self.set_growth_factor(2) 42 | elif self.get_scale() > self.max_scale: 43 | self._scale.fill_(self.max_scale) 44 | self.set_growth_factor(1) 45 | 46 | def scale(self, outputs): 47 | """ 48 | Multiplies ('scales') a tensor or list of tensors by the scale factor. 49 | 50 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 51 | unmodified. 52 | 53 | Arguments: 54 | outputs (Tensor or iterable of Tensors): Outputs to scale. 55 | """ 56 | if not self._enabled: 57 | return outputs 58 | self.scale_clip() 59 | # Short-circuit for the common case. 60 | if isinstance(outputs, torch.Tensor): 61 | assert outputs.is_cuda 62 | if self._scale is None: 63 | self._lazy_init_scale_growth_tracker(outputs.device) 64 | assert self._scale is not None 65 | return outputs * self._scale.to(device=outputs.device, non_blocking=True) 66 | 67 | # Invoke the more complex machinery only if we're treating multiple outputs. 68 | stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale 69 | 70 | def apply_scale(val): 71 | if isinstance(val, torch.Tensor): 72 | assert val.is_cuda 73 | if len(stash) == 0: 74 | if self._scale is None: 75 | self._lazy_init_scale_growth_tracker(val.device) 76 | assert self._scale is not None 77 | stash.append(_MultiDeviceReplicator(self._scale)) 78 | return val * stash[0].get(val.device) 79 | elif isinstance(val, Iterable): 80 | iterable = map(apply_scale, val) 81 | if isinstance(val, list) or isinstance(val, tuple): 82 | return type(val)(iterable) 83 | else: 84 | return iterable 85 | else: 86 | raise ValueError("outputs must be a Tensor or an iterable of Tensors") 87 | 88 | return apply_scale(outputs) 89 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/options/train_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the training options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | from util import util 6 | 7 | class TrainOptions(BaseOptions): 8 | """This class includes training options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) 15 | # dataset parameters 16 | # for train 17 | parser.add_argument('--data_root', type=str, default='./', help='dataset root') 18 | parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') 19 | parser.add_argument('--batch_size', type=int, default=32) 20 | parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') 21 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 22 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 23 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 24 | parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') 25 | parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') 26 | 27 | # for val 28 | parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') 29 | parser.add_argument('--batch_size_val', type=int, default=32) 30 | 31 | 32 | # visualization parameters 33 | parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') 34 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 35 | 36 | # network saving and loading parameters 37 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 38 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 39 | parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') 40 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 41 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 42 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 43 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 44 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') 45 | 46 | # training parameters 47 | parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') 48 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 49 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') 50 | parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') 51 | 52 | self.isTrain = True 53 | return parser 54 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/preprocess.py: -------------------------------------------------------------------------------- 1 | """This script contains the image preprocessing code for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | from scipy.io import loadmat 6 | from PIL import Image 7 | import cv2 8 | import os 9 | from skimage import transform as trans 10 | import torch 11 | import warnings 12 | 13 | warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 14 | warnings.filterwarnings("ignore", category=FutureWarning) 15 | 16 | 17 | # calculating least square problem for image alignment 18 | def POS(xp, x): 19 | npts = xp.shape[1] 20 | 21 | A = np.zeros([2 * npts, 8]) 22 | 23 | A[0 : 2 * npts - 1 : 2, 0:3] = x.transpose() 24 | A[0 : 2 * npts - 1 : 2, 3] = 1 25 | 26 | A[1 : 2 * npts : 2, 4:7] = x.transpose() 27 | A[1 : 2 * npts : 2, 7] = 1 28 | 29 | b = np.reshape(xp.transpose(), [2 * npts, 1]) 30 | 31 | k, _, _, _ = np.linalg.lstsq(A, b) 32 | 33 | R1 = k[0:3] 34 | R2 = k[4:7] 35 | sTx = k[3] 36 | sTy = k[7] 37 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2 38 | t = np.stack([sTx, sTy], axis=0) 39 | 40 | return t, s 41 | 42 | 43 | # resize and crop images for face reconstruction 44 | def resize_n_crop_img(img, lm, t, s, target_size=224.0, mask=None): 45 | w0, h0 = img.size 46 | w = (w0 * s).astype(np.int32) 47 | h = (h0 * s).astype(np.int32) 48 | left = (w / 2 - target_size / 2 + float((t[0] - w0 / 2) * s)).astype(np.int32) 49 | right = left + target_size 50 | up = (h / 2 - target_size / 2 + float((h0 / 2 - t[1]) * s)).astype(np.int32) 51 | below = up + target_size 52 | 53 | img = img.resize((w, h), resample=Image.BICUBIC) 54 | img = img.crop((left, up, right, below)) 55 | 56 | if mask is not None: 57 | mask = mask.resize((w, h), resample=Image.BICUBIC) 58 | mask = mask.crop((left, up, right, below)) 59 | 60 | lm = np.stack([lm[:, 0] - t[0] + w0 / 2, lm[:, 1] - t[1] + h0 / 2], axis=1) * s 61 | lm = lm - np.reshape( 62 | np.array([(w / 2 - target_size / 2), (h / 2 - target_size / 2)]), [1, 2] 63 | ) 64 | 65 | return img, lm, mask 66 | 67 | 68 | # utils for face reconstruction 69 | def extract_5p(lm): 70 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 71 | lm5p = np.stack( 72 | [ 73 | lm[lm_idx[0], :], 74 | np.mean(lm[lm_idx[[1, 2]], :], 0), 75 | np.mean(lm[lm_idx[[3, 4]], :], 0), 76 | lm[lm_idx[5], :], 77 | lm[lm_idx[6], :], 78 | ], 79 | axis=0, 80 | ) 81 | lm5p = lm5p[[1, 2, 0, 3, 4], :] 82 | return lm5p 83 | 84 | 85 | # utils for face reconstruction 86 | def align_img(img, lm, lm3D, mask=None, target_size=224.0, rescale_factor=102.0): 87 | """ 88 | Return: 89 | transparams --numpy.array (raw_W, raw_H, scale, tx, ty) 90 | img_new --PIL.Image (target_size, target_size, 3) 91 | lm_new --numpy.array (68, 2), y direction is opposite to v direction 92 | mask_new --PIL.Image (target_size, target_size) 93 | 94 | Parameters: 95 | img --PIL.Image (raw_H, raw_W, 3) 96 | lm --numpy.array (68, 2), y direction is opposite to v direction 97 | lm3D --numpy.array (5, 3) 98 | mask --PIL.Image (raw_H, raw_W, 3) 99 | """ 100 | 101 | w0, h0 = img.size 102 | if lm.shape[0] != 5: 103 | lm5p = extract_5p(lm) 104 | else: 105 | lm5p = lm 106 | 107 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face 108 | t, s = POS(lm5p.transpose(), lm3D.transpose()) 109 | s = rescale_factor / s 110 | 111 | # processing the image 112 | img_new, lm_new, mask_new = resize_n_crop_img( 113 | img, lm, t, s, target_size=target_size, mask=mask 114 | ) 115 | trans_params = np.array([w0, h0, s, t[0].item(), t[1].item()], dtype=np.float32) 116 | 117 | return trans_params, img_new, lm_new, mask_new 118 | -------------------------------------------------------------------------------- /SadTalker/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | examples/results/* 163 | gfpgan/* 164 | checkpoints/* 165 | assets/* 166 | results/* 167 | Dockerfile 168 | start_docker.sh 169 | start.sh 170 | 171 | checkpoints 172 | 173 | # Mac 174 | .DS_Store 175 | -------------------------------------------------------------------------------- /SadTalker/src/audio2pose_models/audio2pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from src.audio2pose_models.cvae import CVAE 4 | from src.audio2pose_models.discriminator import PoseSequenceDiscriminator 5 | from src.audio2pose_models.audio_encoder import AudioEncoder 6 | 7 | class Audio2Pose(nn.Module): 8 | def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): 9 | super().__init__() 10 | self.cfg = cfg 11 | self.seq_len = cfg.MODEL.CVAE.SEQ_LEN 12 | self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE 13 | self.device = device 14 | 15 | self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) 16 | self.audio_encoder.eval() 17 | for param in self.audio_encoder.parameters(): 18 | param.requires_grad = False 19 | 20 | self.netG = CVAE(cfg) 21 | self.netD_motion = PoseSequenceDiscriminator(cfg) 22 | 23 | 24 | def forward(self, x): 25 | 26 | batch = {} 27 | coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73 28 | batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6 29 | batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6 30 | batch['class'] = x['class'].squeeze(0).cuda() # bs 31 | indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16 32 | 33 | # forward 34 | audio_emb_list = [] 35 | audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512 36 | batch['audio_emb'] = audio_emb 37 | batch = self.netG(batch) 38 | 39 | pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6 40 | pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6 41 | pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6 42 | 43 | batch['pose_pred'] = pose_pred 44 | batch['pose_gt'] = pose_gt 45 | 46 | return batch 47 | 48 | def test(self, x): 49 | 50 | batch = {} 51 | ref = x['ref'] #bs 1 70 52 | batch['ref'] = x['ref'][:,0,-6:] 53 | batch['class'] = x['class'] 54 | bs = ref.shape[0] 55 | 56 | indiv_mels= x['indiv_mels'] # bs T 1 80 16 57 | indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame 58 | num_frames = x['num_frames'] 59 | num_frames = int(num_frames) - 1 60 | 61 | # 62 | div = num_frames//self.seq_len 63 | re = num_frames%self.seq_len 64 | audio_emb_list = [] 65 | pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, 66 | device=batch['ref'].device)] 67 | 68 | for i in range(div): 69 | z = torch.randn(bs, self.latent_dim).to(ref.device) 70 | batch['z'] = z 71 | audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512 72 | batch['audio_emb'] = audio_emb 73 | batch = self.netG.test(batch) 74 | pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6 75 | 76 | if re != 0: 77 | z = torch.randn(bs, self.latent_dim).to(ref.device) 78 | batch['z'] = z 79 | audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512 80 | if audio_emb.shape[1] != self.seq_len: 81 | pad_dim = self.seq_len-audio_emb.shape[1] 82 | pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) 83 | audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) 84 | batch['audio_emb'] = audio_emb 85 | batch = self.netG.test(batch) 86 | pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) 87 | 88 | pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) 89 | batch['pose_motion_pred'] = pose_motion_pred 90 | 91 | pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6 92 | 93 | batch['pose_pred'] = pose_pred 94 | return batch 95 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/test_mean_face.txt: -------------------------------------------------------------------------------- 1 | -5.228591537475585938e+01 2 | 2.078247070312500000e-01 3 | -5.064269638061523438e+01 4 | -1.315765380859375000e+01 5 | -4.952939224243164062e+01 6 | -2.592591094970703125e+01 7 | -4.793047332763671875e+01 8 | -3.832135772705078125e+01 9 | -4.512159729003906250e+01 10 | -5.059623336791992188e+01 11 | -3.917720794677734375e+01 12 | -6.043736648559570312e+01 13 | -2.929953765869140625e+01 14 | -6.861183166503906250e+01 15 | -1.719801330566406250e+01 16 | -7.572736358642578125e+01 17 | -1.961936950683593750e+00 18 | -7.862001037597656250e+01 19 | 1.467941284179687500e+01 20 | -7.607844543457031250e+01 21 | 2.744073486328125000e+01 22 | -6.915261840820312500e+01 23 | 3.855677795410156250e+01 24 | -5.950350570678710938e+01 25 | 4.478240966796875000e+01 26 | -4.867547225952148438e+01 27 | 4.714337158203125000e+01 28 | -3.800830078125000000e+01 29 | 4.940315246582031250e+01 30 | -2.496297454833984375e+01 31 | 5.117234802246093750e+01 32 | -1.241538238525390625e+01 33 | 5.190507507324218750e+01 34 | 8.244247436523437500e-01 35 | -4.150688934326171875e+01 36 | 2.386329650878906250e+01 37 | -3.570307159423828125e+01 38 | 3.017010498046875000e+01 39 | -2.790358734130859375e+01 40 | 3.212951660156250000e+01 41 | -1.941773223876953125e+01 42 | 3.156523132324218750e+01 43 | -1.138106536865234375e+01 44 | 2.841992187500000000e+01 45 | 5.993263244628906250e+00 46 | 2.895182800292968750e+01 47 | 1.343590545654296875e+01 48 | 3.189880371093750000e+01 49 | 2.203153991699218750e+01 50 | 3.302221679687500000e+01 51 | 2.992478942871093750e+01 52 | 3.099150085449218750e+01 53 | 3.628388977050781250e+01 54 | 2.765748596191406250e+01 55 | -1.933914184570312500e+00 56 | 1.405374145507812500e+01 57 | -2.153038024902343750e+00 58 | 5.772636413574218750e+00 59 | -2.270050048828125000e+00 60 | -2.121643066406250000e+00 61 | -2.218330383300781250e+00 62 | -1.068978118896484375e+01 63 | -1.187252044677734375e+01 64 | -1.997912597656250000e+01 65 | -6.879402160644531250e+00 66 | -2.143579864501953125e+01 67 | -1.227821350097656250e+00 68 | -2.193494415283203125e+01 69 | 4.623237609863281250e+00 70 | -2.152721405029296875e+01 71 | 9.721397399902343750e+00 72 | -1.953671264648437500e+01 73 | -3.648714447021484375e+01 74 | 9.811126708984375000e+00 75 | -3.130242919921875000e+01 76 | 1.422447967529296875e+01 77 | -2.212834930419921875e+01 78 | 1.493019866943359375e+01 79 | -1.500880432128906250e+01 80 | 1.073588562011718750e+01 81 | -2.095037078857421875e+01 82 | 9.054298400878906250e+00 83 | -3.050099182128906250e+01 84 | 8.704177856445312500e+00 85 | 1.173237609863281250e+01 86 | 1.054329681396484375e+01 87 | 1.856353759765625000e+01 88 | 1.535009765625000000e+01 89 | 2.893331909179687500e+01 90 | 1.451992797851562500e+01 91 | 3.452944946289062500e+01 92 | 1.065280151367187500e+01 93 | 2.875990295410156250e+01 94 | 8.654792785644531250e+00 95 | 1.942100524902343750e+01 96 | 9.422447204589843750e+00 97 | -2.204488372802734375e+01 98 | -3.983994293212890625e+01 99 | -1.324458312988281250e+01 100 | -3.467377471923828125e+01 101 | -6.749649047851562500e+00 102 | -3.092894744873046875e+01 103 | -9.183349609375000000e-01 104 | -3.196458435058593750e+01 105 | 4.220649719238281250e+00 106 | -3.090406036376953125e+01 107 | 1.089889526367187500e+01 108 | -3.497008514404296875e+01 109 | 1.874589538574218750e+01 110 | -4.065438079833984375e+01 111 | 1.124106597900390625e+01 112 | -4.438417816162109375e+01 113 | 5.181709289550781250e+00 114 | -4.649170684814453125e+01 115 | -1.158607482910156250e+00 116 | -4.680406951904296875e+01 117 | -7.918922424316406250e+00 118 | -4.671575164794921875e+01 119 | -1.452505493164062500e+01 120 | -4.416526031494140625e+01 121 | -2.005007171630859375e+01 122 | -3.997841644287109375e+01 123 | -1.054919433593750000e+01 124 | -3.849683380126953125e+01 125 | -1.051826477050781250e+00 126 | -3.794863128662109375e+01 127 | 6.412681579589843750e+00 128 | -3.804645538330078125e+01 129 | 1.627674865722656250e+01 130 | -4.039697265625000000e+01 131 | 6.373878479003906250e+00 132 | -4.087213897705078125e+01 133 | -8.551712036132812500e-01 134 | -4.157129669189453125e+01 135 | -1.014953613281250000e+01 136 | -4.128469085693359375e+01 137 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/extract_kp_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import glob 5 | import argparse 6 | import face_alignment 7 | import numpy as np 8 | from PIL import Image 9 | from tqdm import tqdm 10 | from itertools import cycle 11 | 12 | from torch.multiprocessing import Pool, Process, set_start_method 13 | 14 | class KeypointExtractor(): 15 | def __init__(self, device): 16 | self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, 17 | device=device) 18 | 19 | def extract_keypoint(self, images, name=None, info=True): 20 | if isinstance(images, list): 21 | keypoints = [] 22 | if info: 23 | i_range = tqdm(images,desc='landmark Det:') 24 | else: 25 | i_range = images 26 | 27 | for image in i_range: 28 | current_kp = self.extract_keypoint(image) 29 | if np.mean(current_kp) == -1 and keypoints: 30 | keypoints.append(keypoints[-1]) 31 | else: 32 | keypoints.append(current_kp[None]) 33 | 34 | keypoints = np.concatenate(keypoints, 0) 35 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 36 | return keypoints 37 | else: 38 | while True: 39 | try: 40 | keypoints = self.detector.get_landmarks_from_image(np.array(images))[0] 41 | break 42 | except RuntimeError as e: 43 | if str(e).startswith('CUDA'): 44 | print("Warning: out of memory, sleep for 1s") 45 | time.sleep(1) 46 | else: 47 | print(e) 48 | break 49 | except TypeError: 50 | print('No face detected in this image') 51 | shape = [68, 2] 52 | keypoints = -1. * np.ones(shape) 53 | break 54 | if name is not None: 55 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1)) 56 | return keypoints 57 | 58 | def read_video(filename): 59 | frames = [] 60 | cap = cv2.VideoCapture(filename) 61 | while cap.isOpened(): 62 | ret, frame = cap.read() 63 | if ret: 64 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 65 | frame = Image.fromarray(frame) 66 | frames.append(frame) 67 | else: 68 | break 69 | cap.release() 70 | return frames 71 | 72 | def run(data): 73 | filename, opt, device = data 74 | os.environ['CUDA_VISIBLE_DEVICES'] = device 75 | kp_extractor = KeypointExtractor() 76 | images = read_video(filename) 77 | name = filename.split('/')[-2:] 78 | os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True) 79 | kp_extractor.extract_keypoint( 80 | images, 81 | name=os.path.join(opt.output_dir, name[-2], name[-1]) 82 | ) 83 | 84 | if __name__ == '__main__': 85 | set_start_method('spawn') 86 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 87 | parser.add_argument('--input_dir', type=str, help='the folder of the input files') 88 | parser.add_argument('--output_dir', type=str, help='the folder of the output files') 89 | parser.add_argument('--device_ids', type=str, default='0,1') 90 | parser.add_argument('--workers', type=int, default=4) 91 | 92 | opt = parser.parse_args() 93 | filenames = list() 94 | VIDEO_EXTENSIONS_LOWERCASE = {'mp4'} 95 | VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE}) 96 | extensions = VIDEO_EXTENSIONS 97 | 98 | for ext in extensions: 99 | os.listdir(f'{opt.input_dir}') 100 | print(f'{opt.input_dir}/*.{ext}') 101 | filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}')) 102 | print('Total number of videos:', len(filenames)) 103 | pool = Pool(opt.workers) 104 | args_list = cycle([opt]) 105 | device_ids = opt.device_ids.split(",") 106 | device_ids = cycle(device_ids) 107 | for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))): 108 | None 109 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/dataset.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os 3 | import queue as Queue 4 | import threading 5 | 6 | import mxnet as mx 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision import transforms 11 | 12 | 13 | class BackgroundGenerator(threading.Thread): 14 | def __init__(self, generator, local_rank, max_prefetch=6): 15 | super(BackgroundGenerator, self).__init__() 16 | self.queue = Queue.Queue(max_prefetch) 17 | self.generator = generator 18 | self.local_rank = local_rank 19 | self.daemon = True 20 | self.start() 21 | 22 | def run(self): 23 | torch.cuda.set_device(self.local_rank) 24 | for item in self.generator: 25 | self.queue.put(item) 26 | self.queue.put(None) 27 | 28 | def next(self): 29 | next_item = self.queue.get() 30 | if next_item is None: 31 | raise StopIteration 32 | return next_item 33 | 34 | def __next__(self): 35 | return self.next() 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | 41 | class DataLoaderX(DataLoader): 42 | 43 | def __init__(self, local_rank, **kwargs): 44 | super(DataLoaderX, self).__init__(**kwargs) 45 | self.stream = torch.cuda.Stream(local_rank) 46 | self.local_rank = local_rank 47 | 48 | def __iter__(self): 49 | self.iter = super(DataLoaderX, self).__iter__() 50 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 51 | self.preload() 52 | return self 53 | 54 | def preload(self): 55 | self.batch = next(self.iter, None) 56 | if self.batch is None: 57 | return None 58 | with torch.cuda.stream(self.stream): 59 | for k in range(len(self.batch)): 60 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 61 | 62 | def __next__(self): 63 | torch.cuda.current_stream().wait_stream(self.stream) 64 | batch = self.batch 65 | if batch is None: 66 | raise StopIteration 67 | self.preload() 68 | return batch 69 | 70 | 71 | class MXFaceDataset(Dataset): 72 | def __init__(self, root_dir, local_rank): 73 | super(MXFaceDataset, self).__init__() 74 | self.transform = transforms.Compose( 75 | [transforms.ToPILImage(), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 79 | ]) 80 | self.root_dir = root_dir 81 | self.local_rank = local_rank 82 | path_imgrec = os.path.join(root_dir, 'train.rec') 83 | path_imgidx = os.path.join(root_dir, 'train.idx') 84 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') 85 | s = self.imgrec.read_idx(0) 86 | header, _ = mx.recordio.unpack(s) 87 | if header.flag > 0: 88 | self.header0 = (int(header.label[0]), int(header.label[1])) 89 | self.imgidx = np.array(range(1, int(header.label[0]))) 90 | else: 91 | self.imgidx = np.array(list(self.imgrec.keys)) 92 | 93 | def __getitem__(self, index): 94 | idx = self.imgidx[index] 95 | s = self.imgrec.read_idx(idx) 96 | header, img = mx.recordio.unpack(s) 97 | label = header.label 98 | if not isinstance(label, numbers.Number): 99 | label = label[0] 100 | label = torch.tensor(label, dtype=torch.long) 101 | sample = mx.image.imdecode(img).asnumpy() 102 | if self.transform is not None: 103 | sample = self.transform(sample) 104 | return sample, label 105 | 106 | def __len__(self): 107 | return len(self.imgidx) 108 | 109 | 110 | class SyntheticDataset(Dataset): 111 | def __init__(self, local_rank): 112 | super(SyntheticDataset, self).__init__() 113 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 114 | img = np.transpose(img, (2, 0, 1)) 115 | img = torch.from_numpy(img).squeeze(0).float() 116 | img = ((img / 255) - 0.5) / 0.5 117 | self.img = img 118 | self.label = 1 119 | 120 | def __getitem__(self, index): 121 | return self.img, self.label 122 | 123 | def __len__(self): 124 | return 1000000 125 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/detect_lm68.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from scipy.io import loadmat 5 | import tensorflow as tf 6 | from util.preprocess import align_for_lm 7 | from shutil import move 8 | 9 | mean_face = np.loadtxt('util/test_mean_face.txt') 10 | mean_face = mean_face.reshape([68, 2]) 11 | 12 | def save_label(labels, save_path): 13 | np.savetxt(save_path, labels) 14 | 15 | def draw_landmarks(img, landmark, save_name): 16 | landmark = landmark 17 | lm_img = np.zeros([img.shape[0], img.shape[1], 3]) 18 | lm_img[:] = img.astype(np.float32) 19 | landmark = np.round(landmark).astype(np.int32) 20 | 21 | for i in range(len(landmark)): 22 | for j in range(-1, 1): 23 | for k in range(-1, 1): 24 | if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \ 25 | img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \ 26 | landmark[i, 0]+k > 0 and \ 27 | landmark[i, 0]+k < img.shape[1]: 28 | lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k, 29 | :] = np.array([0, 0, 255]) 30 | lm_img = lm_img.astype(np.uint8) 31 | 32 | cv2.imwrite(save_name, lm_img) 33 | 34 | 35 | def load_data(img_name, txt_name): 36 | return cv2.imread(img_name), np.loadtxt(txt_name) 37 | 38 | # create tensorflow graph for landmark detector 39 | def load_lm_graph(graph_filename): 40 | with tf.gfile.GFile(graph_filename, 'rb') as f: 41 | graph_def = tf.GraphDef() 42 | graph_def.ParseFromString(f.read()) 43 | 44 | with tf.Graph().as_default() as graph: 45 | tf.import_graph_def(graph_def, name='net') 46 | img_224 = graph.get_tensor_by_name('net/input_imgs:0') 47 | output_lm = graph.get_tensor_by_name('net/lm:0') 48 | lm_sess = tf.Session(graph=graph) 49 | 50 | return lm_sess,img_224,output_lm 51 | 52 | # landmark detection 53 | def detect_68p(img_path,sess,input_op,output_op): 54 | print('detecting landmarks......') 55 | names = [i for i in sorted(os.listdir( 56 | img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i] 57 | vis_path = os.path.join(img_path, 'vis') 58 | remove_path = os.path.join(img_path, 'remove') 59 | save_path = os.path.join(img_path, 'landmarks') 60 | if not os.path.isdir(vis_path): 61 | os.makedirs(vis_path) 62 | if not os.path.isdir(remove_path): 63 | os.makedirs(remove_path) 64 | if not os.path.isdir(save_path): 65 | os.makedirs(save_path) 66 | 67 | for i in range(0, len(names)): 68 | name = names[i] 69 | print('%05d' % (i), ' ', name) 70 | full_image_name = os.path.join(img_path, name) 71 | txt_name = '.'.join(name.split('.')[:-1]) + '.txt' 72 | full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image 73 | 74 | # if an image does not have detected 5 facial landmarks, remove it from the training list 75 | if not os.path.isfile(full_txt_name): 76 | move(full_image_name, os.path.join(remove_path, name)) 77 | continue 78 | 79 | # load data 80 | img, five_points = load_data(full_image_name, full_txt_name) 81 | input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection 82 | 83 | # if the alignment fails, remove corresponding image from the training list 84 | if scale == 0: 85 | move(full_txt_name, os.path.join( 86 | remove_path, txt_name)) 87 | move(full_image_name, os.path.join(remove_path, name)) 88 | continue 89 | 90 | # detect landmarks 91 | input_img = np.reshape( 92 | input_img, [1, 224, 224, 3]).astype(np.float32) 93 | landmark = sess.run( 94 | output_op, feed_dict={input_op: input_img}) 95 | 96 | # transform back to original image coordinate 97 | landmark = landmark.reshape([68, 2]) + mean_face 98 | landmark[:, 1] = 223 - landmark[:, 1] 99 | landmark = landmark / scale 100 | landmark[:, 0] = landmark[:, 0] + bbox[0] 101 | landmark[:, 1] = landmark[:, 1] + bbox[1] 102 | landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] 103 | 104 | if i % 100 == 0: 105 | draw_landmarks(img, landmark, os.path.join(vis_path, name)) 106 | save_label(landmark, os.path.join(save_path, txt_name)) 107 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/data/flist_dataset.py: -------------------------------------------------------------------------------- 1 | """This script defines the custom dataset for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os.path 5 | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | import random 9 | import util.util as util 10 | import numpy as np 11 | import json 12 | import torch 13 | from scipy.io import loadmat, savemat 14 | import pickle 15 | from util.preprocess import align_img, estimate_norm 16 | from util.load_mats import load_lm3d 17 | 18 | 19 | def default_flist_reader(flist): 20 | """ 21 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 22 | """ 23 | imlist = [] 24 | with open(flist, 'r') as rf: 25 | for line in rf.readlines(): 26 | impath = line.strip() 27 | imlist.append(impath) 28 | 29 | return imlist 30 | 31 | def jason_flist_reader(flist): 32 | with open(flist, 'r') as fp: 33 | info = json.load(fp) 34 | return info 35 | 36 | def parse_label(label): 37 | return torch.tensor(np.array(label).astype(np.float32)) 38 | 39 | 40 | class FlistDataset(BaseDataset): 41 | """ 42 | It requires one directories to host training images '/path/to/data/train' 43 | You can train the model with the dataset flag '--dataroot /path/to/data'. 44 | """ 45 | 46 | def __init__(self, opt): 47 | """Initialize this dataset class. 48 | 49 | Parameters: 50 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 51 | """ 52 | BaseDataset.__init__(self, opt) 53 | 54 | self.lm3d_std = load_lm3d(opt.bfm_folder) 55 | 56 | msk_names = default_flist_reader(opt.flist) 57 | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] 58 | 59 | self.size = len(self.msk_paths) 60 | self.opt = opt 61 | 62 | self.name = 'train' if opt.isTrain else 'val' 63 | if '_' in opt.flist: 64 | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] 65 | 66 | 67 | def __getitem__(self, index): 68 | """Return a data point and its metadata information. 69 | 70 | Parameters: 71 | index (int) -- a random integer for data indexing 72 | 73 | Returns a dictionary that contains A, B, A_paths and B_paths 74 | img (tensor) -- an image in the input domain 75 | msk (tensor) -- its corresponding attention mask 76 | lm (tensor) -- its corresponding 3d landmarks 77 | im_paths (str) -- image paths 78 | aug_flag (bool) -- a flag used to tell whether its raw or augmented 79 | """ 80 | msk_path = self.msk_paths[index % self.size] # make sure index is within then range 81 | img_path = msk_path.replace('mask/', '') 82 | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' 83 | 84 | raw_img = Image.open(img_path).convert('RGB') 85 | raw_msk = Image.open(msk_path).convert('RGB') 86 | raw_lm = np.loadtxt(lm_path).astype(np.float32) 87 | 88 | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) 89 | 90 | aug_flag = self.opt.use_aug and self.opt.isTrain 91 | if aug_flag: 92 | img, lm, msk = self._augmentation(img, lm, self.opt, msk) 93 | 94 | _, H = img.size 95 | M = estimate_norm(lm, H) 96 | transform = get_transform() 97 | img_tensor = transform(img) 98 | msk_tensor = transform(msk)[:1, ...] 99 | lm_tensor = parse_label(lm) 100 | M_tensor = parse_label(M) 101 | 102 | 103 | return {'imgs': img_tensor, 104 | 'lms': lm_tensor, 105 | 'msks': msk_tensor, 106 | 'M': M_tensor, 107 | 'im_paths': img_path, 108 | 'aug_flag': aug_flag, 109 | 'dataset': self.name} 110 | 111 | def _augmentation(self, img, lm, opt, msk=None): 112 | affine, affine_inv, flip = get_affine_mat(opt, img.size) 113 | img = apply_img_affine(img, affine_inv) 114 | lm = apply_lm_affine(lm, affine, flip, img.size) 115 | if msk is not None: 116 | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) 117 | return img, lm, msk 118 | 119 | 120 | 121 | 122 | def __len__(self): 123 | """Return the total number of images in the dataset. 124 | """ 125 | return self.size 126 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from kornia.geometry import warp_affine 5 | import torch.nn.functional as F 6 | 7 | def resize_n_crop(image, M, dsize=112): 8 | # image: (b, c, h, w) 9 | # M : (b, 2, 3) 10 | return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True) 11 | 12 | ### perceptual level loss 13 | class PerceptualLoss(nn.Module): 14 | def __init__(self, recog_net, input_size=112): 15 | super(PerceptualLoss, self).__init__() 16 | self.recog_net = recog_net 17 | self.preprocess = lambda x: 2 * x - 1 18 | self.input_size=input_size 19 | def forward(imageA, imageB, M): 20 | """ 21 | 1 - cosine distance 22 | Parameters: 23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order 24 | imageB --same as imageA 25 | """ 26 | 27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) 28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) 29 | 30 | # freeze bn 31 | self.recog_net.eval() 32 | 33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) 34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) 35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 36 | # assert torch.sum((cosine_d > 1).float()) == 0 37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 38 | 39 | def perceptual_loss(id_featureA, id_featureB): 40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 41 | # assert torch.sum((cosine_d > 1).float()) == 0 42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 43 | 44 | ### image level loss 45 | def photo_loss(imageA, imageB, mask, eps=1e-6): 46 | """ 47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) 48 | Parameters: 49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order 50 | imageB --same as imageA 51 | """ 52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask 53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) 54 | return loss 55 | 56 | def landmark_loss(predict_lm, gt_lm, weight=None): 57 | """ 58 | weighted mse loss 59 | Parameters: 60 | predict_lm --torch.tensor (B, 68, 2) 61 | gt_lm --torch.tensor (B, 68, 2) 62 | weight --numpy.array (1, 68) 63 | """ 64 | if not weight: 65 | weight = np.ones([68]) 66 | weight[28:31] = 20 67 | weight[-8:] = 20 68 | weight = np.expand_dims(weight, 0) 69 | weight = torch.tensor(weight).to(predict_lm.device) 70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight 71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) 72 | return loss 73 | 74 | 75 | ### regulization 76 | def reg_loss(coeffs_dict, opt=None): 77 | """ 78 | l2 norm without the sqrt, from yu's implementation (mse) 79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss 80 | Parameters: 81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans 82 | 83 | """ 84 | # coefficient regularization to ensure plausible 3d faces 85 | if opt: 86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex 87 | else: 88 | w_id, w_exp, w_tex = 1, 1, 1, 1 89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ 90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ 91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2) 92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0] 93 | 94 | # gamma regularization to ensure a nearly-monochromatic light 95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) 96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True) 97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2) 98 | 99 | return creg_loss, gamma_loss 100 | 101 | def reflectance_loss(texture, mask): 102 | """ 103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo 104 | Parameters: 105 | texture --torch.tensor, (B, N, 3) 106 | mask --torch.tensor, (N), 1 or 0 107 | 108 | """ 109 | mask = mask.reshape([1, mask.shape[0], 1]) 110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) 111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) 112 | return loss 113 | 114 | -------------------------------------------------------------------------------- /SadTalker/src/audio2pose_models/networks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class ResidualConv(nn.Module): 6 | def __init__(self, input_dim, output_dim, stride, padding): 7 | super(ResidualConv, self).__init__() 8 | 9 | self.conv_block = nn.Sequential( 10 | nn.BatchNorm2d(input_dim), 11 | nn.ReLU(), 12 | nn.Conv2d( 13 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 14 | ), 15 | nn.BatchNorm2d(output_dim), 16 | nn.ReLU(), 17 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 18 | ) 19 | self.conv_skip = nn.Sequential( 20 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 21 | nn.BatchNorm2d(output_dim), 22 | ) 23 | 24 | def forward(self, x): 25 | 26 | return self.conv_block(x) + self.conv_skip(x) 27 | 28 | 29 | class Upsample(nn.Module): 30 | def __init__(self, input_dim, output_dim, kernel, stride): 31 | super(Upsample, self).__init__() 32 | 33 | self.upsample = nn.ConvTranspose2d( 34 | input_dim, output_dim, kernel_size=kernel, stride=stride 35 | ) 36 | 37 | def forward(self, x): 38 | return self.upsample(x) 39 | 40 | 41 | class Squeeze_Excite_Block(nn.Module): 42 | def __init__(self, channel, reduction=16): 43 | super(Squeeze_Excite_Block, self).__init__() 44 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 45 | self.fc = nn.Sequential( 46 | nn.Linear(channel, channel // reduction, bias=False), 47 | nn.ReLU(inplace=True), 48 | nn.Linear(channel // reduction, channel, bias=False), 49 | nn.Sigmoid(), 50 | ) 51 | 52 | def forward(self, x): 53 | b, c, _, _ = x.size() 54 | y = self.avg_pool(x).view(b, c) 55 | y = self.fc(y).view(b, c, 1, 1) 56 | return x * y.expand_as(x) 57 | 58 | 59 | class ASPP(nn.Module): 60 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): 61 | super(ASPP, self).__init__() 62 | 63 | self.aspp_block1 = nn.Sequential( 64 | nn.Conv2d( 65 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] 66 | ), 67 | nn.ReLU(inplace=True), 68 | nn.BatchNorm2d(out_dims), 69 | ) 70 | self.aspp_block2 = nn.Sequential( 71 | nn.Conv2d( 72 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] 73 | ), 74 | nn.ReLU(inplace=True), 75 | nn.BatchNorm2d(out_dims), 76 | ) 77 | self.aspp_block3 = nn.Sequential( 78 | nn.Conv2d( 79 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] 80 | ), 81 | nn.ReLU(inplace=True), 82 | nn.BatchNorm2d(out_dims), 83 | ) 84 | 85 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) 86 | self._init_weights() 87 | 88 | def forward(self, x): 89 | x1 = self.aspp_block1(x) 90 | x2 = self.aspp_block2(x) 91 | x3 = self.aspp_block3(x) 92 | out = torch.cat([x1, x2, x3], dim=1) 93 | return self.output(out) 94 | 95 | def _init_weights(self): 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_normal_(m.weight) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | 103 | 104 | class Upsample_(nn.Module): 105 | def __init__(self, scale=2): 106 | super(Upsample_, self).__init__() 107 | 108 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) 109 | 110 | def forward(self, x): 111 | return self.upsample(x) 112 | 113 | 114 | class AttentionBlock(nn.Module): 115 | def __init__(self, input_encoder, input_decoder, output_dim): 116 | super(AttentionBlock, self).__init__() 117 | 118 | self.conv_encoder = nn.Sequential( 119 | nn.BatchNorm2d(input_encoder), 120 | nn.ReLU(), 121 | nn.Conv2d(input_encoder, output_dim, 3, padding=1), 122 | nn.MaxPool2d(2, 2), 123 | ) 124 | 125 | self.conv_decoder = nn.Sequential( 126 | nn.BatchNorm2d(input_decoder), 127 | nn.ReLU(), 128 | nn.Conv2d(input_decoder, output_dim, 3, padding=1), 129 | ) 130 | 131 | self.conv_attn = nn.Sequential( 132 | nn.BatchNorm2d(output_dim), 133 | nn.ReLU(), 134 | nn.Conv2d(output_dim, 1, 1), 135 | ) 136 | 137 | def forward(self, x1, x2): 138 | out = self.conv_encoder(x1) + self.conv_decoder(x2) 139 | out = self.conv_attn(out) 140 | return out * x2 -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/load_mats.py: -------------------------------------------------------------------------------- 1 | """This script is to load 3D face model for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from scipy.io import loadmat, savemat 7 | from array import array 8 | import os.path as osp 9 | 10 | # load expression basis 11 | def LoadExpBasis(bfm_folder='BFM'): 12 | n_vertex = 53215 13 | Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb') 14 | exp_dim = array('i') 15 | exp_dim.fromfile(Expbin, 1) 16 | expMU = array('f') 17 | expPC = array('f') 18 | expMU.fromfile(Expbin, 3*n_vertex) 19 | expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex) 20 | Expbin.close() 21 | 22 | expPC = np.array(expPC) 23 | expPC = np.reshape(expPC, [exp_dim[0], -1]) 24 | expPC = np.transpose(expPC) 25 | 26 | expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt')) 27 | 28 | return expPC, expEV 29 | 30 | 31 | # transfer original BFM09 to our face model 32 | def transferBFM09(bfm_folder='BFM'): 33 | print('Transfer BFM09 to BFM_model_front......') 34 | original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat')) 35 | shapePC = original_BFM['shapePC'] # shape basis 36 | shapeEV = original_BFM['shapeEV'] # corresponding eigen value 37 | shapeMU = original_BFM['shapeMU'] # mean face 38 | texPC = original_BFM['texPC'] # texture basis 39 | texEV = original_BFM['texEV'] # eigen value 40 | texMU = original_BFM['texMU'] # mean texture 41 | 42 | expPC, expEV = LoadExpBasis(bfm_folder) 43 | 44 | # transfer BFM09 to our face model 45 | 46 | idBase = shapePC*np.reshape(shapeEV, [-1, 199]) 47 | idBase = idBase/1e5 # unify the scale to decimeter 48 | idBase = idBase[:, :80] # use only first 80 basis 49 | 50 | exBase = expPC*np.reshape(expEV, [-1, 79]) 51 | exBase = exBase/1e5 # unify the scale to decimeter 52 | exBase = exBase[:, :64] # use only first 64 basis 53 | 54 | texBase = texPC*np.reshape(texEV, [-1, 199]) 55 | texBase = texBase[:, :80] # use only first 80 basis 56 | 57 | # our face model is cropped along face landmarks and contains only 35709 vertex. 58 | # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex. 59 | # thus we select corresponding vertex to get our face model. 60 | 61 | index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat')) 62 | index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215) 63 | 64 | index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat')) 65 | index_shape = index_shape['trimIndex'].astype( 66 | np.int32) - 1 # starts from 0 (to 53490) 67 | index_shape = index_shape[index_exp] 68 | 69 | idBase = np.reshape(idBase, [-1, 3, 80]) 70 | idBase = idBase[index_shape, :, :] 71 | idBase = np.reshape(idBase, [-1, 80]) 72 | 73 | texBase = np.reshape(texBase, [-1, 3, 80]) 74 | texBase = texBase[index_shape, :, :] 75 | texBase = np.reshape(texBase, [-1, 80]) 76 | 77 | exBase = np.reshape(exBase, [-1, 3, 64]) 78 | exBase = exBase[index_exp, :, :] 79 | exBase = np.reshape(exBase, [-1, 64]) 80 | 81 | meanshape = np.reshape(shapeMU, [-1, 3])/1e5 82 | meanshape = meanshape[index_shape, :] 83 | meanshape = np.reshape(meanshape, [1, -1]) 84 | 85 | meantex = np.reshape(texMU, [-1, 3]) 86 | meantex = meantex[index_shape, :] 87 | meantex = np.reshape(meantex, [1, -1]) 88 | 89 | # other info contains triangles, region used for computing photometric loss, 90 | # region used for skin texture regularization, and 68 landmarks index etc. 91 | other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat')) 92 | frontmask2_idx = other_info['frontmask2_idx'] 93 | skinmask = other_info['skinmask'] 94 | keypoints = other_info['keypoints'] 95 | point_buf = other_info['point_buf'] 96 | tri = other_info['tri'] 97 | tri_mask2 = other_info['tri_mask2'] 98 | 99 | # save our face model 100 | savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase, 101 | 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask}) 102 | 103 | 104 | # load landmarks for standard face, which is used for image preprocessing 105 | def load_lm3d(bfm_folder): 106 | 107 | Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat')) 108 | Lm3D = Lm3D['lm'] 109 | 110 | # calculate 5 facial landmarks using 68 landmarks 111 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 112 | Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean( 113 | Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0) 114 | Lm3D = Lm3D[[1, 2, 0, 3, 4], :] 115 | 116 | return Lm3D 117 | 118 | 119 | if __name__ == '__main__': 120 | transferBFM09() -------------------------------------------------------------------------------- /SadTalker/src/generate_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | import torch 5 | import numpy as np 6 | import random 7 | import scipy.io as scio 8 | import src.utils.audio as audio 9 | 10 | def crop_pad_audio(wav, audio_length): 11 | if len(wav) > audio_length: 12 | wav = wav[:audio_length] 13 | elif len(wav) < audio_length: 14 | wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) 15 | return wav 16 | 17 | def parse_audio_length(audio_length, sr, fps): 18 | bit_per_frames = sr / fps 19 | 20 | num_frames = int(audio_length / bit_per_frames) 21 | audio_length = int(num_frames * bit_per_frames) 22 | 23 | return audio_length, num_frames 24 | 25 | def generate_blink_seq(num_frames): 26 | ratio = np.zeros((num_frames,1)) 27 | frame_id = 0 28 | while frame_id in range(num_frames): 29 | start = 80 30 | if frame_id+start+9<=num_frames - 1: 31 | ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5] 32 | frame_id = frame_id+start+9 33 | else: 34 | break 35 | return ratio 36 | 37 | def generate_blink_seq_randomly(num_frames): 38 | ratio = np.zeros((num_frames,1)) 39 | if num_frames<=20: 40 | return ratio 41 | frame_id = 0 42 | while frame_id in range(num_frames): 43 | start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) 44 | if frame_id+start+5<=num_frames - 1: 45 | ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5] 46 | frame_id = frame_id+start+5 47 | else: 48 | break 49 | return ratio 50 | 51 | def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True): 52 | 53 | syncnet_mel_step_size = 16 54 | fps = 25 55 | 56 | pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0] 57 | audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0] 58 | 59 | 60 | if idlemode: 61 | num_frames = int(length_of_audio * 25) 62 | indiv_mels = np.zeros((num_frames, 80, 16)) 63 | else: 64 | wav = audio.load_wav(audio_path, 16000) 65 | wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) 66 | wav = crop_pad_audio(wav, wav_length) 67 | orig_mel = audio.melspectrogram(wav).T 68 | spec = orig_mel.copy() # nframes 80 69 | indiv_mels = [] 70 | 71 | for i in tqdm(range(num_frames), 'mel:'): 72 | start_frame_num = i-2 73 | start_idx = int(80. * (start_frame_num / float(fps))) 74 | end_idx = start_idx + syncnet_mel_step_size 75 | seq = list(range(start_idx, end_idx)) 76 | seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ] 77 | m = spec[seq, :] 78 | indiv_mels.append(m.T) 79 | indiv_mels = np.asarray(indiv_mels) # T 80 16 80 | 81 | ratio = generate_blink_seq_randomly(num_frames) # T 82 | source_semantics_path = first_coeff_path 83 | source_semantics_dict = scio.loadmat(source_semantics_path) 84 | ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70 85 | ref_coeff = np.repeat(ref_coeff, num_frames, axis=0) 86 | 87 | if ref_eyeblink_coeff_path is not None: 88 | ratio[:num_frames] = 0 89 | refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path) 90 | refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64] 91 | refeyeblink_num_frames = refeyeblink_coeff.shape[0] 92 | if refeyeblink_num_frames: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import numpy as np 14 | import importlib 15 | import torch.utils.data 16 | from face3d.data.base_dataset import BaseDataset 17 | 18 | 19 | def find_dataset_using_name(dataset_name): 20 | """Import the module "data/[dataset_name]_dataset.py". 21 | 22 | In the file, the class called DatasetNameDataset() will 23 | be instantiated. It has to be a subclass of BaseDataset, 24 | and it is case-insensitive. 25 | """ 26 | dataset_filename = "data." + dataset_name + "_dataset" 27 | datasetlib = importlib.import_module(dataset_filename) 28 | 29 | dataset = None 30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 31 | for name, cls in datasetlib.__dict__.items(): 32 | if name.lower() == target_dataset_name.lower() \ 33 | and issubclass(cls, BaseDataset): 34 | dataset = cls 35 | 36 | if dataset is None: 37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 38 | 39 | return dataset 40 | 41 | 42 | def get_option_setter(dataset_name): 43 | """Return the static method of the dataset class.""" 44 | dataset_class = find_dataset_using_name(dataset_name) 45 | return dataset_class.modify_commandline_options 46 | 47 | 48 | def create_dataset(opt, rank=0): 49 | """Create a dataset given the option. 50 | 51 | This function wraps the class CustomDatasetDataLoader. 52 | This is the main interface between this package and 'train.py'/'test.py' 53 | 54 | Example: 55 | >>> from data import create_dataset 56 | >>> dataset = create_dataset(opt) 57 | """ 58 | data_loader = CustomDatasetDataLoader(opt, rank=rank) 59 | dataset = data_loader.load_data() 60 | return dataset 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt, rank=0): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | self.sampler = None 75 | print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) 76 | if opt.use_ddp and opt.isTrain: 77 | world_size = opt.world_size 78 | self.sampler = torch.utils.data.distributed.DistributedSampler( 79 | self.dataset, 80 | num_replicas=world_size, 81 | rank=rank, 82 | shuffle=not opt.serial_batches 83 | ) 84 | self.dataloader = torch.utils.data.DataLoader( 85 | self.dataset, 86 | sampler=self.sampler, 87 | num_workers=int(opt.num_threads / world_size), 88 | batch_size=int(opt.batch_size / world_size), 89 | drop_last=True) 90 | else: 91 | self.dataloader = torch.utils.data.DataLoader( 92 | self.dataset, 93 | batch_size=opt.batch_size, 94 | shuffle=(not opt.serial_batches) and opt.isTrain, 95 | num_workers=int(opt.num_threads), 96 | drop_last=True 97 | ) 98 | 99 | def set_epoch(self, epoch): 100 | self.dataset.current_epoch = epoch 101 | if self.sampler is not None: 102 | self.sampler.set_epoch(epoch) 103 | 104 | def load_data(self): 105 | return self 106 | 107 | def __len__(self): 108 | """Return the number of data in the dataset""" 109 | return min(len(self.dataset), self.opt.max_dataset_size) 110 | 111 | def __iter__(self): 112 | """Return a batch of data""" 113 | for i, data in enumerate(self.dataloader): 114 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 115 | break 116 | yield data 117 | -------------------------------------------------------------------------------- /SadTalker/src/facerender/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /SadTalker/src/utils/audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | # import tensorflow as tf 5 | from scipy import signal 6 | from scipy.io import wavfile 7 | from src.utils.hparams import hparams as hp 8 | 9 | def load_wav(path, sr): 10 | return librosa.core.load(path, sr=sr)[0] 11 | 12 | def save_wav(wav, path, sr): 13 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 14 | #proposed by @dsmiller 15 | wavfile.write(path, sr, wav.astype(np.int16)) 16 | 17 | def save_wavenet_wav(wav, path, sr): 18 | librosa.output.write_wav(path, wav, sr=sr) 19 | 20 | def preemphasis(wav, k, preemphasize=True): 21 | if preemphasize: 22 | return signal.lfilter([1, -k], [1], wav) 23 | return wav 24 | 25 | def inv_preemphasis(wav, k, inv_preemphasize=True): 26 | if inv_preemphasize: 27 | return signal.lfilter([1], [1, -k], wav) 28 | return wav 29 | 30 | def get_hop_size(): 31 | hop_size = hp.hop_size 32 | if hop_size is None: 33 | assert hp.frame_shift_ms is not None 34 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 35 | return hop_size 36 | 37 | def linearspectrogram(wav): 38 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 39 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 40 | 41 | if hp.signal_normalization: 42 | return _normalize(S) 43 | return S 44 | 45 | def melspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | def _lws_processor(): 54 | import lws 55 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 56 | 57 | def _stft(y): 58 | if hp.use_lws: 59 | return _lws_processor(hp).stft(y).T 60 | else: 61 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 62 | 63 | ########################################################## 64 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 65 | def num_frames(length, fsize, fshift): 66 | """Compute number of time frames of spectrogram 67 | """ 68 | pad = (fsize - fshift) 69 | if length % fshift == 0: 70 | M = (length + pad * 2 - fsize) // fshift + 1 71 | else: 72 | M = (length + pad * 2 - fsize) // fshift + 2 73 | return M 74 | 75 | 76 | def pad_lr(x, fsize, fshift): 77 | """Compute left and right padding 78 | """ 79 | M = num_frames(len(x), fsize, fshift) 80 | pad = (fsize - fshift) 81 | T = len(x) + 2 * pad 82 | r = (M - 1) * fshift + fsize - T 83 | return pad, pad + r 84 | ########################################################## 85 | #Librosa correct padding 86 | def librosa_pad_lr(x, fsize, fshift): 87 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 88 | 89 | # Conversions 90 | _mel_basis = None 91 | 92 | def _linear_to_mel(spectogram): 93 | global _mel_basis 94 | if _mel_basis is None: 95 | _mel_basis = _build_mel_basis() 96 | return np.dot(_mel_basis, spectogram) 97 | 98 | def _build_mel_basis(): 99 | assert hp.fmax <= hp.sample_rate // 2 100 | return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, 101 | fmin=hp.fmin, fmax=hp.fmax) 102 | 103 | def _amp_to_db(x): 104 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 105 | return 20 * np.log10(np.maximum(min_level, x)) 106 | 107 | def _db_to_amp(x): 108 | return np.power(10.0, (x) * 0.05) 109 | 110 | def _normalize(S): 111 | if hp.allow_clipping_in_normalization: 112 | if hp.symmetric_mels: 113 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 114 | -hp.max_abs_value, hp.max_abs_value) 115 | else: 116 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 117 | 118 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 119 | if hp.symmetric_mels: 120 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 121 | else: 122 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 123 | 124 | def _denormalize(D): 125 | if hp.allow_clipping_in_normalization: 126 | if hp.symmetric_mels: 127 | return (((np.clip(D, -hp.max_abs_value, 128 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 129 | + hp.min_level_db) 130 | else: 131 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 132 | 133 | if hp.symmetric_mels: 134 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 135 | else: 136 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 137 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/util/nvdiffrast.py: -------------------------------------------------------------------------------- 1 | """This script is the differentiable renderer for Deep3DFaceRecon_pytorch 2 | Attention, antialiasing step is missing in current version. 3 | """ 4 | import pytorch3d.ops 5 | import torch 6 | import torch.nn.functional as F 7 | import kornia 8 | from kornia.geometry.camera import pixel2cam 9 | import numpy as np 10 | from typing import List 11 | from scipy.io import loadmat 12 | from torch import nn 13 | 14 | from pytorch3d.structures import Meshes 15 | from pytorch3d.renderer import ( 16 | look_at_view_transform, 17 | FoVPerspectiveCameras, 18 | DirectionalLights, 19 | RasterizationSettings, 20 | MeshRenderer, 21 | MeshRasterizer, 22 | SoftPhongShader, 23 | TexturesUV, 24 | ) 25 | 26 | # def ndc_projection(x=0.1, n=1.0, f=50.0): 27 | # return np.array([[n/x, 0, 0, 0], 28 | # [ 0, n/-x, 0, 0], 29 | # [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], 30 | # [ 0, 0, -1, 0]]).astype(np.float32) 31 | 32 | class MeshRenderer(nn.Module): 33 | def __init__(self, 34 | rasterize_fov, 35 | znear=0.1, 36 | zfar=10, 37 | rasterize_size=224): 38 | super(MeshRenderer, self).__init__() 39 | 40 | # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear 41 | # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( 42 | # torch.diag(torch.tensor([1., -1, -1, 1]))) 43 | self.rasterize_size = rasterize_size 44 | self.fov = rasterize_fov 45 | self.znear = znear 46 | self.zfar = zfar 47 | 48 | self.rasterizer = None 49 | 50 | def forward(self, vertex, tri, feat=None): 51 | """ 52 | Return: 53 | mask -- torch.tensor, size (B, 1, H, W) 54 | depth -- torch.tensor, size (B, 1, H, W) 55 | features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None 56 | 57 | Parameters: 58 | vertex -- torch.tensor, size (B, N, 3) 59 | tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles 60 | feat(optional) -- torch.tensor, size (B, N ,C), features 61 | """ 62 | device = vertex.device 63 | rsize = int(self.rasterize_size) 64 | # ndc_proj = self.ndc_proj.to(device) 65 | # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v 66 | if vertex.shape[-1] == 3: 67 | vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) 68 | vertex[..., 0] = -vertex[..., 0] 69 | 70 | 71 | # vertex_ndc = vertex @ ndc_proj.t() 72 | if self.rasterizer is None: 73 | self.rasterizer = MeshRasterizer() 74 | print("create rasterizer on device cuda:%d"%device.index) 75 | 76 | # ranges = None 77 | # if isinstance(tri, List) or len(tri.shape) == 3: 78 | # vum = vertex_ndc.shape[1] 79 | # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) 80 | # fstartidx = torch.cumsum(fnum, dim=0) - fnum 81 | # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() 82 | # for i in range(tri.shape[0]): 83 | # tri[i] = tri[i] + i*vum 84 | # vertex_ndc = torch.cat(vertex_ndc, dim=0) 85 | # tri = torch.cat(tri, dim=0) 86 | 87 | # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] 88 | tri = tri.type(torch.int32).contiguous() 89 | 90 | # rasterize 91 | cameras = FoVPerspectiveCameras( 92 | device=device, 93 | fov=self.fov, 94 | znear=self.znear, 95 | zfar=self.zfar, 96 | ) 97 | 98 | raster_settings = RasterizationSettings( 99 | image_size=rsize 100 | ) 101 | 102 | # print(vertex.shape, tri.shape) 103 | mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1))) 104 | 105 | fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings) 106 | rast_out = fragments.pix_to_face.squeeze(-1) 107 | depth = fragments.zbuf 108 | 109 | # render depth 110 | depth = depth.permute(0, 3, 1, 2) 111 | mask = (rast_out > 0).float().unsqueeze(1) 112 | depth = mask * depth 113 | 114 | 115 | image = None 116 | if feat is not None: 117 | attributes = feat.reshape(-1,3)[mesh.faces_packed()] 118 | image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face, 119 | fragments.bary_coords, 120 | attributes) 121 | # print(image.shape) 122 | image = image.squeeze(-2).permute(0, 3, 1, 2) 123 | image = mask * image 124 | 125 | return mask, depth, image 126 | 127 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | # self.root = opt.dataroot 31 | self.current_epoch = 0 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_transform(grayscale=False): 65 | transform_list = [] 66 | if grayscale: 67 | transform_list.append(transforms.Grayscale(1)) 68 | transform_list += [transforms.ToTensor()] 69 | return transforms.Compose(transform_list) 70 | 71 | def get_affine_mat(opt, size): 72 | shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False 73 | w, h = size 74 | 75 | if 'shift' in opt.preprocess: 76 | shift_pixs = int(opt.shift_pixs) 77 | shift_x = random.randint(-shift_pixs, shift_pixs) 78 | shift_y = random.randint(-shift_pixs, shift_pixs) 79 | if 'scale' in opt.preprocess: 80 | scale = 1 + opt.scale_delta * (2 * random.random() - 1) 81 | if 'rot' in opt.preprocess: 82 | rot_angle = opt.rot_angle * (2 * random.random() - 1) 83 | rot_rad = -rot_angle * np.pi/180 84 | if 'flip' in opt.preprocess: 85 | flip = random.random() > 0.5 86 | 87 | shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) 88 | flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) 89 | shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) 90 | rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) 91 | scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) 92 | shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) 93 | 94 | affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin 95 | affine_inv = np.linalg.inv(affine) 96 | return affine, affine_inv, flip 97 | 98 | def apply_img_affine(img, affine_inv, method=Image.BICUBIC): 99 | return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC) 100 | 101 | def apply_lm_affine(landmark, affine, flip, size): 102 | _, h = size 103 | lm = landmark.copy() 104 | lm[:, 1] = h - 1 - lm[:, 1] 105 | lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) 106 | lm = lm @ np.transpose(affine) 107 | lm[:, :2] = lm[:, :2] / lm[:, 2:] 108 | lm = lm[:, :2] 109 | lm[:, 1] = h - 1 - lm[:, 1] 110 | if flip: 111 | lm_ = lm.copy() 112 | lm_[:17] = lm[16::-1] 113 | lm_[17:22] = lm[26:21:-1] 114 | lm_[22:27] = lm[21:16:-1] 115 | lm_[31:36] = lm[35:30:-1] 116 | lm_[36:40] = lm[45:41:-1] 117 | lm_[40:42] = lm[47:45:-1] 118 | lm_[42:46] = lm[39:35:-1] 119 | lm_[46:48] = lm[41:39:-1] 120 | lm_[48:55] = lm[54:47:-1] 121 | lm_[55:60] = lm[59:54:-1] 122 | lm_[60:65] = lm[64:59:-1] 123 | lm_[65:68] = lm[67:64:-1] 124 | lm = lm_ 125 | return lm 126 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/backbones/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module 8 | import torch 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class ConvBlock(Module): 17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 18 | super(ConvBlock, self).__init__() 19 | self.layers = nn.Sequential( 20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), 21 | BatchNorm2d(num_features=out_c), 22 | PReLU(num_parameters=out_c) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.layers(x) 27 | 28 | 29 | class LinearBlock(Module): 30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 31 | super(LinearBlock, self).__init__() 32 | self.layers = nn.Sequential( 33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), 34 | BatchNorm2d(num_features=out_c) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.layers(x) 39 | 40 | 41 | class DepthWise(Module): 42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 43 | super(DepthWise, self).__init__() 44 | self.residual = residual 45 | self.layers = nn.Sequential( 46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), 47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), 48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 49 | ) 50 | 51 | def forward(self, x): 52 | short_cut = None 53 | if self.residual: 54 | short_cut = x 55 | x = self.layers(x) 56 | if self.residual: 57 | output = short_cut + x 58 | else: 59 | output = x 60 | return output 61 | 62 | 63 | class Residual(Module): 64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 65 | super(Residual, self).__init__() 66 | modules = [] 67 | for _ in range(num_block): 68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) 69 | self.layers = Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class GDC(Module): 76 | def __init__(self, embedding_size): 77 | super(GDC, self).__init__() 78 | self.layers = nn.Sequential( 79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), 80 | Flatten(), 81 | Linear(512, embedding_size, bias=False), 82 | BatchNorm1d(embedding_size)) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | 87 | 88 | class MobileFaceNet(Module): 89 | def __init__(self, fp16=False, num_features=512): 90 | super(MobileFaceNet, self).__init__() 91 | scale = 2 92 | self.fp16 = fp16 93 | self.layers = nn.Sequential( 94 | ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), 95 | ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), 96 | DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), 97 | Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 98 | DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), 99 | Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 100 | DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), 101 | Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 102 | ) 103 | self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 104 | self.features = GDC(num_features) 105 | self._initialize_weights() 106 | 107 | def _initialize_weights(self): 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | 121 | def forward(self, x): 122 | with torch.cuda.amp.autocast(self.fp16): 123 | x = self.layers(x) 124 | x = self.conv_sep(x.float() if self.fp16 else x) 125 | x = self.features(x) 126 | return x 127 | 128 | 129 | def get_mbf(fp16, num_features): 130 | return MobileFaceNet(fp16, num_features) -------------------------------------------------------------------------------- /nodes/SadTalkerNode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import torch # type: ignore 5 | import numpy as np # type: ignore 6 | from PIL import Image # type: ignore 7 | import soundfile as sf # type: ignore 8 | from datetime import datetime 9 | import folder_paths # type: ignore 10 | 11 | 12 | from ..SadTalker.src.Pb import SadTalker 13 | 14 | 15 | class SadTalkerNode: 16 | 17 | def __init__(self): 18 | self.checkpoint_path = os.path.abspath( 19 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "checkpoints") 20 | ) 21 | self.config_path = os.path.abspath( 22 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "src", "config") 23 | ) 24 | self.output_dir = os.path.abspath( 25 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "output") 26 | ) 27 | self.comfy_output_dir = folder_paths.get_output_directory() 28 | self.input_dir = os.path.abspath( 29 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "input") 30 | ) 31 | os.makedirs(self.output_dir, exist_ok=True) 32 | os.makedirs(self.input_dir, exist_ok=True) 33 | os.makedirs(self.comfy_output_dir, exist_ok=True) 34 | 35 | # 实例化 SadTalker 36 | self.sad_talker = SadTalker( 37 | checkpoint_path=self.checkpoint_path, 38 | config_path=self.config_path, 39 | lazy_load=True, 40 | ) 41 | 42 | @classmethod 43 | def INPUT_TYPES(cls): 44 | return { 45 | "required": { 46 | "image": ("IMAGE",), 47 | "audio": ("AUDIO",), 48 | "poseStyle": ("INT", {"default": 0, "min": 0, "max": 46}), 49 | "faceModelResolution": (["256", "512"], {"default": "256"}), 50 | "preprocess": ( 51 | ["crop", "resize", "full", "extcrop", "extfull"], 52 | {"default": "crop"}, 53 | ), 54 | "stillMode": ("BOOLEAN", {"default": False}), 55 | "batchSizeInGeneration": ( 56 | "INT", 57 | {"default": 2, "min": 0, "max": 10}, 58 | ), 59 | "gfpganAsFaceEnhancer": ("BOOLEAN", {"default": False}), 60 | "useIdleMode": ("BOOLEAN", {"default": False}), 61 | "idleModeTime": ("INT", {"default": 5, "min": 1, "max": 90}), 62 | "useRefVideo": ("BOOLEAN", {"default": False}), 63 | "refInfo": ( 64 | ["pose", "blink", "pose+blink", "all"], 65 | {"default": "pose"}, 66 | ), 67 | }, 68 | "optional": { 69 | "refVideo": ("VIDEOSTRING",), 70 | }, 71 | } 72 | 73 | RETURN_TYPES = ( 74 | "STRING", 75 | "STRING", 76 | ) 77 | RETURN_NAMES = ( 78 | "video_path", 79 | "show_video_path", 80 | ) 81 | FUNCTION = "generate" 82 | CATEGORY = "SadTalker" 83 | 84 | def generate( 85 | self, 86 | image, 87 | audio, 88 | poseStyle, 89 | faceModelResolution, 90 | preprocess, 91 | stillMode, 92 | batchSizeInGeneration, 93 | gfpganAsFaceEnhancer, 94 | useIdleMode, 95 | idleModeTime, 96 | useRefVideo, 97 | refInfo, 98 | refVideo=None, 99 | ): 100 | 101 | # 生成时间戳目录 102 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 103 | input_dir = os.path.join(self.input_dir, timestamp) 104 | os.makedirs(input_dir, exist_ok=True) 105 | 106 | # 保存图像 107 | if isinstance(image, torch.Tensor): 108 | image = image.cpu().numpy() 109 | if image.ndim == 4: 110 | image = image[0] 111 | if image.ndim == 3: 112 | if image.shape[2] in {1, 3}: 113 | if image.shape[2] == 1: 114 | image = np.squeeze(image, axis=2) 115 | else: 116 | raise ValueError("Unexpected number of channels in image tensor") 117 | 118 | image = (image * 255).astype(np.uint8) 119 | image_path = os.path.join(input_dir, "input_image.png") 120 | Image.fromarray(image).save(image_path) 121 | 122 | # 提取并保存音频 123 | waveform = audio["waveform"].cpu().numpy().squeeze() 124 | sample_rate = audio["sample_rate"] 125 | if waveform.ndim == 2: 126 | waveform = waveform.T 127 | audio_path = os.path.join(input_dir, "input_audio.wav") 128 | sf.write(audio_path, waveform, sample_rate) 129 | 130 | result = self.sad_talker.test( 131 | source_image=image_path, 132 | driven_audio=audio_path, 133 | preprocess=preprocess, 134 | still_mode=stillMode, 135 | use_enhancer=gfpganAsFaceEnhancer, 136 | batch_size=batchSizeInGeneration, 137 | size=int(faceModelResolution), 138 | pose_style=poseStyle, 139 | exp_scale=1.0, 140 | use_ref_video=useRefVideo, 141 | ref_video=refVideo, 142 | ref_info=refInfo, 143 | use_idle_mode=useIdleMode, 144 | length_of_audio=idleModeTime, 145 | use_blink=True, 146 | result_dir=self.output_dir, 147 | ) 148 | 149 | comy_out = os.path.join(self.comfy_output_dir, timestamp + ".mp4") 150 | shutil.copy(result, comy_out) 151 | 152 | return ( 153 | result, 154 | comy_out, 155 | ) 156 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import List 4 | 5 | import torch 6 | 7 | from eval import verification 8 | from utils.utils_logging import AverageMeter 9 | 10 | 11 | class CallBackVerification(object): 12 | def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): 13 | self.frequent: int = frequent 14 | self.rank: int = rank 15 | self.highest_acc: float = 0.0 16 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 17 | self.ver_list: List[object] = [] 18 | self.ver_name_list: List[str] = [] 19 | if self.rank is 0: 20 | self.init_dataset( 21 | val_targets=val_targets, data_dir=rec_prefix, image_size=image_size 22 | ) 23 | 24 | def ver_test(self, backbone: torch.nn.Module, global_step: int): 25 | results = [] 26 | for i in range(len(self.ver_list)): 27 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( 28 | self.ver_list[i], backbone, 10, 10 29 | ) 30 | if acc2 > self.highest_acc_list[i]: 31 | self.highest_acc_list[i] = acc2 32 | results.append(acc2) 33 | 34 | def init_dataset(self, val_targets, data_dir, image_size): 35 | for name in val_targets: 36 | path = os.path.join(data_dir, name + ".bin") 37 | if os.path.exists(path): 38 | data_set = verification.load_bin(path, image_size) 39 | self.ver_list.append(data_set) 40 | self.ver_name_list.append(name) 41 | 42 | def __call__(self, num_update, backbone: torch.nn.Module): 43 | if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: 44 | backbone.eval() 45 | self.ver_test(backbone, num_update) 46 | backbone.train() 47 | 48 | 49 | class CallBackLogging(object): 50 | def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): 51 | self.frequent: int = frequent 52 | self.rank: int = rank 53 | self.time_start = time.time() 54 | self.total_step: int = total_step 55 | self.batch_size: int = batch_size 56 | self.world_size: int = world_size 57 | self.writer = writer 58 | 59 | self.init = False 60 | self.tic = 0 61 | 62 | def __call__( 63 | self, 64 | global_step: int, 65 | loss: AverageMeter, 66 | epoch: int, 67 | fp16: bool, 68 | learning_rate: float, 69 | grad_scaler: torch.cuda.amp.GradScaler, 70 | ): 71 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 72 | if self.init: 73 | try: 74 | speed: float = ( 75 | self.frequent * self.batch_size / (time.time() - self.tic) 76 | ) 77 | speed_total = speed * self.world_size 78 | except ZeroDivisionError: 79 | speed_total = float("inf") 80 | 81 | time_now = (time.time() - self.time_start) / 3600 82 | time_total = time_now / ((global_step + 1) / self.total_step) 83 | time_for_end = time_total - time_now 84 | if self.writer is not None: 85 | self.writer.add_scalar("time_for_end", time_for_end, global_step) 86 | self.writer.add_scalar("learning_rate", learning_rate, global_step) 87 | self.writer.add_scalar("loss", loss.avg, global_step) 88 | if fp16: 89 | msg = ( 90 | "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " 91 | "Fp16 Grad Scale: %2.f Required: %1.f hours" 92 | % ( 93 | speed_total, 94 | loss.avg, 95 | learning_rate, 96 | epoch, 97 | global_step, 98 | grad_scaler.get_scale(), 99 | time_for_end, 100 | ) 101 | ) 102 | else: 103 | msg = ( 104 | "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " 105 | "Required: %1.f hours" 106 | % ( 107 | speed_total, 108 | loss.avg, 109 | learning_rate, 110 | epoch, 111 | global_step, 112 | time_for_end, 113 | ) 114 | ) 115 | loss.reset() 116 | self.tic = time.time() 117 | else: 118 | self.init = True 119 | self.tic = time.time() 120 | 121 | 122 | class CallBackModelCheckpoint(object): 123 | def __init__(self, rank, output="./"): 124 | self.rank: int = rank 125 | self.output: str = output 126 | 127 | def __call__( 128 | self, 129 | global_step, 130 | backbone, 131 | partial_fc, 132 | ): 133 | if global_step > 100 and self.rank == 0: 134 | path_module = os.path.join(self.output, "backbone.pth") 135 | torch.save(backbone.module.state_dict(), path_module) 136 | 137 | if global_step > 100 and partial_fc is not None: 138 | partial_fc.save_params() 139 | -------------------------------------------------------------------------------- /SadTalker/src/face3d/models/arcface_torch/docs/speed_benchmark.md: -------------------------------------------------------------------------------- 1 | ## Test Training Speed 2 | 3 | - Test Commands 4 | 5 | You need to use the following two commands to test the Partial FC training performance. 6 | The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, 7 | batch size is 1024. 8 | ```shell 9 | # Model Parallel 10 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions 11 | # Partial FC 0.1 12 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc 13 | ``` 14 | 15 | - GPU Memory 16 | 17 | ``` 18 | # (Model Parallel) gpustat -i 19 | [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB 20 | [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB 21 | [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB 22 | [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB 23 | [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB 24 | [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB 25 | [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB 26 | [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB 27 | 28 | # (Partial FC 0.1) gpustat -i 29 | [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· 30 | [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· 31 | [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· 32 | [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· 33 | [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· 34 | [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· 35 | [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· 36 | [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· 37 | ``` 38 | 39 | - Training Speed 40 | 41 | ```python 42 | # (Model Parallel) trainging.log 43 | Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 44 | Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 45 | Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 46 | Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 47 | Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 48 | 49 | # (Partial FC 0.1) trainging.log 50 | Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 51 | Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 52 | Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 53 | Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 54 | Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 55 | ``` 56 | 57 | In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, 58 | and the training speed is 2.5 times faster than the model parallel. 59 | 60 | 61 | ## Speed Benchmark 62 | 63 | 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) 64 | 65 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | 66 | | :--- | :--- | :--- | :--- | 67 | |125000 | 4681 | 4824 | 5004 | 68 | |250000 | 4047 | 4521 | 4976 | 69 | |500000 | 3087 | 4013 | 4900 | 70 | |1000000 | 2090 | 3449 | 4803 | 71 | |1400000 | 1672 | 3043 | 4738 | 72 | |2000000 | - | 2593 | 4626 | 73 | |4000000 | - | 1748 | 4208 | 74 | |5500000 | - | 1389 | 3975 | 75 | |8000000 | - | - | 3565 | 76 | |16000000 | - | - | 2679 | 77 | |29000000 | - | - | 1855 | 78 | 79 | 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) 80 | 81 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | 82 | | :--- | :--- | :--- | :--- | 83 | |125000 | 7358 | 5306 | 4868 | 84 | |250000 | 9940 | 5826 | 5004 | 85 | |500000 | 14220 | 7114 | 5202 | 86 | |1000000 | 23708 | 9966 | 5620 | 87 | |1400000 | 32252 | 11178 | 6056 | 88 | |2000000 | - | 13978 | 6472 | 89 | |4000000 | - | 23238 | 8284 | 90 | |5500000 | - | 32188 | 9854 | 91 | |8000000 | - | - | 12310 | 92 | |16000000 | - | - | 19950 | 93 | |29000000 | - | - | 32324 | 94 | --------------------------------------------------------------------------------