├── SadTalker
├── custom
│ └── put custom
├── input
│ └── put input
├── output
│ └── put output
├── src
│ ├── face3d
│ │ ├── models
│ │ │ ├── arcface_torch
│ │ │ │ ├── docs
│ │ │ │ │ ├── modelzoo.md
│ │ │ │ │ ├── eval.md
│ │ │ │ │ ├── install.md
│ │ │ │ │ └── speed_benchmark.md
│ │ │ │ ├── eval
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── configs
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── speed.py
│ │ │ │ │ ├── 3millions.py
│ │ │ │ │ ├── 3millions_pfc.py
│ │ │ │ │ ├── glint360k_mbf.py
│ │ │ │ │ ├── glint360k_r100.py
│ │ │ │ │ ├── glint360k_r18.py
│ │ │ │ │ ├── glint360k_r34.py
│ │ │ │ │ ├── glint360k_r50.py
│ │ │ │ │ ├── ms1mv3_mbf.py
│ │ │ │ │ ├── ms1mv3_r18.py
│ │ │ │ │ ├── ms1mv3_r2060.py
│ │ │ │ │ ├── ms1mv3_r34.py
│ │ │ │ │ ├── ms1mv3_r50.py
│ │ │ │ │ └── base.py
│ │ │ │ ├── utils
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── utils_os.py
│ │ │ │ │ ├── utils_config.py
│ │ │ │ │ ├── utils_logging.py
│ │ │ │ │ ├── plot.py
│ │ │ │ │ ├── utils_amp.py
│ │ │ │ │ └── utils_callbacks.py
│ │ │ │ ├── requirement.txt
│ │ │ │ ├── run.sh
│ │ │ │ ├── backbones
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── mobilefacenet.py
│ │ │ │ ├── inference.py
│ │ │ │ ├── losses.py
│ │ │ │ ├── torch2onnx.py
│ │ │ │ └── dataset.py
│ │ │ ├── __init__.py
│ │ │ └── losses.py
│ │ ├── util
│ │ │ ├── BBRegressorParam_r.mat
│ │ │ ├── __init__.py
│ │ │ ├── generate_list.py
│ │ │ ├── html.py
│ │ │ ├── preprocess.py
│ │ │ ├── test_mean_face.txt
│ │ │ ├── detect_lm68.py
│ │ │ ├── load_mats.py
│ │ │ └── nvdiffrast.py
│ │ ├── options
│ │ │ ├── __init__.py
│ │ │ ├── test_options.py
│ │ │ ├── inference_options.py
│ │ │ └── train_options.py
│ │ ├── visualize.py
│ │ ├── data
│ │ │ ├── image_folder.py
│ │ │ ├── template_dataset.py
│ │ │ ├── flist_dataset.py
│ │ │ ├── __init__.py
│ │ │ └── base_dataset.py
│ │ └── extract_kp_videos.py
│ ├── config
│ │ ├── similarity_Lm3D_all.mat
│ │ ├── facerender.yaml
│ │ ├── facerender_still.yaml
│ │ ├── auido2pose.yaml
│ │ └── auido2exp.yaml
│ ├── utils
│ │ ├── safetensor_helper.py
│ │ ├── text2speech.py
│ │ ├── init_path.py
│ │ ├── paste_pic.py
│ │ ├── videoio.py
│ │ ├── face_enhancer.py
│ │ └── audio.py
│ ├── facerender
│ │ ├── sync_batchnorm
│ │ │ ├── __init__.py
│ │ │ ├── unittest.py
│ │ │ ├── replicate.py
│ │ │ └── comm.py
│ │ └── modules
│ │ │ ├── mapping.py
│ │ │ └── discriminator.py
│ ├── audio2exp_models
│ │ ├── audio2exp.py
│ │ └── networks.py
│ ├── audio2pose_models
│ │ ├── res_unet.py
│ │ ├── discriminator.py
│ │ ├── audio_encoder.py
│ │ ├── audio2pose.py
│ │ └── networks.py
│ └── generate_batch.py
├── requirements.txt
├── requirements3d.txt
├── req.txt
├── scripts
│ ├── test.sh
│ └── download_models.sh
├── cog.yaml
├── check_ffmpeg.py
└── .gitignore
├── examples
├── qq.jpg
├── result.mp4
└── workflow.png
├── requirements.txt
├── __init__.py
├── nodes
├── ShowText.py
├── ShowAudio.py
├── LoadRefVideo.py
├── ShowVideo.py
└── SadTalkerNode.py
├── README.md
└── web
└── js
├── showVideo.js
└── showText.js
/SadTalker/custom/put custom:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SadTalker/input/put input:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SadTalker/output/put output:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/docs/modelzoo.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/eval/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/utils/utils_os.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/qq.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/examples/qq.jpg
--------------------------------------------------------------------------------
/examples/result.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/examples/result.mp4
--------------------------------------------------------------------------------
/examples/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/examples/workflow.png
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/requirement.txt:
--------------------------------------------------------------------------------
1 | tensorboard
2 | easydict
3 | mxnet
4 | onnx
5 | sklearn
6 |
--------------------------------------------------------------------------------
/SadTalker/src/config/similarity_Lm3D_all.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/SadTalker/src/config/similarity_Lm3D_all.mat
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/BBRegressorParam_r.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haomole/Comfyui-SadTalker/HEAD/SadTalker/src/face3d/util/BBRegressorParam_r.mat
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 | from src.face3d.util import *
3 |
4 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/SadTalker/src/utils/safetensor_helper.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def load_x_from_safetensor(checkpoint, key):
4 | x_generator = {}
5 | for k,v in checkpoint.items():
6 | if key in k:
7 | x_generator[k.replace(key+'.', '')] = v
8 | return x_generator
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/run.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
2 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
3 |
--------------------------------------------------------------------------------
/SadTalker/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.23.4
2 | face_alignment==1.3.5
3 | imageio==2.19.3
4 | imageio-ffmpeg==0.4.7
5 | librosa==0.9.2 #
6 | numba
7 | resampy==0.3.1
8 | pydub==0.25.1
9 | scipy==1.10.1
10 | kornia==0.6.8
11 | tqdm
12 | yacs==0.1.8
13 | pyyaml
14 | joblib==1.1.0
15 | scikit-image==0.19.3
16 | basicsr==1.4.2
17 | facexlib==0.3.0
18 | gradio==3.50.2
19 | gfpgan
20 | av
21 | safetensors
--------------------------------------------------------------------------------
/SadTalker/requirements3d.txt:
--------------------------------------------------------------------------------
1 | numpy==1.23.4
2 | face_alignment==1.3.5
3 | imageio==2.19.3
4 | imageio-ffmpeg==0.4.7
5 | librosa==0.9.2 #
6 | numba
7 | resampy==0.3.1
8 | pydub==0.25.1
9 | scipy==1.5.3
10 | kornia==0.6.8
11 | tqdm
12 | yacs==0.1.8
13 | pyyaml
14 | joblib==1.1.0
15 | scikit-image==0.19.3
16 | basicsr==1.4.2
17 | facexlib==0.3.0
18 | trimesh==3.9.20
19 | gradio
20 | gfpgan
21 | safetensors
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.23.4
2 | face_alignment==1.3.5
3 | imageio==2.19.3
4 | imageio-ffmpeg==0.4.7
5 | librosa==0.9.2 #
6 | numba
7 | resampy==0.3.1
8 | pydub==0.25.1
9 | scipy==1.10.1
10 | kornia==0.6.8
11 | tqdm
12 | yacs==0.1.8
13 | pyyaml
14 | joblib==1.1.0
15 | scikit-image==0.19.3
16 | basicsr==1.4.2
17 | facexlib==0.3.0
18 | gradio==3.50.2
19 | gfpgan
20 | av
21 | safetensors
22 | moviepy
--------------------------------------------------------------------------------
/SadTalker/req.txt:
--------------------------------------------------------------------------------
1 | llvmlite==0.38.1
2 | numpy==1.21.6
3 | face_alignment==1.3.5
4 | imageio==2.19.3
5 | imageio-ffmpeg==0.4.7
6 | librosa==0.10.0.post2
7 | numba==0.55.1
8 | resampy==0.3.1
9 | pydub==0.25.1
10 | scipy==1.10.1
11 | kornia==0.6.8
12 | tqdm
13 | yacs==0.1.8
14 | pyyaml
15 | joblib==1.1.0
16 | scikit-image==0.19.3
17 | basicsr==1.4.2
18 | facexlib==0.3.0
19 | gradio
20 | gfpgan
21 | av
22 | safetensors
23 |
--------------------------------------------------------------------------------
/SadTalker/src/facerender/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/SadTalker/src/utils/text2speech.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from TTS.api import TTS
4 |
5 |
6 | class TTSTalker():
7 | def __init__(self) -> None:
8 | model_name = TTS().list_models()[0]
9 | self.tts = TTS(model_name)
10 |
11 | def test(self, text, language='en'):
12 |
13 | tempf = tempfile.NamedTemporaryFile(
14 | delete = False,
15 | suffix = ('.'+'wav'),
16 | )
17 |
18 | self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name)
19 |
20 | return tempf.name
21 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/speed.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # configs for test speed
4 |
5 | config = edict()
6 | config.loss = "arcface"
7 | config.network = "r50"
8 | config.resume = False
9 | config.output = None
10 | config.embedding_size = 512
11 | config.sample_rate = 1.0
12 | config.fp16 = True
13 | config.momentum = 0.9
14 | config.weight_decay = 5e-4
15 | config.batch_size = 128
16 | config.lr = 0.1 # batch size is 512
17 |
18 | config.rec = "synthetic"
19 | config.num_classes = 100 * 10000
20 | config.num_epoch = 30
21 | config.warmup_epoch = -1
22 | config.decay_epoch = [10, 16, 22]
23 | config.val_targets = []
24 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/3millions.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # configs for test speed
4 |
5 | config = edict()
6 | config.loss = "arcface"
7 | config.network = "r50"
8 | config.resume = False
9 | config.output = None
10 | config.embedding_size = 512
11 | config.sample_rate = 1.0
12 | config.fp16 = True
13 | config.momentum = 0.9
14 | config.weight_decay = 5e-4
15 | config.batch_size = 128
16 | config.lr = 0.1 # batch size is 512
17 |
18 | config.rec = "synthetic"
19 | config.num_classes = 300 * 10000
20 | config.num_epoch = 30
21 | config.warmup_epoch = -1
22 | config.decay_epoch = [10, 16, 22]
23 | config.val_targets = []
24 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/3millions_pfc.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # configs for test speed
4 |
5 | config = edict()
6 | config.loss = "arcface"
7 | config.network = "r50"
8 | config.resume = False
9 | config.output = None
10 | config.embedding_size = 512
11 | config.sample_rate = 0.1
12 | config.fp16 = True
13 | config.momentum = 0.9
14 | config.weight_decay = 5e-4
15 | config.batch_size = 128
16 | config.lr = 0.1 # batch size is 512
17 |
18 | config.rec = "synthetic"
19 | config.num_classes = 300 * 10000
20 | config.num_epoch = 30
21 | config.warmup_epoch = -1
22 | config.decay_epoch = [10, 16, 22]
23 | config.val_targets = []
24 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/utils/utils_config.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os.path as osp
3 |
4 |
5 | def get_config(config_file):
6 | assert config_file.startswith('configs/'), 'config file setting must start with configs/'
7 | temp_config_name = osp.basename(config_file)
8 | temp_module_name = osp.splitext(temp_config_name)[0]
9 | config = importlib.import_module("configs.base")
10 | cfg = config.config
11 | config = importlib.import_module("configs.%s" % temp_module_name)
12 | job_cfg = config.config
13 | cfg.update(job_cfg)
14 | if cfg.output is None:
15 | cfg.output = osp.join('work_dirs', temp_module_name)
16 | return cfg
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/glint360k_mbf.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "cosface"
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.1
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 2e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/glint360k"
21 | config.num_classes = 360232
22 | config.num_image = 17091657
23 | config.num_epoch = 20
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [8, 12, 15, 18]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "cosface"
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/glint360k"
21 | config.num_classes = 360232
22 | config.num_image = 17091657
23 | config.num_epoch = 20
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [8, 12, 15, 18]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r18.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "cosface"
9 | config.network = "r18"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/glint360k"
21 | config.num_classes = 360232
22 | config.num_image = 17091657
23 | config.num_epoch = 20
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [8, 12, 15, 18]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r34.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "cosface"
9 | config.network = "r34"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/glint360k"
21 | config.num_classes = 360232
22 | config.num_image = 17091657
23 | config.num_epoch = 20
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [8, 12, 15, 18]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/glint360k_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "cosface"
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/glint360k"
21 | config.num_classes = 360232
22 | config.num_image = 17091657
23 | config.num_epoch = 20
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [8, 12, 15, 18]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "arcface"
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 2e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/ms1m-retinaface-t1"
21 | config.num_classes = 93431
22 | config.num_image = 5179510
23 | config.num_epoch = 30
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [10, 20, 25]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "arcface"
9 | config.network = "r18"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/ms1m-retinaface-t1"
21 | config.num_classes = 93431
22 | config.num_image = 5179510
23 | config.num_epoch = 25
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [10, 16, 22]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "arcface"
9 | config.network = "r2060"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 64
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/ms1m-retinaface-t1"
21 | config.num_classes = 93431
22 | config.num_image = 5179510
23 | config.num_epoch = 25
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [10, 16, 22]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "arcface"
9 | config.network = "r34"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/ms1m-retinaface-t1"
21 | config.num_classes = 93431
22 | config.num_image = 5179510
23 | config.num_epoch = 25
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [10, 16, 22]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "arcface"
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1 # batch size is 512
19 |
20 | config.rec = "/train_tmp/ms1m-retinaface-t1"
21 | config.num_classes = 93431
22 | config.num_image = 5179510
23 | config.num_epoch = 25
24 | config.warmup_epoch = -1
25 | config.decay_epoch = [10, 16, 22]
26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
27 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/docs/eval.md:
--------------------------------------------------------------------------------
1 | ## Eval on ICCV2021-MFR
2 |
3 | coming soon.
4 |
5 |
6 | ## Eval IJBC
7 | You can eval ijbc with pytorch or onnx.
8 |
9 |
10 | 1. Eval IJBC With Onnx
11 | ```shell
12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
13 | ```
14 |
15 | 2. Eval IJBC With Pytorch
16 | ```shell
17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \
19 | --image-path IJB_release/IJBC \
20 | --result-dir ms1mv3_arcface_r50 \
21 | --batch-size 128 \
22 | --job ms1mv3_arcface_r50 \
23 | --target IJBC \
24 | --network iresnet50
25 | ```
26 |
27 | ## Inference
28 |
29 | ```shell
30 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
31 | ```
32 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/options/test_options.py:
--------------------------------------------------------------------------------
1 | """This script contains the test options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | from .base_options import BaseOptions
5 |
6 |
7 | class TestOptions(BaseOptions):
8 | """This class includes test options.
9 |
10 | It also includes shared options defined in BaseOptions.
11 | """
12 |
13 | def initialize(self, parser):
14 | parser = BaseOptions.initialize(self, parser) # define shared options
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
18 |
19 | # Dropout and Batchnorm has different behavior during training and test.
20 | self.isTrain = False
21 | return parser
22 |
--------------------------------------------------------------------------------
/SadTalker/scripts/test.sh:
--------------------------------------------------------------------------------
1 | # ### some test command before commit.
2 | # python inference.py --preprocess crop --size 256
3 | # python inference.py --preprocess crop --size 512
4 |
5 | # python inference.py --preprocess extcrop --size 256
6 | # python inference.py --preprocess extcrop --size 512
7 |
8 | # python inference.py --preprocess resize --size 256
9 | # python inference.py --preprocess resize --size 512
10 |
11 | # python inference.py --preprocess full --size 256
12 | # python inference.py --preprocess full --size 512
13 |
14 | # python inference.py --preprocess extfull --size 256
15 | # python inference.py --preprocess extfull --size 512
16 |
17 | python inference.py --preprocess full --size 256 --enhancer gfpgan
18 | python inference.py --preprocess full --size 512 --enhancer gfpgan
19 |
20 | python inference.py --preprocess full --size 256 --enhancer gfpgan --still
21 | python inference.py --preprocess full --size 512 --enhancer gfpgan --still
22 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2 | from .mobilefacenet import get_mbf
3 |
4 |
5 | def get_model(name, **kwargs):
6 | # resnet
7 | if name == "r18":
8 | return iresnet18(False, **kwargs)
9 | elif name == "r34":
10 | return iresnet34(False, **kwargs)
11 | elif name == "r50":
12 | return iresnet50(False, **kwargs)
13 | elif name == "r100":
14 | return iresnet100(False, **kwargs)
15 | elif name == "r200":
16 | return iresnet200(False, **kwargs)
17 | elif name == "r2060":
18 | from .iresnet2060 import iresnet2060
19 | return iresnet2060(False, **kwargs)
20 | elif name == "mbf":
21 | fp16 = kwargs.get("fp16", False)
22 | num_features = kwargs.get("num_features", 512)
23 | return get_mbf(fp16=fp16, num_features=num_features)
24 | else:
25 | raise ValueError()
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os, sys
3 |
4 | SadTalkerPath = os.path.abspath(os.path.join(os.path.dirname(__file__), "SadTalker"))
5 | sys.path.append(SadTalkerPath)
6 |
7 | from .nodes.ShowText import ShowText
8 | from .nodes.ShowVideo import ShowVideo
9 | from .nodes.ShowAudio import ShowAudio
10 | from .nodes.LoadRefVideo import LoadRefVideo
11 | from .nodes.SadTalkerNode import SadTalkerNode
12 |
13 | WEB_DIRECTORY = "./web"
14 |
15 | NODE_CLASS_MAPPINGS = {
16 | "SadTalker": SadTalkerNode,
17 | "ShowVideo": ShowVideo,
18 | "ShowText": ShowText,
19 | "ShowAudio": ShowAudio,
20 | "LoadRefVideo": LoadRefVideo,
21 | }
22 |
23 | NODE_DISPLAY_NAME_MAPPINGS = {
24 | "SadTalker": "🦚 SadTalker",
25 | "ShowVideo": "🎥 Show Video",
26 | "LoadRefVideo": "🎥 Load Ref Video",
27 | "ShowText": "💬 Show Text",
28 | "ShowAudio": "🔊 Show Audio",
29 | }
30 |
31 |
32 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
33 |
--------------------------------------------------------------------------------
/SadTalker/src/facerender/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/nodes/ShowText.py:
--------------------------------------------------------------------------------
1 | class ShowText:
2 | @classmethod
3 | def INPUT_TYPES(s):
4 | return {
5 | "required": {
6 | "text": ("STRING", {"forceInput": True}),
7 | },
8 | "hidden": {
9 | "unique_id": "UNIQUE_ID",
10 | "extra_pnginfo": "EXTRA_PNGINFO",
11 | },
12 | }
13 |
14 | INPUT_IS_LIST = True
15 | RETURN_TYPES = ("STRING",)
16 | FUNCTION = "notify"
17 | OUTPUT_NODE = True
18 | OUTPUT_IS_LIST = (True,)
19 |
20 | CATEGORY = "SadTalker"
21 |
22 | def notify(self, text, unique_id=None, extra_pnginfo=None):
23 | if unique_id and extra_pnginfo and "workflow" in extra_pnginfo[0]:
24 | workflow = extra_pnginfo[0]["workflow"]
25 | node = next(
26 | (x for x in workflow["nodes"] if str(x["id"]) == unique_id[0]), None
27 | )
28 | if node:
29 | node["widgets_values"] = [text]
30 | return {"ui": {"text": text}, "result": (text,)}
31 |
--------------------------------------------------------------------------------
/nodes/ShowAudio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torchaudio # type: ignore
3 | import folder_paths # type: ignore
4 |
5 |
6 | class ShowAudio:
7 | SUPPORTED_FORMATS = (".wav", ".mp3", ".ogg", ".flac", ".aiff", ".aif")
8 |
9 | @classmethod
10 | def INPUT_TYPES(s):
11 | input_dir = folder_paths.get_input_directory()
12 | files = [
13 | f
14 | for f in os.listdir(input_dir)
15 | if (
16 | os.path.isfile(os.path.join(input_dir, f))
17 | and f.endswith(ShowAudio.SUPPORTED_FORMATS)
18 | )
19 | ]
20 | return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
21 |
22 | CATEGORY = "SadTalker"
23 |
24 | RETURN_TYPES = ("AUDIO",)
25 | FUNCTION = "load"
26 |
27 | def load(self, audio):
28 | audio_path = folder_paths.get_annotated_filepath(audio)
29 | waveform, sample_rate = torchaudio.load(audio_path)
30 | audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
31 | return (audio,)
32 |
--------------------------------------------------------------------------------
/nodes/LoadRefVideo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import folder_paths # type: ignore
3 | try:
4 | from moviepy import VideoFileClip # type: ignore
5 | except:
6 | from moviepy.editor import VideoFileClip
7 |
8 | input_path = folder_paths.get_input_directory()
9 |
10 |
11 | class LoadRefVideo:
12 | @classmethod
13 | def INPUT_TYPES(s):
14 | files = [
15 | f
16 | for f in os.listdir(input_path)
17 | if os.path.isfile(os.path.join(input_path, f))
18 | and f.split(".")[-1] in ["mp4", "webm", "mkv", "avi"]
19 | ]
20 | return {
21 | "required": {
22 | "video": (files,),
23 | }
24 | }
25 |
26 | CATEGORY = "SadTalker"
27 | RETURN_TYPES = ("VIDEO", "VIDEOSTRING")
28 | RETURN_NAMES = ("video", "video_path")
29 | OUTPUT_NODE = False
30 | FUNCTION = "load_video"
31 |
32 | def load_video(self, video):
33 | video_path = os.path.join(input_path, video)
34 | video_clip = VideoFileClip(video_path)
35 | return (
36 | video_clip,
37 | video_path,
38 | )
39 |
--------------------------------------------------------------------------------
/nodes/ShowVideo.py:
--------------------------------------------------------------------------------
1 | class ShowVideo:
2 | @classmethod
3 | def INPUT_TYPES(s):
4 | return {
5 | "required": {
6 | "show_video_path": ("STRING", {"forceInput": True}),
7 | },
8 | "hidden": {
9 | "unique_id": "UNIQUE_ID",
10 | "extra_pnginfo": "EXTRA_PNGINFO",
11 | },
12 | }
13 |
14 | OUTPUT_NODE = True
15 | INPUT_IS_LIST = True
16 | RETURN_TYPES = ()
17 | OUTPUT_IS_LIST = (True,)
18 |
19 | FUNCTION = "generate"
20 | CATEGORY = "SadTalker"
21 |
22 | def generate(self, show_video_path, unique_id=None, extra_pnginfo=None):
23 | if unique_id and extra_pnginfo and "workflow" in extra_pnginfo[0]:
24 | workflow = extra_pnginfo[0]["workflow"]
25 | node = next(
26 | (x for x in workflow["nodes"] if str(x["id"]) == unique_id[0]), None
27 | )
28 | if node:
29 | node["widgets_values"] = [show_video_path]
30 |
31 | # 返回视频路径
32 | return {
33 | "ui": {"show_video_path": show_video_path},
34 | "result": (show_video_path,),
35 | }
36 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | from backbones import get_model
8 |
9 |
10 | @torch.no_grad()
11 | def inference(weight, name, img):
12 | if img is None:
13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
14 | else:
15 | img = cv2.imread(img)
16 | img = cv2.resize(img, (112, 112))
17 |
18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19 | img = np.transpose(img, (2, 0, 1))
20 | img = torch.from_numpy(img).unsqueeze(0).float()
21 | img.div_(255).sub_(0.5).div_(0.5)
22 | net = get_model(name, fp16=False)
23 | net.load_state_dict(torch.load(weight))
24 | net.eval()
25 | feat = net(img).numpy()
26 | print(feat)
27 |
28 |
29 | if __name__ == "__main__":
30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
31 | parser.add_argument('--network', type=str, default='r50', help='backbone network')
32 | parser.add_argument('--weight', type=str, default='')
33 | parser.add_argument('--img', type=str, default=None)
34 | args = parser.parse_args()
35 | inference(args.weight, args.network, args.img)
36 |
--------------------------------------------------------------------------------
/SadTalker/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | gpu: true
3 | cuda: "11.3"
4 | python_version: "3.8"
5 | system_packages:
6 | - "ffmpeg"
7 | - "libgl1-mesa-glx"
8 | - "libglib2.0-0"
9 | python_packages:
10 | - "torch==1.12.1"
11 | - "torchvision==0.13.1"
12 | - "torchaudio==0.12.1"
13 | - "joblib==1.1.0"
14 | - "scikit-image==0.19.3"
15 | - "basicsr==1.4.2"
16 | - "facexlib==0.3.0"
17 | - "resampy==0.3.1"
18 | - "pydub==0.25.1"
19 | - "scipy==1.10.1"
20 | - "kornia==0.6.8"
21 | - "face_alignment==1.3.5"
22 | - "imageio==2.19.3"
23 | - "imageio-ffmpeg==0.4.7"
24 | - "librosa==0.9.2" #
25 | - "tqdm==4.65.0"
26 | - "yacs==0.1.8"
27 | - "gfpgan==1.3.8"
28 | - "dlib-bin==19.24.1"
29 | - "av==10.0.0"
30 | - "trimesh==3.9.20"
31 | run:
32 | - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"
33 | - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip"
34 |
35 | predict: "predict.py:Predictor"
36 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/options/inference_options.py:
--------------------------------------------------------------------------------
1 | from face3d.options.base_options import BaseOptions
2 |
3 |
4 | class InferenceOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
13 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
14 |
15 | parser.add_argument('--input_dir', type=str, help='the folder of the input files')
16 | parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')
17 | parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')
18 | parser.add_argument('--save_split_files', action='store_true', help='save split files or not')
19 | parser.add_argument('--inference_batch_size', type=int, default=8)
20 |
21 | # Dropout and Batchnorm has different behavior during training and test.
22 | self.isTrain = False
23 | return parser
24 |
--------------------------------------------------------------------------------
/SadTalker/src/config/facerender.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | common_params:
3 | num_kp: 15
4 | image_channel: 3
5 | feature_channel: 32
6 | estimate_jacobian: False # True
7 | kp_detector_params:
8 | temperature: 0.1
9 | block_expansion: 32
10 | max_features: 1024
11 | scale_factor: 0.25 # 0.25
12 | num_blocks: 5
13 | reshape_channel: 16384 # 16384 = 1024 * 16
14 | reshape_depth: 16
15 | he_estimator_params:
16 | block_expansion: 64
17 | max_features: 2048
18 | num_bins: 66
19 | generator_params:
20 | block_expansion: 64
21 | max_features: 512
22 | num_down_blocks: 2
23 | reshape_channel: 32
24 | reshape_depth: 16 # 512 = 32 * 16
25 | num_resblocks: 6
26 | estimate_occlusion_map: True
27 | dense_motion_params:
28 | block_expansion: 32
29 | max_features: 1024
30 | num_blocks: 5
31 | reshape_depth: 16
32 | compress: 4
33 | discriminator_params:
34 | scales: [1]
35 | block_expansion: 32
36 | max_features: 512
37 | num_blocks: 4
38 | sn: True
39 | mapping_params:
40 | coeff_nc: 70
41 | descriptor_nc: 1024
42 | layer: 3
43 | num_kp: 15
44 | num_bins: 66
45 |
46 |
--------------------------------------------------------------------------------
/SadTalker/src/config/facerender_still.yaml:
--------------------------------------------------------------------------------
1 | model_params:
2 | common_params:
3 | num_kp: 15
4 | image_channel: 3
5 | feature_channel: 32
6 | estimate_jacobian: False # True
7 | kp_detector_params:
8 | temperature: 0.1
9 | block_expansion: 32
10 | max_features: 1024
11 | scale_factor: 0.25 # 0.25
12 | num_blocks: 5
13 | reshape_channel: 16384 # 16384 = 1024 * 16
14 | reshape_depth: 16
15 | he_estimator_params:
16 | block_expansion: 64
17 | max_features: 2048
18 | num_bins: 66
19 | generator_params:
20 | block_expansion: 64
21 | max_features: 512
22 | num_down_blocks: 2
23 | reshape_channel: 32
24 | reshape_depth: 16 # 512 = 32 * 16
25 | num_resblocks: 6
26 | estimate_occlusion_map: True
27 | dense_motion_params:
28 | block_expansion: 32
29 | max_features: 1024
30 | num_blocks: 5
31 | reshape_depth: 16
32 | compress: 4
33 | discriminator_params:
34 | scales: [1]
35 | block_expansion: 32
36 | max_features: 512
37 | num_blocks: 4
38 | sn: True
39 | mapping_params:
40 | coeff_nc: 73
41 | descriptor_nc: 1024
42 | layer: 3
43 | num_kp: 15
44 | num_bins: 66
45 |
46 |
--------------------------------------------------------------------------------
/SadTalker/src/config/auido2pose.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3 | EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4 | TRAIN_BATCH_SIZE: 64
5 | EVAL_BATCH_SIZE: 1
6 | EXP: True
7 | EXP_DIM: 64
8 | FRAME_LEN: 32
9 | COEFF_LEN: 73
10 | NUM_CLASSES: 46
11 | AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12 | COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13 | DEBUG: True
14 |
15 |
16 | MODEL:
17 | AUDIOENCODER:
18 | LEAKY_RELU: True
19 | NORM: 'IN'
20 | DISCRIMINATOR:
21 | LEAKY_RELU: False
22 | INPUT_CHANNELS: 6
23 | CVAE:
24 | AUDIO_EMB_IN_SIZE: 512
25 | AUDIO_EMB_OUT_SIZE: 6
26 | SEQ_LEN: 32
27 | LATENT_SIZE: 64
28 | ENCODER_LAYER_SIZES: [192, 128]
29 | DECODER_LAYER_SIZES: [128, 192]
30 |
31 |
32 | TRAIN:
33 | MAX_EPOCH: 150
34 | GENERATOR:
35 | LR: 1.0e-4
36 | DISCRIMINATOR:
37 | LR: 1.0e-4
38 | LOSS:
39 | LAMBDA_REG: 1
40 | LAMBDA_LANDMARKS: 0
41 | LAMBDA_VERTICES: 0
42 | LAMBDA_GAN_MOTION: 0.7
43 | LAMBDA_GAN_COEFF: 0
44 | LAMBDA_KL: 1
45 |
46 | TAG:
47 | NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48 |
49 |
50 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/utils/utils_logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 |
6 | class AverageMeter(object):
7 | """Computes and stores the average and current value
8 | """
9 |
10 | def __init__(self):
11 | self.val = None
12 | self.avg = None
13 | self.sum = None
14 | self.count = None
15 | self.reset()
16 |
17 | def reset(self):
18 | self.val = 0
19 | self.avg = 0
20 | self.sum = 0
21 | self.count = 0
22 |
23 | def update(self, val, n=1):
24 | self.val = val
25 | self.sum += val * n
26 | self.count += n
27 | self.avg = self.sum / self.count
28 |
29 |
30 | def init_logging(rank, models_root):
31 | if rank == 0:
32 | log_root = logging.getLogger()
33 | log_root.setLevel(logging.INFO)
34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
36 | handler_stream = logging.StreamHandler(sys.stdout)
37 | handler_file.setFormatter(formatter)
38 | handler_stream.setFormatter(formatter)
39 | log_root.addHandler(handler_file)
40 | log_root.addHandler(handler_stream)
41 | log_root.info('rank_id: %d' % rank)
42 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def get_loss(name):
6 | if name == "cosface":
7 | return CosFace()
8 | elif name == "arcface":
9 | return ArcFace()
10 | else:
11 | raise ValueError()
12 |
13 |
14 | class CosFace(nn.Module):
15 | def __init__(self, s=64.0, m=0.40):
16 | super(CosFace, self).__init__()
17 | self.s = s
18 | self.m = m
19 |
20 | def forward(self, cosine, label):
21 | index = torch.where(label != -1)[0]
22 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
23 | m_hot.scatter_(1, label[index, None], self.m)
24 | cosine[index] -= m_hot
25 | ret = cosine * self.s
26 | return ret
27 |
28 |
29 | class ArcFace(nn.Module):
30 | def __init__(self, s=64.0, m=0.5):
31 | super(ArcFace, self).__init__()
32 | self.s = s
33 | self.m = m
34 |
35 | def forward(self, cosine: torch.Tensor, label):
36 | index = torch.where(label != -1)[0]
37 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
38 | m_hot.scatter_(1, label[index, None], self.m)
39 | cosine.acos_()
40 | cosine[index] += m_hot
41 | cosine.cos_().mul_(self.s)
42 | return cosine
43 |
--------------------------------------------------------------------------------
/SadTalker/src/audio2exp_models/audio2exp.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class Audio2Exp(nn.Module):
7 | def __init__(self, netG, cfg, device, prepare_training_loss=False):
8 | super(Audio2Exp, self).__init__()
9 | self.cfg = cfg
10 | self.device = device
11 | self.netG = netG.to(device)
12 |
13 | def test(self, batch):
14 |
15 | mel_input = batch['indiv_mels'] # bs T 1 80 16
16 | bs = mel_input.shape[0]
17 | T = mel_input.shape[1]
18 |
19 | exp_coeff_pred = []
20 |
21 | for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22 |
23 | current_mel_input = mel_input[:,i:i+10]
24 |
25 | #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26 | ref = batch['ref'][:, :, :64][:, i:i+10]
27 | ratio = batch['ratio_gt'][:, i:i+10] #bs T
28 |
29 | audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30 |
31 | curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32 |
33 | exp_coeff_pred += [curr_exp_coeff_pred]
34 |
35 | # BS x T x 64
36 | results_dict = {
37 | 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38 | }
39 | return results_dict
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SadTalker/src/config/auido2exp.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3 | EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4 | TRAIN_BATCH_SIZE: 32
5 | EVAL_BATCH_SIZE: 32
6 | EXP: True
7 | EXP_DIM: 64
8 | FRAME_LEN: 32
9 | COEFF_LEN: 73
10 | NUM_CLASSES: 46
11 | AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12 | COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13 | LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14 | DEBUG: True
15 | NUM_REPEATS: 2
16 | T: 40
17 |
18 |
19 | MODEL:
20 | FRAMEWORK: V2
21 | AUDIOENCODER:
22 | LEAKY_RELU: True
23 | NORM: 'IN'
24 | DISCRIMINATOR:
25 | LEAKY_RELU: False
26 | INPUT_CHANNELS: 6
27 | CVAE:
28 | AUDIO_EMB_IN_SIZE: 512
29 | AUDIO_EMB_OUT_SIZE: 128
30 | SEQ_LEN: 32
31 | LATENT_SIZE: 256
32 | ENCODER_LAYER_SIZES: [192, 1024]
33 | DECODER_LAYER_SIZES: [1024, 192]
34 |
35 |
36 | TRAIN:
37 | MAX_EPOCH: 300
38 | GENERATOR:
39 | LR: 2.0e-5
40 | DISCRIMINATOR:
41 | LR: 1.0e-5
42 | LOSS:
43 | W_FEAT: 0
44 | W_COEFF_EXP: 2
45 | W_LM: 1.0e-2
46 | W_LM_MOUTH: 0
47 | W_REG: 0
48 | W_SYNC: 0
49 | W_COLOR: 0
50 | W_EXPRESSION: 0
51 | W_LIPREADING: 0.01
52 | W_LIPREADING_VV: 0
53 | W_EYE_BLINK: 4
54 |
55 | TAG:
56 | NAME: small_dataset
57 |
58 |
59 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/generate_list.py:
--------------------------------------------------------------------------------
1 | """This script is to generate training list files for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 |
6 | # save path to training data
7 | def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
8 | save_path = os.path.join(save_folder, mode)
9 | if not os.path.isdir(save_path):
10 | os.makedirs(save_path)
11 | with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
12 | fd.writelines([i + '\n' for i in lms_list])
13 |
14 | with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
15 | fd.writelines([i + '\n' for i in imgs_list])
16 |
17 | with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
18 | fd.writelines([i + '\n' for i in msks_list])
19 |
20 | # check if the path is valid
21 | def check_list(rlms_list, rimgs_list, rmsks_list):
22 | lms_list, imgs_list, msks_list = [], [], []
23 | for i in range(len(rlms_list)):
24 | flag = 'false'
25 | lm_path = rlms_list[i]
26 | im_path = rimgs_list[i]
27 | msk_path = rmsks_list[i]
28 | if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
29 | flag = 'true'
30 | lms_list.append(rlms_list[i])
31 | imgs_list.append(rimgs_list[i])
32 | msks_list.append(rmsks_list[i])
33 | print(i, rlms_list[i], flag)
34 | return lms_list, imgs_list, msks_list
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Comfyui-SadTalker
2 |
3 | ### 版权声明(Copyright)
4 |
5 | - [SadTalker](https://github.com/OpenTalker/SadTalker)
6 |
7 | ### 注意(Notice)
8 |
9 | - 首次开发插件,节点可能存在问题,如果遇到了困难,请提交 Issues.(This is the first time to develop a plugin. There may be problems with the node. If you run into difficulties, please submit Issues.)
10 | - 使用ComfyUi-aki-v1.3.7z(新版本未适配)
11 |
12 | ### 例子(Examples)
13 |
14 | 
15 |
16 | ### 可能遇到的问题(Problems That You May Have)
17 |
18 | - 如果报错 audio 不存在或者未定义,请升级 comfyui.(If you get the error that audio does not exist or is undefined, upgrade comfyui.)
19 | - 如果 video 显示异常,请刷新浏览器.(Reload your browser if the video displays abnormally.)
20 | - 如果报错:FileNotFoundError: [Errno 2] No such file or directory: 'F:\\\\ComfyUI-aki-v1.3\\output\\6b5735c4-ef50-4933-a34d-ccbad3b5936d.mp4',请下载ffmpeg并且将ffmpeg/bin目录加入环境变量中.(If you get an error: FileNotFoundError: [Errno 2] No such file or directory: 'F:\\\\ComfyUI-aki-v1.3\\output\\6b5735c4-ef50-4933-a34d-ccbad3b5936d.mp4', download ffmpeg and add the ffmpeg/bin directory to the environment variable.)
21 |
22 | ### 安装(Install)
23 |
24 | 1. ...custom_nodes\Comfyui-SadTalker\SadTalker\checkpoints\
25 | - SadTalker_V0.0.2_256.safetensors(691MB)
26 | - SadTalker_V0.0.2_512.safetensors(691MB)
27 | - mapping_00109-model.pth.tar(148MB)
28 | - mapping_00229-model.pth.tar(148MB)
29 | 2. ...(comfyui root)\gfpgan\weights\
30 | - alignment_WFLW_4HG.pth(184MB)
31 | - detection_Resnet50_Final.pth(104MB)
32 | - GFPGANv1.4.pth(332MB)
33 | - parsing_parsenet.pth(81.3MB)
34 |
35 | ### 联系我(Contact Me)
36 |
37 |
38 |
--------------------------------------------------------------------------------
/SadTalker/src/facerender/modules/mapping.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class MappingNet(nn.Module):
9 | def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
10 | super( MappingNet, self).__init__()
11 |
12 | self.layer = layer
13 | nonlinearity = nn.LeakyReLU(0.1)
14 |
15 | self.first = nn.Sequential(
16 | torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
17 |
18 | for i in range(layer):
19 | net = nn.Sequential(nonlinearity,
20 | torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
21 | setattr(self, 'encoder' + str(i), net)
22 |
23 | self.pooling = nn.AdaptiveAvgPool1d(1)
24 | self.output_nc = descriptor_nc
25 |
26 | self.fc_roll = nn.Linear(descriptor_nc, num_bins)
27 | self.fc_pitch = nn.Linear(descriptor_nc, num_bins)
28 | self.fc_yaw = nn.Linear(descriptor_nc, num_bins)
29 | self.fc_t = nn.Linear(descriptor_nc, 3)
30 | self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp)
31 |
32 | def forward(self, input_3dmm):
33 | out = self.first(input_3dmm)
34 | for i in range(self.layer):
35 | model = getattr(self, 'encoder' + str(i))
36 | out = model(out) + out[:,:,3:-3]
37 | out = self.pooling(out)
38 | out = out.view(out.shape[0], -1)
39 | #print('out:', out.shape)
40 |
41 | yaw = self.fc_yaw(out)
42 | pitch = self.fc_pitch(out)
43 | roll = self.fc_roll(out)
44 | t = self.fc_t(out)
45 | exp = self.fc_exp(out)
46 |
47 | return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/docs/install.md:
--------------------------------------------------------------------------------
1 | ## v1.8.0
2 | ### Linux and Windows
3 | ```shell
4 | # CUDA 11.0
5 | pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
6 |
7 | # CUDA 10.2
8 | pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
9 |
10 | # CPU only
11 | pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
12 |
13 | ```
14 |
15 |
16 | ## v1.7.1
17 | ### Linux and Windows
18 | ```shell
19 | # CUDA 11.0
20 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
21 |
22 | # CUDA 10.2
23 | pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
24 |
25 | # CUDA 10.1
26 | pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
27 |
28 | # CUDA 9.2
29 | pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
30 |
31 | # CPU only
32 | pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
33 | ```
34 |
35 |
36 | ## v1.6.0
37 |
38 | ### Linux and Windows
39 | ```shell
40 | # CUDA 10.2
41 | pip install torch==1.6.0 torchvision==0.7.0
42 |
43 | # CUDA 10.1
44 | pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
45 |
46 | # CUDA 9.2
47 | pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
48 |
49 | # CPU only
50 | pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
51 | ```
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/configs/base.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.loss = "arcface"
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = "ms1mv3_arcface_r50"
12 |
13 | config.dataset = "ms1m-retinaface-t1"
14 | config.embedding_size = 512
15 | config.sample_rate = 1
16 | config.fp16 = False
17 | config.momentum = 0.9
18 | config.weight_decay = 5e-4
19 | config.batch_size = 128
20 | config.lr = 0.1 # batch size is 512
21 |
22 | if config.dataset == "emore":
23 | config.rec = "/train_tmp/faces_emore"
24 | config.num_classes = 85742
25 | config.num_image = 5822653
26 | config.num_epoch = 16
27 | config.warmup_epoch = -1
28 | config.decay_epoch = [8, 14, ]
29 | config.val_targets = ["lfw", ]
30 |
31 | elif config.dataset == "ms1m-retinaface-t1":
32 | config.rec = "/train_tmp/ms1m-retinaface-t1"
33 | config.num_classes = 93431
34 | config.num_image = 5179510
35 | config.num_epoch = 25
36 | config.warmup_epoch = -1
37 | config.decay_epoch = [11, 17, 22]
38 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39 |
40 | elif config.dataset == "glint360k":
41 | config.rec = "/train_tmp/glint360k"
42 | config.num_classes = 360232
43 | config.num_image = 17091657
44 | config.num_epoch = 20
45 | config.warmup_epoch = -1
46 | config.decay_epoch = [8, 12, 15, 18]
47 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48 |
49 | elif config.dataset == "webface":
50 | config.rec = "/train_tmp/faces_webface_112x112"
51 | config.num_classes = 10572
52 | config.num_image = "forget"
53 | config.num_epoch = 34
54 | config.warmup_epoch = -1
55 | config.decay_epoch = [20, 28, 32]
56 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
57 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/visualize.py:
--------------------------------------------------------------------------------
1 | # check the sync of 3dmm feature and the audio
2 | import cv2
3 | import numpy as np
4 | from src.face3d.models.bfm import ParametricFaceModel
5 | from src.face3d.models.facerecon_model import FaceReconModel
6 | import torch
7 | import subprocess, platform
8 | import scipy.io as scio
9 | from tqdm import tqdm
10 |
11 | # draft
12 | def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64):
13 |
14 | coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
15 |
16 | coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
17 |
18 | coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
19 |
20 | coeff_full[:, 80:144] = coeff_pred[:, 0:64]
21 | coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
22 | coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
23 |
24 | tmp_video_path = '/tmp/face3dtmp.mp4'
25 |
26 | facemodel = FaceReconModel(args)
27 |
28 | video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
29 |
30 | for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
31 | cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
32 |
33 | facemodel.forward(cur_coeff_full, device)
34 |
35 | predicted_landmark = facemodel.pred_lm # TODO.
36 | predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
37 |
38 | rendered_img = facemodel.pred_face
39 | rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
40 | out_img = rendered_img[:, :, :3].astype(np.uint8)
41 |
42 | video.write(np.uint8(out_img[:,:,::-1]))
43 |
44 | video.release()
45 |
46 | command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
47 | subprocess.call(command, shell=platform.system() != 'Windows')
48 |
49 |
--------------------------------------------------------------------------------
/web/js/showVideo.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { api } from "../../../scripts/api.js";
3 | import { ComfyWidgets } from "../../../scripts/widgets.js";
4 |
5 | let video;
6 |
7 | const ext = {
8 | name: "SadTalker.ShowVideo",
9 |
10 | async beforeRegisterNodeDef (nodeType, nodeData, app) {
11 | if (nodeData.name === "ShowVideo") {
12 | function populate (videoPath) {
13 |
14 | const filePath = videoPath[0]
15 | if (video && filePath) {
16 | const fileName = filePath.split('\\').pop();
17 | video.src = api.apiURL(`/view?filename=${fileName}`);
18 | video.play();
19 | }
20 | }
21 |
22 | // When the node is executed we will be sent the input video path, display this in the widget
23 | const onExecuted = nodeType.prototype.onExecuted;
24 | nodeType.prototype.onExecuted = function (message) {
25 | onExecuted?.apply(this, arguments);
26 | populate.call(this, message.show_video_path);
27 | };
28 |
29 | const onConfigure = nodeType.prototype.onConfigure;
30 | nodeType.prototype.onConfigure = function () {
31 | onConfigure?.apply(this, arguments);
32 | if (this.widgets_values?.length) {
33 | populate.call(this, this.widgets_values);
34 | }
35 | };
36 | }
37 | },
38 |
39 | loadedGraphNode (node, app) {
40 | if (node.type == "ShowVideo") {
41 | if (video) return;
42 | const container = document.createElement("div");
43 | container.style.background = "rgba(0,0,0,0.25)";
44 | container.style.textAlign = "center";
45 | video = document.createElement("video")
46 | video.style.height = video.style.width = "100%";
47 | video.controls = true
48 | video.classList.add("comfy-video")
49 | video.setAttribute("name", "media")
50 | container.replaceChildren(video);
51 | node.addDOMWidget("video", "VIDEO", container)
52 | }
53 | }
54 | };
55 |
56 |
57 |
58 | app.registerExtension(ext);
59 |
--------------------------------------------------------------------------------
/web/js/showText.js:
--------------------------------------------------------------------------------
1 | import { app } from "../../../scripts/app.js";
2 | import { ComfyWidgets } from "../../../scripts/widgets.js";
3 |
4 | const ext = {
5 | name: "SadTalker.ShowText",
6 |
7 | async beforeRegisterNodeDef (nodeType, nodeData, app) {
8 | if (nodeData.name === "ShowText") {
9 | function populate (text) {
10 | if (this.widgets) {
11 | for (let i = 1; i < this.widgets.length; i++) {
12 | this.widgets[i].onRemove?.();
13 | }
14 | this.widgets.length = 1;
15 | }
16 |
17 | const v = [...text];
18 | if (!v[0]) {
19 | v.shift();
20 | }
21 | for (const list of v) {
22 | const w = ComfyWidgets["STRING"](this, "text", ["STRING", { multiline: true }], app).widget;
23 | w.inputEl.readOnly = true;
24 | w.inputEl.style.opacity = 0.6;
25 | w.value = list;
26 | }
27 |
28 | requestAnimationFrame(() => {
29 | const sz = this.computeSize();
30 | if (sz[0] < this.size[0]) {
31 | sz[0] = this.size[0];
32 | }
33 | if (sz[1] < this.size[1]) {
34 | sz[1] = this.size[1];
35 | }
36 | this.onResize?.(sz);
37 | app.graph.setDirtyCanvas(true, false);
38 | });
39 | }
40 |
41 | // When the node is executed we will be sent the input text, display this in the widget
42 | const onExecuted = nodeType.prototype.onExecuted;
43 | nodeType.prototype.onExecuted = function (message) {
44 | onExecuted?.apply(this, arguments);
45 | populate.call(this, message.text);
46 | };
47 |
48 | const onConfigure = nodeType.prototype.onConfigure;
49 | nodeType.prototype.onConfigure = function () {
50 | onConfigure?.apply(this, arguments);
51 | if (this.widgets_values?.length) {
52 | populate.call(this, this.widgets_values);
53 | }
54 | };
55 | }
56 |
57 | // This fires for every node definition so only log once
58 | // delete ext.beforeRegisterNodeDef;
59 | },
60 | };
61 |
62 | app.registerExtension(ext);
63 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 | import numpy as np
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | '.tif', '.TIF', '.tiff', '.TIFF',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir, max_dataset_size=float("inf")):
25 | images = []
26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 | return images[:min(max_dataset_size, len(images))]
34 |
35 |
36 | def default_loader(path):
37 | return Image.open(path).convert('RGB')
38 |
39 |
40 | class ImageFolder(data.Dataset):
41 |
42 | def __init__(self, root, transform=None, return_paths=False,
43 | loader=default_loader):
44 | imgs = make_dataset(root)
45 | if len(imgs) == 0:
46 | raise(RuntimeError("Found 0 images in: " + root + "\n"
47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/SadTalker/src/audio2pose_models/res_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from src.audio2pose_models.networks import ResidualConv, Upsample
4 |
5 |
6 | class ResUnet(nn.Module):
7 | def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8 | super(ResUnet, self).__init__()
9 |
10 | self.input_layer = nn.Sequential(
11 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12 | nn.BatchNorm2d(filters[0]),
13 | nn.ReLU(),
14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15 | )
16 | self.input_skip = nn.Sequential(
17 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18 | )
19 |
20 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22 |
23 | self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24 |
25 | self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26 | self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27 |
28 | self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29 | self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30 |
31 | self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32 | self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33 |
34 | self.output_layer = nn.Sequential(
35 | nn.Conv2d(filters[0], 1, 1, 1),
36 | nn.Sigmoid(),
37 | )
38 |
39 | def forward(self, x):
40 | # Encode
41 | x1 = self.input_layer(x) + self.input_skip(x)
42 | x2 = self.residual_conv_1(x1)
43 | x3 = self.residual_conv_2(x2)
44 | # Bridge
45 | x4 = self.bridge(x3)
46 |
47 | # Decode
48 | x4 = self.upsample_1(x4)
49 | x5 = torch.cat([x4, x3], dim=1)
50 |
51 | x6 = self.up_residual_conv1(x5)
52 |
53 | x6 = self.upsample_2(x6)
54 | x7 = torch.cat([x6, x2], dim=1)
55 |
56 | x8 = self.up_residual_conv2(x7)
57 |
58 | x8 = self.upsample_3(x8)
59 | x9 = torch.cat([x8, x1], dim=1)
60 |
61 | x10 = self.up_residual_conv3(x9)
62 |
63 | output = self.output_layer(x10)
64 |
65 | return output
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/utils/plot.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 |
3 | import os
4 | from pathlib import Path
5 |
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import pandas as pd
9 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
10 | from prettytable import PrettyTable
11 | from sklearn.metrics import roc_curve, auc
12 |
13 | image_path = "/data/anxiang/IJB_release/IJBC"
14 | files = [
15 | "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
16 | ]
17 |
18 |
19 | def read_template_pair_list(path):
20 | pairs = pd.read_csv(path, sep=' ', header=None).values
21 | t1 = pairs[:, 0].astype(np.int)
22 | t2 = pairs[:, 1].astype(np.int)
23 | label = pairs[:, 2].astype(np.int)
24 | return t1, t2, label
25 |
26 |
27 | p1, p2, label = read_template_pair_list(
28 | os.path.join('%s/meta' % image_path,
29 | '%s_template_pair_label.txt' % 'ijbc'))
30 |
31 | methods = []
32 | scores = []
33 | for file in files:
34 | methods.append(file.split('/')[-2])
35 | scores.append(np.load(file))
36 |
37 | methods = np.array(methods)
38 | scores = dict(zip(methods, scores))
39 | colours = dict(
40 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
41 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
42 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
43 | fig = plt.figure()
44 | for method in methods:
45 | fpr, tpr, _ = roc_curve(label, scores[method])
46 | roc_auc = auc(fpr, tpr)
47 | fpr = np.flipud(fpr)
48 | tpr = np.flipud(tpr) # select largest tpr at same fpr
49 | plt.plot(fpr,
50 | tpr,
51 | color=colours[method],
52 | lw=1,
53 | label=('[%s (AUC = %0.4f %%)]' %
54 | (method.split('-')[-1], roc_auc * 100)))
55 | tpr_fpr_row = []
56 | tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
57 | for fpr_iter in np.arange(len(x_labels)):
58 | _, min_index = min(
59 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
60 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
61 | tpr_fpr_table.add_row(tpr_fpr_row)
62 | plt.xlim([10 ** -6, 0.1])
63 | plt.ylim([0.3, 1.0])
64 | plt.grid(linestyle='--', linewidth=1)
65 | plt.xticks(x_labels)
66 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
67 | plt.xscale('log')
68 | plt.xlabel('False Positive Rate')
69 | plt.ylabel('True Positive Rate')
70 | plt.title('ROC on IJB')
71 | plt.legend(loc="lower right")
72 | print(tpr_fpr_table)
73 |
--------------------------------------------------------------------------------
/SadTalker/check_ffmpeg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import subprocess
4 | from pathlib import Path
5 | import winreg
6 |
7 | def check_ffmpeg_path():
8 | path_list = os.environ['Path'].split(';')
9 | ffmpeg_found = False
10 |
11 | for path in path_list:
12 | if 'ffmpeg' in path.lower() and 'bin' in path.lower():
13 | ffmpeg_found = True
14 | print("FFmpeg already installed, skipping...")
15 | break
16 |
17 | return ffmpeg_found
18 |
19 | def add_ffmpeg_path_to_user_variable():
20 | ffmpeg_bin_path = Path('.\\ffmpeg\\bin')
21 | if ffmpeg_bin_path.is_dir():
22 | abs_path = str(ffmpeg_bin_path.resolve())
23 |
24 | try:
25 | key = winreg.OpenKey(
26 | winreg.HKEY_CURRENT_USER,
27 | r"Environment",
28 | 0,
29 | winreg.KEY_READ | winreg.KEY_WRITE
30 | )
31 |
32 | try:
33 | current_path, _ = winreg.QueryValueEx(key, "Path")
34 | if abs_path not in current_path:
35 | new_path = f"{current_path};{abs_path}"
36 | winreg.SetValueEx(key, "Path", 0, winreg.REG_EXPAND_SZ, new_path)
37 | print(f"Added FFmpeg path to user variable 'Path': {abs_path}")
38 | else:
39 | print("FFmpeg path already exists in the user variable 'Path'.")
40 | finally:
41 | winreg.CloseKey(key)
42 | except WindowsError:
43 | print("Error: Unable to modify user variable 'Path'.")
44 | sys.exit(1)
45 |
46 | else:
47 | print("Error: ffmpeg\\bin folder not found in the current path.")
48 | sys.exit(1)
49 |
50 | def get_default_browser():
51 | browser = ""
52 | try:
53 | cmd = r'reg query HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice /v ProgId'
54 | output = subprocess.check_output(cmd, shell=True).decode()
55 | browser = output.split()[-1].split('\\')[-1]
56 | except Exception as e:
57 | print(f"Error: {e}")
58 | browser = "Unknown"
59 | return browser
60 |
61 | def main():
62 | if not check_ffmpeg_path():
63 | add_ffmpeg_path_to_user_variable()
64 | default_browser = get_default_browser()
65 | if not "chrome" in default_browser.lower() and not "edge" in default_browser.lower() and not "firefox" in default_browser.lower():
66 | print("默认浏览器不符合要求,可能会影响使用,请更换为Chrome或Edge浏览器")
67 |
68 | if __name__ == "__main__":
69 | main()
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/torch2onnx.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import onnx
3 | import torch
4 |
5 |
6 | def convert_onnx(net, path_module, output, opset=11, simplify=False):
7 | assert isinstance(net, torch.nn.Module)
8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
9 | img = img.astype(np.float64)
10 | img = (img / 255.0 - 0.5) / 0.5 # torch style norm
11 | img = img.transpose((2, 0, 1))
12 | img = torch.from_numpy(img).unsqueeze(0).float()
13 |
14 | weight = torch.load(path_module)
15 | net.load_state_dict(weight)
16 | net.eval()
17 | torch.onnx.export(
18 | net,
19 | img,
20 | output,
21 | keep_initializers_as_inputs=False,
22 | verbose=False,
23 | opset_version=opset,
24 | )
25 | model = onnx.load(output)
26 | graph = model.graph
27 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None"
28 | if simplify:
29 | from onnxsim import simplify
30 |
31 | model, check = simplify(model)
32 | assert check, "Simplified ONNX model could not be validated"
33 | onnx.save(model, output)
34 |
35 |
36 | if __name__ == "__main__":
37 | import os
38 | import argparse
39 | from backbones import get_model
40 |
41 | parser = argparse.ArgumentParser(description="ArcFace PyTorch to onnx")
42 | parser.add_argument("input", type=str, help="input backbone.pth file or path")
43 | parser.add_argument("--output", type=str, default=None, help="output onnx path")
44 | parser.add_argument("--network", type=str, default=None, help="backbone network")
45 | parser.add_argument("--simplify", type=bool, default=False, help="onnx simplify")
46 | args = parser.parse_args()
47 | input_file = args.input
48 | if os.path.isdir(input_file):
49 | input_file = os.path.join(input_file, "backbone.pth")
50 | assert os.path.exists(input_file)
51 | model_name = os.path.basename(os.path.dirname(input_file)).lower()
52 | params = model_name.split("_")
53 | if len(params) >= 3 and params[1] in ("arcface", "cosface"):
54 | if args.network is None:
55 | args.network = params[2]
56 | assert args.network is not None
57 | print(args)
58 | backbone_onnx = get_model(args.network, dropout=0)
59 |
60 | output_path = args.output
61 | if output_path is None:
62 | output_path = os.path.join(os.path.dirname(__file__), "onnx")
63 | if not os.path.exists(output_path):
64 | os.makedirs(output_path)
65 | assert os.path.isdir(output_path)
66 | output_file = os.path.join(output_path, "%s.onnx" % model_name)
67 | convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
68 |
--------------------------------------------------------------------------------
/SadTalker/src/utils/init_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):
5 |
6 | if old_version:
7 | #### load all the checkpoint of `pth`
8 | sadtalker_paths = {
9 | 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
10 | 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
11 | 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
12 | 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
13 | 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
14 | }
15 |
16 | use_safetensor = False
17 | elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))):
18 | print('using safetensor as default')
19 | sadtalker_paths = {
20 | "checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'),
21 | }
22 | use_safetensor = True
23 | else:
24 | print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!")
25 | use_safetensor = False
26 |
27 | sadtalker_paths = {
28 | 'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
29 | 'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
30 | 'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
31 | 'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
32 | 'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
33 | }
34 |
35 | sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting'
36 | sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml')
37 | sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml')
38 | sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml')
39 |
40 | if 'full' in preprocess:
41 | sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar')
42 | sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml')
43 | else:
44 | sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar')
45 | sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml')
46 |
47 | return sadtalker_paths
--------------------------------------------------------------------------------
/SadTalker/src/utils/paste_pic.py:
--------------------------------------------------------------------------------
1 | import cv2, os
2 | import numpy as np
3 | from tqdm import tqdm
4 | import uuid
5 |
6 | from src.utils.videoio import save_video_with_watermark
7 |
8 |
9 | def paste_pic(
10 | video_path,
11 | pic_path,
12 | crop_info,
13 | new_audio_path,
14 | full_video_path,
15 | extended_crop=False,
16 | ):
17 |
18 | if not os.path.isfile(pic_path):
19 | raise ValueError("pic_path must be a valid path to video/image file")
20 | elif pic_path.split(".")[-1] in ["jpg", "png", "jpeg"]:
21 | # loader for first frame
22 | full_img = cv2.imread(pic_path)
23 | else:
24 | # loader for videos
25 | video_stream = cv2.VideoCapture(pic_path)
26 | fps = video_stream.get(cv2.CAP_PROP_FPS)
27 | full_frames = []
28 | while 1:
29 | still_reading, frame = video_stream.read()
30 | if not still_reading:
31 | video_stream.release()
32 | break
33 | break
34 | full_img = frame
35 | frame_h = full_img.shape[0]
36 | frame_w = full_img.shape[1]
37 |
38 | video_stream = cv2.VideoCapture(video_path)
39 | fps = video_stream.get(cv2.CAP_PROP_FPS)
40 | crop_frames = []
41 | while 1:
42 | still_reading, frame = video_stream.read()
43 | if not still_reading:
44 | video_stream.release()
45 | break
46 | crop_frames.append(frame)
47 |
48 | if len(crop_info) != 3:
49 | print("you didn't crop the image")
50 | return
51 | else:
52 | r_w, r_h = crop_info[0]
53 | clx, cly, crx, cry = crop_info[1]
54 | lx, ly, rx, ry = crop_info[2]
55 | lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
56 | # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
57 | # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
58 |
59 | if extended_crop:
60 | oy1, oy2, ox1, ox2 = cly, cry, clx, crx
61 | else:
62 | oy1, oy2, ox1, ox2 = cly + ly, cly + ry, clx + lx, clx + rx
63 |
64 | tmp_path = str(uuid.uuid4()) + ".mp4"
65 | out_tmp = cv2.VideoWriter(
66 | tmp_path, cv2.VideoWriter_fourcc(*"MP4V"), fps, (frame_w, frame_h)
67 | )
68 | for crop_frame in tqdm(crop_frames, "seamlessClone:"):
69 | p = cv2.resize(crop_frame.astype(np.uint8), (ox2 - ox1, oy2 - oy1))
70 |
71 | mask = 255 * np.ones(p.shape, p.dtype)
72 | location = ((ox1 + ox2) // 2, (oy1 + oy2) // 2)
73 | gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE)
74 | out_tmp.write(gen_img)
75 |
76 | out_tmp.release()
77 |
78 | save_video_with_watermark(
79 | tmp_path, new_audio_path, full_video_path, watermark=False
80 | )
81 | os.remove(tmp_path)
82 |
--------------------------------------------------------------------------------
/SadTalker/src/utils/videoio.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import uuid
3 | import os
4 | import cv2 # type: ignore
5 | import logging
6 | import folder_paths # type: ignore
7 |
8 | output_directory = folder_paths.get_output_directory()
9 |
10 |
11 | def load_video_to_cv2(input_path):
12 | video_stream = cv2.VideoCapture(input_path)
13 | fps = video_stream.get(cv2.CAP_PROP_FPS)
14 | full_frames = []
15 | while True:
16 | still_reading, frame = video_stream.read()
17 | if not still_reading:
18 | video_stream.release()
19 | break
20 | full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
21 | return full_frames, fps
22 |
23 |
24 | def save_video_with_watermark(video, audio, save_path, watermark=False, target_fps=30):
25 | temp_video = os.path.join(output_directory, str(uuid.uuid4()) + "_video.mp4")
26 |
27 | cmd = (r'ffmpeg -y -hide_banner -loglevel error -i "%s" -r %d -qscale 0 "%s"') % (
28 | video,
29 | target_fps,
30 | temp_video,
31 | )
32 | os.system(cmd)
33 |
34 | temp_file = os.path.join(output_directory, str(uuid.uuid4()) + ".mp4")
35 | cmd = (
36 | r'ffmpeg -i "%s" -i "%s" '
37 | r"-c:v libx264 -b:v 5000k -maxrate 5000k -bufsize 10000k "
38 | r'-c:a aac -b:a 122k -ac 2 -ar 44100 -strict experimental "%s" '
39 | ) % (temp_video, audio, temp_file)
40 |
41 | os.system(cmd)
42 |
43 | try:
44 | if not watermark:
45 | # 如果不需要水印,直接将临时文件移动到保存路径
46 | shutil.copy(temp_file, save_path)
47 | else:
48 | # 获取水印图片路径
49 | dir_path = os.path.dirname(os.path.realpath(__file__))
50 | watermark_path = os.path.join(dir_path, "../../docs/sadtalker_logo.png")
51 |
52 | # 构建 FFmpeg 命令以添加水印
53 | cmd = [
54 | "ffmpeg",
55 | "-y",
56 | "-hide_banner",
57 | "-loglevel",
58 | "error",
59 | "-i",
60 | temp_file,
61 | "-i",
62 | watermark_path,
63 | "-filter_complex",
64 | "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10",
65 | "-c:v",
66 | "copy",
67 | "-c:a",
68 | "aac",
69 | "-b:a",
70 | "128k",
71 | "-ar",
72 | "16000",
73 | "-shortest",
74 | save_path,
75 | ]
76 |
77 | os.system(cmd)
78 | except Exception as e:
79 | logging.error("Error:", e)
80 |
81 | # 确保临时视频文件被删除
82 | if os.path.exists(temp_video):
83 | os.remove(temp_video)
84 | if os.path.exists(temp_file):
85 | os.remove(temp_file)
86 |
--------------------------------------------------------------------------------
/SadTalker/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | mkdir ./checkpoints
2 |
3 | # lagency download link
4 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
5 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
6 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
7 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
8 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
9 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
10 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
11 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
12 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
13 | # unzip -n ./checkpoints/hub.zip -d ./checkpoints/
14 |
15 |
16 | #### download the new links.
17 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
18 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
19 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors
20 | wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors
21 |
22 |
23 | # wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
24 | # unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
25 |
26 | ### enhancer
27 | mkdir -p ./gfpgan/weights
28 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth
29 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth
30 | wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth
31 | wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth
32 |
33 |
--------------------------------------------------------------------------------
/SadTalker/src/audio2pose_models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | class ConvNormRelu(nn.Module):
6 | def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7 | kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8 | super().__init__()
9 | if kernel_size is None:
10 | if downsample:
11 | kernel_size, stride, padding = 4, 2, 1
12 | else:
13 | kernel_size, stride, padding = 3, 1, 1
14 |
15 | if conv_type == '2d':
16 | self.conv = nn.Conv2d(
17 | in_channels,
18 | out_channels,
19 | kernel_size,
20 | stride,
21 | padding,
22 | bias=False,
23 | )
24 | if norm == 'BN':
25 | self.norm = nn.BatchNorm2d(out_channels)
26 | elif norm == 'IN':
27 | self.norm = nn.InstanceNorm2d(out_channels)
28 | else:
29 | raise NotImplementedError
30 | elif conv_type == '1d':
31 | self.conv = nn.Conv1d(
32 | in_channels,
33 | out_channels,
34 | kernel_size,
35 | stride,
36 | padding,
37 | bias=False,
38 | )
39 | if norm == 'BN':
40 | self.norm = nn.BatchNorm1d(out_channels)
41 | elif norm == 'IN':
42 | self.norm = nn.InstanceNorm1d(out_channels)
43 | else:
44 | raise NotImplementedError
45 | nn.init.kaiming_normal_(self.conv.weight)
46 |
47 | self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48 |
49 | def forward(self, x):
50 | x = self.conv(x)
51 | if isinstance(self.norm, nn.InstanceNorm1d):
52 | x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53 | else:
54 | x = self.norm(x)
55 | x = self.act(x)
56 | return x
57 |
58 |
59 | class PoseSequenceDiscriminator(nn.Module):
60 | def __init__(self, cfg):
61 | super().__init__()
62 | self.cfg = cfg
63 | leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64 |
65 | self.seq = nn.Sequential(
66 | ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67 | ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68 | ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69 | nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70 | )
71 |
72 | def forward(self, x):
73 | x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74 | x = self.seq(x)
75 | x = x.squeeze(1)
76 | return x
--------------------------------------------------------------------------------
/SadTalker/src/audio2pose_models/audio_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | class Conv2d(nn.Module):
6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7 | super().__init__(*args, **kwargs)
8 | self.conv_block = nn.Sequential(
9 | nn.Conv2d(cin, cout, kernel_size, stride, padding),
10 | nn.BatchNorm2d(cout)
11 | )
12 | self.act = nn.ReLU()
13 | self.residual = residual
14 |
15 | def forward(self, x):
16 | out = self.conv_block(x)
17 | if self.residual:
18 | out += x
19 | return self.act(out)
20 |
21 | class AudioEncoder(nn.Module):
22 | def __init__(self, wav2lip_checkpoint, device):
23 | super(AudioEncoder, self).__init__()
24 |
25 | self.audio_encoder = nn.Sequential(
26 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29 |
30 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33 |
34 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37 |
38 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40 |
41 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43 |
44 | #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45 | # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46 | # state_dict = self.audio_encoder.state_dict()
47 |
48 | # for k,v in wav2lip_state_dict.items():
49 | # if 'audio_encoder' in k:
50 | # state_dict[k.replace('module.audio_encoder.', '')] = v
51 | # self.audio_encoder.load_state_dict(state_dict)
52 |
53 |
54 | def forward(self, audio_sequences):
55 | # audio_sequences = (B, T, 1, 80, 16)
56 | B = audio_sequences.size(0)
57 |
58 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59 |
60 | audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61 | dim = audio_embedding.shape[1]
62 | audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63 |
64 | return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
65 |
--------------------------------------------------------------------------------
/SadTalker/src/facerender/modules/discriminator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from facerender.modules.util import kp2gaussian
4 | import torch
5 |
6 |
7 | class DownBlock2d(nn.Module):
8 | """
9 | Simple block for processing video (encoder).
10 | """
11 |
12 | def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
13 | super(DownBlock2d, self).__init__()
14 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
15 |
16 | if sn:
17 | self.conv = nn.utils.spectral_norm(self.conv)
18 |
19 | if norm:
20 | self.norm = nn.InstanceNorm2d(out_features, affine=True)
21 | else:
22 | self.norm = None
23 | self.pool = pool
24 |
25 | def forward(self, x):
26 | out = x
27 | out = self.conv(out)
28 | if self.norm:
29 | out = self.norm(out)
30 | out = F.leaky_relu(out, 0.2)
31 | if self.pool:
32 | out = F.avg_pool2d(out, (2, 2))
33 | return out
34 |
35 |
36 | class Discriminator(nn.Module):
37 | """
38 | Discriminator similar to Pix2Pix
39 | """
40 |
41 | def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
42 | sn=False, **kwargs):
43 | super(Discriminator, self).__init__()
44 |
45 | down_blocks = []
46 | for i in range(num_blocks):
47 | down_blocks.append(
48 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
49 | min(max_features, block_expansion * (2 ** (i + 1))),
50 | norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
51 |
52 | self.down_blocks = nn.ModuleList(down_blocks)
53 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
54 | if sn:
55 | self.conv = nn.utils.spectral_norm(self.conv)
56 |
57 | def forward(self, x):
58 | feature_maps = []
59 | out = x
60 |
61 | for down_block in self.down_blocks:
62 | feature_maps.append(down_block(out))
63 | out = feature_maps[-1]
64 | prediction_map = self.conv(out)
65 |
66 | return feature_maps, prediction_map
67 |
68 |
69 | class MultiScaleDiscriminator(nn.Module):
70 | """
71 | Multi-scale (scale) discriminator
72 | """
73 |
74 | def __init__(self, scales=(), **kwargs):
75 | super(MultiScaleDiscriminator, self).__init__()
76 | self.scales = scales
77 | discs = {}
78 | for scale in scales:
79 | discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
80 | self.discs = nn.ModuleDict(discs)
81 |
82 | def forward(self, x):
83 | out_dict = {}
84 | for scale, disc in self.discs.items():
85 | scale = str(scale).replace('-', '.')
86 | key = 'prediction_' + scale
87 | feature_maps, prediction_map = disc(x[key])
88 | out_dict['feature_maps_' + scale] = feature_maps
89 | out_dict['prediction_map_' + scale] = prediction_map
90 | return out_dict
91 |
--------------------------------------------------------------------------------
/SadTalker/src/audio2exp_models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 | class Conv2d(nn.Module):
6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7 | super().__init__(*args, **kwargs)
8 | self.conv_block = nn.Sequential(
9 | nn.Conv2d(cin, cout, kernel_size, stride, padding),
10 | nn.BatchNorm2d(cout)
11 | )
12 | self.act = nn.ReLU()
13 | self.residual = residual
14 | self.use_act = use_act
15 |
16 | def forward(self, x):
17 | out = self.conv_block(x)
18 | if self.residual:
19 | out += x
20 |
21 | if self.use_act:
22 | return self.act(out)
23 | else:
24 | return out
25 |
26 | class SimpleWrapperV2(nn.Module):
27 | def __init__(self) -> None:
28 | super().__init__()
29 | self.audio_encoder = nn.Sequential(
30 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33 |
34 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37 |
38 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41 |
42 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44 |
45 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47 | )
48 |
49 | #### load the pre-trained audio_encoder
50 | #self.audio_encoder = self.audio_encoder.to(device)
51 | '''
52 | wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53 | state_dict = self.audio_encoder.state_dict()
54 |
55 | for k,v in wav2lip_state_dict.items():
56 | if 'audio_encoder' in k:
57 | print('init:', k)
58 | state_dict[k.replace('module.audio_encoder.', '')] = v
59 | self.audio_encoder.load_state_dict(state_dict)
60 | '''
61 |
62 | self.mapping1 = nn.Linear(512+64+1, 64)
63 | #self.mapping2 = nn.Linear(30, 64)
64 | #nn.init.constant_(self.mapping1.weight, 0.)
65 | nn.init.constant_(self.mapping1.bias, 0.)
66 |
67 | def forward(self, x, ref, ratio):
68 | x = self.audio_encoder(x).view(x.size(0), -1)
69 | ref_reshape = ref.reshape(x.size(0), -1)
70 | ratio = ratio.reshape(x.size(0), -1)
71 |
72 | y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73 | out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74 | return out
75 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from src.face3d.models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "face3d.models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/SadTalker/src/facerender/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/utils/utils_amp.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import torch
4 |
5 | if torch.__version__ < '1.9':
6 | Iterable = torch._six.container_abcs.Iterable
7 | else:
8 | import collections
9 |
10 | Iterable = collections.abc.Iterable
11 | from torch.cuda.amp import GradScaler
12 |
13 |
14 | class _MultiDeviceReplicator(object):
15 | """
16 | Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
17 | """
18 |
19 | def __init__(self, master_tensor: torch.Tensor) -> None:
20 | assert master_tensor.is_cuda
21 | self.master = master_tensor
22 | self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
23 |
24 | def get(self, device) -> torch.Tensor:
25 | retval = self._per_device_tensors.get(device, None)
26 | if retval is None:
27 | retval = self.master.to(device=device, non_blocking=True, copy=True)
28 | self._per_device_tensors[device] = retval
29 | return retval
30 |
31 |
32 | class MaxClipGradScaler(GradScaler):
33 | def __init__(self, init_scale, max_scale: float, growth_interval=100):
34 | GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
35 | self.max_scale = max_scale
36 |
37 | def scale_clip(self):
38 | if self.get_scale() == self.max_scale:
39 | self.set_growth_factor(1)
40 | elif self.get_scale() < self.max_scale:
41 | self.set_growth_factor(2)
42 | elif self.get_scale() > self.max_scale:
43 | self._scale.fill_(self.max_scale)
44 | self.set_growth_factor(1)
45 |
46 | def scale(self, outputs):
47 | """
48 | Multiplies ('scales') a tensor or list of tensors by the scale factor.
49 |
50 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
51 | unmodified.
52 |
53 | Arguments:
54 | outputs (Tensor or iterable of Tensors): Outputs to scale.
55 | """
56 | if not self._enabled:
57 | return outputs
58 | self.scale_clip()
59 | # Short-circuit for the common case.
60 | if isinstance(outputs, torch.Tensor):
61 | assert outputs.is_cuda
62 | if self._scale is None:
63 | self._lazy_init_scale_growth_tracker(outputs.device)
64 | assert self._scale is not None
65 | return outputs * self._scale.to(device=outputs.device, non_blocking=True)
66 |
67 | # Invoke the more complex machinery only if we're treating multiple outputs.
68 | stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
69 |
70 | def apply_scale(val):
71 | if isinstance(val, torch.Tensor):
72 | assert val.is_cuda
73 | if len(stash) == 0:
74 | if self._scale is None:
75 | self._lazy_init_scale_growth_tracker(val.device)
76 | assert self._scale is not None
77 | stash.append(_MultiDeviceReplicator(self._scale))
78 | return val * stash[0].get(val.device)
79 | elif isinstance(val, Iterable):
80 | iterable = map(apply_scale, val)
81 | if isinstance(val, list) or isinstance(val, tuple):
82 | return type(val)(iterable)
83 | else:
84 | return iterable
85 | else:
86 | raise ValueError("outputs must be a Tensor or an iterable of Tensors")
87 |
88 | return apply_scale(outputs)
89 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/data/template_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | from data.base_dataset import BaseDataset, get_transform
15 | # from data.image_folder import make_dataset
16 | # from PIL import Image
17 |
18 |
19 | class TemplateDataset(BaseDataset):
20 | """A template dataset class for you to implement custom datasets."""
21 | @staticmethod
22 | def modify_commandline_options(parser, is_train):
23 | """Add new dataset-specific options, and rewrite default values for existing options.
24 |
25 | Parameters:
26 | parser -- original option parser
27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28 |
29 | Returns:
30 | the modified parser.
31 | """
32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34 | return parser
35 |
36 | def __init__(self, opt):
37 | """Initialize this dataset class.
38 |
39 | Parameters:
40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41 |
42 | A few things can be done here.
43 | - save the options (have been done in BaseDataset)
44 | - get image paths and meta information of the dataset.
45 | - define the image transformation.
46 | """
47 | # save the option and dataset root
48 | BaseDataset.__init__(self, opt)
49 | # get the image paths of your dataset;
50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51 | # define the default transform function. You can use ; You can also define your custom transform function
52 | self.transform = get_transform(opt)
53 |
54 | def __getitem__(self, index):
55 | """Return a data point and its metadata information.
56 |
57 | Parameters:
58 | index -- a random integer for data indexing
59 |
60 | Returns:
61 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
62 |
63 | Step 1: get a random image path: e.g., path = self.image_paths[index]
64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66 | Step 4: return a data point as a dictionary.
67 | """
68 | path = 'temp' # needs to be a string
69 | data_A = None # needs to be a tensor
70 | data_B = None # needs to be a tensor
71 | return {'data_A': data_A, 'data_B': data_B, 'path': path}
72 |
73 | def __len__(self):
74 | """Return the total number of images."""
75 | return len(self.image_paths)
76 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/options/train_options.py:
--------------------------------------------------------------------------------
1 | """This script contains the training options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | from .base_options import BaseOptions
5 | from util import util
6 |
7 | class TrainOptions(BaseOptions):
8 | """This class includes training options.
9 |
10 | It also includes shared options defined in BaseOptions.
11 | """
12 |
13 | def initialize(self, parser):
14 | parser = BaseOptions.initialize(self, parser)
15 | # dataset parameters
16 | # for train
17 | parser.add_argument('--data_root', type=str, default='./', help='dataset root')
18 | parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
19 | parser.add_argument('--batch_size', type=int, default=32)
20 | parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
21 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
22 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
23 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
24 | parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
25 | parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
26 |
27 | # for val
28 | parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
29 | parser.add_argument('--batch_size_val', type=int, default=32)
30 |
31 |
32 | # visualization parameters
33 | parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
34 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
35 |
36 | # network saving and loading parameters
37 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
38 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
39 | parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
40 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
41 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
42 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
43 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
44 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
45 |
46 | # training parameters
47 | parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
48 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
49 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
50 | parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
51 |
52 | self.isTrain = True
53 | return parser
54 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/preprocess.py:
--------------------------------------------------------------------------------
1 | """This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import numpy as np
5 | from scipy.io import loadmat
6 | from PIL import Image
7 | import cv2
8 | import os
9 | from skimage import transform as trans
10 | import torch
11 | import warnings
12 |
13 | warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
14 | warnings.filterwarnings("ignore", category=FutureWarning)
15 |
16 |
17 | # calculating least square problem for image alignment
18 | def POS(xp, x):
19 | npts = xp.shape[1]
20 |
21 | A = np.zeros([2 * npts, 8])
22 |
23 | A[0 : 2 * npts - 1 : 2, 0:3] = x.transpose()
24 | A[0 : 2 * npts - 1 : 2, 3] = 1
25 |
26 | A[1 : 2 * npts : 2, 4:7] = x.transpose()
27 | A[1 : 2 * npts : 2, 7] = 1
28 |
29 | b = np.reshape(xp.transpose(), [2 * npts, 1])
30 |
31 | k, _, _, _ = np.linalg.lstsq(A, b)
32 |
33 | R1 = k[0:3]
34 | R2 = k[4:7]
35 | sTx = k[3]
36 | sTy = k[7]
37 | s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2
38 | t = np.stack([sTx, sTy], axis=0)
39 |
40 | return t, s
41 |
42 |
43 | # resize and crop images for face reconstruction
44 | def resize_n_crop_img(img, lm, t, s, target_size=224.0, mask=None):
45 | w0, h0 = img.size
46 | w = (w0 * s).astype(np.int32)
47 | h = (h0 * s).astype(np.int32)
48 | left = (w / 2 - target_size / 2 + float((t[0] - w0 / 2) * s)).astype(np.int32)
49 | right = left + target_size
50 | up = (h / 2 - target_size / 2 + float((h0 / 2 - t[1]) * s)).astype(np.int32)
51 | below = up + target_size
52 |
53 | img = img.resize((w, h), resample=Image.BICUBIC)
54 | img = img.crop((left, up, right, below))
55 |
56 | if mask is not None:
57 | mask = mask.resize((w, h), resample=Image.BICUBIC)
58 | mask = mask.crop((left, up, right, below))
59 |
60 | lm = np.stack([lm[:, 0] - t[0] + w0 / 2, lm[:, 1] - t[1] + h0 / 2], axis=1) * s
61 | lm = lm - np.reshape(
62 | np.array([(w / 2 - target_size / 2), (h / 2 - target_size / 2)]), [1, 2]
63 | )
64 |
65 | return img, lm, mask
66 |
67 |
68 | # utils for face reconstruction
69 | def extract_5p(lm):
70 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
71 | lm5p = np.stack(
72 | [
73 | lm[lm_idx[0], :],
74 | np.mean(lm[lm_idx[[1, 2]], :], 0),
75 | np.mean(lm[lm_idx[[3, 4]], :], 0),
76 | lm[lm_idx[5], :],
77 | lm[lm_idx[6], :],
78 | ],
79 | axis=0,
80 | )
81 | lm5p = lm5p[[1, 2, 0, 3, 4], :]
82 | return lm5p
83 |
84 |
85 | # utils for face reconstruction
86 | def align_img(img, lm, lm3D, mask=None, target_size=224.0, rescale_factor=102.0):
87 | """
88 | Return:
89 | transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
90 | img_new --PIL.Image (target_size, target_size, 3)
91 | lm_new --numpy.array (68, 2), y direction is opposite to v direction
92 | mask_new --PIL.Image (target_size, target_size)
93 |
94 | Parameters:
95 | img --PIL.Image (raw_H, raw_W, 3)
96 | lm --numpy.array (68, 2), y direction is opposite to v direction
97 | lm3D --numpy.array (5, 3)
98 | mask --PIL.Image (raw_H, raw_W, 3)
99 | """
100 |
101 | w0, h0 = img.size
102 | if lm.shape[0] != 5:
103 | lm5p = extract_5p(lm)
104 | else:
105 | lm5p = lm
106 |
107 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
108 | t, s = POS(lm5p.transpose(), lm3D.transpose())
109 | s = rescale_factor / s
110 |
111 | # processing the image
112 | img_new, lm_new, mask_new = resize_n_crop_img(
113 | img, lm, t, s, target_size=target_size, mask=mask
114 | )
115 | trans_params = np.array([w0, h0, s, t[0].item(), t[1].item()], dtype=np.float32)
116 |
117 | return trans_params, img_new, lm_new, mask_new
118 |
--------------------------------------------------------------------------------
/SadTalker/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | examples/results/*
163 | gfpgan/*
164 | checkpoints/*
165 | assets/*
166 | results/*
167 | Dockerfile
168 | start_docker.sh
169 | start.sh
170 |
171 | checkpoints
172 |
173 | # Mac
174 | .DS_Store
175 |
--------------------------------------------------------------------------------
/SadTalker/src/audio2pose_models/audio2pose.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from src.audio2pose_models.cvae import CVAE
4 | from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5 | from src.audio2pose_models.audio_encoder import AudioEncoder
6 |
7 | class Audio2Pose(nn.Module):
8 | def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9 | super().__init__()
10 | self.cfg = cfg
11 | self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12 | self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13 | self.device = device
14 |
15 | self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16 | self.audio_encoder.eval()
17 | for param in self.audio_encoder.parameters():
18 | param.requires_grad = False
19 |
20 | self.netG = CVAE(cfg)
21 | self.netD_motion = PoseSequenceDiscriminator(cfg)
22 |
23 |
24 | def forward(self, x):
25 |
26 | batch = {}
27 | coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28 | batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29 | batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30 | batch['class'] = x['class'].squeeze(0).cuda() # bs
31 | indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32 |
33 | # forward
34 | audio_emb_list = []
35 | audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36 | batch['audio_emb'] = audio_emb
37 | batch = self.netG(batch)
38 |
39 | pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40 | pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41 | pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42 |
43 | batch['pose_pred'] = pose_pred
44 | batch['pose_gt'] = pose_gt
45 |
46 | return batch
47 |
48 | def test(self, x):
49 |
50 | batch = {}
51 | ref = x['ref'] #bs 1 70
52 | batch['ref'] = x['ref'][:,0,-6:]
53 | batch['class'] = x['class']
54 | bs = ref.shape[0]
55 |
56 | indiv_mels= x['indiv_mels'] # bs T 1 80 16
57 | indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58 | num_frames = x['num_frames']
59 | num_frames = int(num_frames) - 1
60 |
61 | #
62 | div = num_frames//self.seq_len
63 | re = num_frames%self.seq_len
64 | audio_emb_list = []
65 | pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66 | device=batch['ref'].device)]
67 |
68 | for i in range(div):
69 | z = torch.randn(bs, self.latent_dim).to(ref.device)
70 | batch['z'] = z
71 | audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72 | batch['audio_emb'] = audio_emb
73 | batch = self.netG.test(batch)
74 | pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75 |
76 | if re != 0:
77 | z = torch.randn(bs, self.latent_dim).to(ref.device)
78 | batch['z'] = z
79 | audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80 | if audio_emb.shape[1] != self.seq_len:
81 | pad_dim = self.seq_len-audio_emb.shape[1]
82 | pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83 | audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84 | batch['audio_emb'] = audio_emb
85 | batch = self.netG.test(batch)
86 | pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87 |
88 | pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89 | batch['pose_motion_pred'] = pose_motion_pred
90 |
91 | pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92 |
93 | batch['pose_pred'] = pose_pred
94 | return batch
95 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/test_mean_face.txt:
--------------------------------------------------------------------------------
1 | -5.228591537475585938e+01
2 | 2.078247070312500000e-01
3 | -5.064269638061523438e+01
4 | -1.315765380859375000e+01
5 | -4.952939224243164062e+01
6 | -2.592591094970703125e+01
7 | -4.793047332763671875e+01
8 | -3.832135772705078125e+01
9 | -4.512159729003906250e+01
10 | -5.059623336791992188e+01
11 | -3.917720794677734375e+01
12 | -6.043736648559570312e+01
13 | -2.929953765869140625e+01
14 | -6.861183166503906250e+01
15 | -1.719801330566406250e+01
16 | -7.572736358642578125e+01
17 | -1.961936950683593750e+00
18 | -7.862001037597656250e+01
19 | 1.467941284179687500e+01
20 | -7.607844543457031250e+01
21 | 2.744073486328125000e+01
22 | -6.915261840820312500e+01
23 | 3.855677795410156250e+01
24 | -5.950350570678710938e+01
25 | 4.478240966796875000e+01
26 | -4.867547225952148438e+01
27 | 4.714337158203125000e+01
28 | -3.800830078125000000e+01
29 | 4.940315246582031250e+01
30 | -2.496297454833984375e+01
31 | 5.117234802246093750e+01
32 | -1.241538238525390625e+01
33 | 5.190507507324218750e+01
34 | 8.244247436523437500e-01
35 | -4.150688934326171875e+01
36 | 2.386329650878906250e+01
37 | -3.570307159423828125e+01
38 | 3.017010498046875000e+01
39 | -2.790358734130859375e+01
40 | 3.212951660156250000e+01
41 | -1.941773223876953125e+01
42 | 3.156523132324218750e+01
43 | -1.138106536865234375e+01
44 | 2.841992187500000000e+01
45 | 5.993263244628906250e+00
46 | 2.895182800292968750e+01
47 | 1.343590545654296875e+01
48 | 3.189880371093750000e+01
49 | 2.203153991699218750e+01
50 | 3.302221679687500000e+01
51 | 2.992478942871093750e+01
52 | 3.099150085449218750e+01
53 | 3.628388977050781250e+01
54 | 2.765748596191406250e+01
55 | -1.933914184570312500e+00
56 | 1.405374145507812500e+01
57 | -2.153038024902343750e+00
58 | 5.772636413574218750e+00
59 | -2.270050048828125000e+00
60 | -2.121643066406250000e+00
61 | -2.218330383300781250e+00
62 | -1.068978118896484375e+01
63 | -1.187252044677734375e+01
64 | -1.997912597656250000e+01
65 | -6.879402160644531250e+00
66 | -2.143579864501953125e+01
67 | -1.227821350097656250e+00
68 | -2.193494415283203125e+01
69 | 4.623237609863281250e+00
70 | -2.152721405029296875e+01
71 | 9.721397399902343750e+00
72 | -1.953671264648437500e+01
73 | -3.648714447021484375e+01
74 | 9.811126708984375000e+00
75 | -3.130242919921875000e+01
76 | 1.422447967529296875e+01
77 | -2.212834930419921875e+01
78 | 1.493019866943359375e+01
79 | -1.500880432128906250e+01
80 | 1.073588562011718750e+01
81 | -2.095037078857421875e+01
82 | 9.054298400878906250e+00
83 | -3.050099182128906250e+01
84 | 8.704177856445312500e+00
85 | 1.173237609863281250e+01
86 | 1.054329681396484375e+01
87 | 1.856353759765625000e+01
88 | 1.535009765625000000e+01
89 | 2.893331909179687500e+01
90 | 1.451992797851562500e+01
91 | 3.452944946289062500e+01
92 | 1.065280151367187500e+01
93 | 2.875990295410156250e+01
94 | 8.654792785644531250e+00
95 | 1.942100524902343750e+01
96 | 9.422447204589843750e+00
97 | -2.204488372802734375e+01
98 | -3.983994293212890625e+01
99 | -1.324458312988281250e+01
100 | -3.467377471923828125e+01
101 | -6.749649047851562500e+00
102 | -3.092894744873046875e+01
103 | -9.183349609375000000e-01
104 | -3.196458435058593750e+01
105 | 4.220649719238281250e+00
106 | -3.090406036376953125e+01
107 | 1.089889526367187500e+01
108 | -3.497008514404296875e+01
109 | 1.874589538574218750e+01
110 | -4.065438079833984375e+01
111 | 1.124106597900390625e+01
112 | -4.438417816162109375e+01
113 | 5.181709289550781250e+00
114 | -4.649170684814453125e+01
115 | -1.158607482910156250e+00
116 | -4.680406951904296875e+01
117 | -7.918922424316406250e+00
118 | -4.671575164794921875e+01
119 | -1.452505493164062500e+01
120 | -4.416526031494140625e+01
121 | -2.005007171630859375e+01
122 | -3.997841644287109375e+01
123 | -1.054919433593750000e+01
124 | -3.849683380126953125e+01
125 | -1.051826477050781250e+00
126 | -3.794863128662109375e+01
127 | 6.412681579589843750e+00
128 | -3.804645538330078125e+01
129 | 1.627674865722656250e+01
130 | -4.039697265625000000e+01
131 | 6.373878479003906250e+00
132 | -4.087213897705078125e+01
133 | -8.551712036132812500e-01
134 | -4.157129669189453125e+01
135 | -1.014953613281250000e+01
136 | -4.128469085693359375e+01
137 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/extract_kp_videos.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import glob
5 | import argparse
6 | import face_alignment
7 | import numpy as np
8 | from PIL import Image
9 | from tqdm import tqdm
10 | from itertools import cycle
11 |
12 | from torch.multiprocessing import Pool, Process, set_start_method
13 |
14 | class KeypointExtractor():
15 | def __init__(self, device):
16 | self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17 | device=device)
18 |
19 | def extract_keypoint(self, images, name=None, info=True):
20 | if isinstance(images, list):
21 | keypoints = []
22 | if info:
23 | i_range = tqdm(images,desc='landmark Det:')
24 | else:
25 | i_range = images
26 |
27 | for image in i_range:
28 | current_kp = self.extract_keypoint(image)
29 | if np.mean(current_kp) == -1 and keypoints:
30 | keypoints.append(keypoints[-1])
31 | else:
32 | keypoints.append(current_kp[None])
33 |
34 | keypoints = np.concatenate(keypoints, 0)
35 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36 | return keypoints
37 | else:
38 | while True:
39 | try:
40 | keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41 | break
42 | except RuntimeError as e:
43 | if str(e).startswith('CUDA'):
44 | print("Warning: out of memory, sleep for 1s")
45 | time.sleep(1)
46 | else:
47 | print(e)
48 | break
49 | except TypeError:
50 | print('No face detected in this image')
51 | shape = [68, 2]
52 | keypoints = -1. * np.ones(shape)
53 | break
54 | if name is not None:
55 | np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56 | return keypoints
57 |
58 | def read_video(filename):
59 | frames = []
60 | cap = cv2.VideoCapture(filename)
61 | while cap.isOpened():
62 | ret, frame = cap.read()
63 | if ret:
64 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65 | frame = Image.fromarray(frame)
66 | frames.append(frame)
67 | else:
68 | break
69 | cap.release()
70 | return frames
71 |
72 | def run(data):
73 | filename, opt, device = data
74 | os.environ['CUDA_VISIBLE_DEVICES'] = device
75 | kp_extractor = KeypointExtractor()
76 | images = read_video(filename)
77 | name = filename.split('/')[-2:]
78 | os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79 | kp_extractor.extract_keypoint(
80 | images,
81 | name=os.path.join(opt.output_dir, name[-2], name[-1])
82 | )
83 |
84 | if __name__ == '__main__':
85 | set_start_method('spawn')
86 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87 | parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88 | parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89 | parser.add_argument('--device_ids', type=str, default='0,1')
90 | parser.add_argument('--workers', type=int, default=4)
91 |
92 | opt = parser.parse_args()
93 | filenames = list()
94 | VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95 | VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96 | extensions = VIDEO_EXTENSIONS
97 |
98 | for ext in extensions:
99 | os.listdir(f'{opt.input_dir}')
100 | print(f'{opt.input_dir}/*.{ext}')
101 | filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102 | print('Total number of videos:', len(filenames))
103 | pool = Pool(opt.workers)
104 | args_list = cycle([opt])
105 | device_ids = opt.device_ids.split(",")
106 | device_ids = cycle(device_ids)
107 | for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108 | None
109 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/dataset.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import os
3 | import queue as Queue
4 | import threading
5 |
6 | import mxnet as mx
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import DataLoader, Dataset
10 | from torchvision import transforms
11 |
12 |
13 | class BackgroundGenerator(threading.Thread):
14 | def __init__(self, generator, local_rank, max_prefetch=6):
15 | super(BackgroundGenerator, self).__init__()
16 | self.queue = Queue.Queue(max_prefetch)
17 | self.generator = generator
18 | self.local_rank = local_rank
19 | self.daemon = True
20 | self.start()
21 |
22 | def run(self):
23 | torch.cuda.set_device(self.local_rank)
24 | for item in self.generator:
25 | self.queue.put(item)
26 | self.queue.put(None)
27 |
28 | def next(self):
29 | next_item = self.queue.get()
30 | if next_item is None:
31 | raise StopIteration
32 | return next_item
33 |
34 | def __next__(self):
35 | return self.next()
36 |
37 | def __iter__(self):
38 | return self
39 |
40 |
41 | class DataLoaderX(DataLoader):
42 |
43 | def __init__(self, local_rank, **kwargs):
44 | super(DataLoaderX, self).__init__(**kwargs)
45 | self.stream = torch.cuda.Stream(local_rank)
46 | self.local_rank = local_rank
47 |
48 | def __iter__(self):
49 | self.iter = super(DataLoaderX, self).__iter__()
50 | self.iter = BackgroundGenerator(self.iter, self.local_rank)
51 | self.preload()
52 | return self
53 |
54 | def preload(self):
55 | self.batch = next(self.iter, None)
56 | if self.batch is None:
57 | return None
58 | with torch.cuda.stream(self.stream):
59 | for k in range(len(self.batch)):
60 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
61 |
62 | def __next__(self):
63 | torch.cuda.current_stream().wait_stream(self.stream)
64 | batch = self.batch
65 | if batch is None:
66 | raise StopIteration
67 | self.preload()
68 | return batch
69 |
70 |
71 | class MXFaceDataset(Dataset):
72 | def __init__(self, root_dir, local_rank):
73 | super(MXFaceDataset, self).__init__()
74 | self.transform = transforms.Compose(
75 | [transforms.ToPILImage(),
76 | transforms.RandomHorizontalFlip(),
77 | transforms.ToTensor(),
78 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
79 | ])
80 | self.root_dir = root_dir
81 | self.local_rank = local_rank
82 | path_imgrec = os.path.join(root_dir, 'train.rec')
83 | path_imgidx = os.path.join(root_dir, 'train.idx')
84 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
85 | s = self.imgrec.read_idx(0)
86 | header, _ = mx.recordio.unpack(s)
87 | if header.flag > 0:
88 | self.header0 = (int(header.label[0]), int(header.label[1]))
89 | self.imgidx = np.array(range(1, int(header.label[0])))
90 | else:
91 | self.imgidx = np.array(list(self.imgrec.keys))
92 |
93 | def __getitem__(self, index):
94 | idx = self.imgidx[index]
95 | s = self.imgrec.read_idx(idx)
96 | header, img = mx.recordio.unpack(s)
97 | label = header.label
98 | if not isinstance(label, numbers.Number):
99 | label = label[0]
100 | label = torch.tensor(label, dtype=torch.long)
101 | sample = mx.image.imdecode(img).asnumpy()
102 | if self.transform is not None:
103 | sample = self.transform(sample)
104 | return sample, label
105 |
106 | def __len__(self):
107 | return len(self.imgidx)
108 |
109 |
110 | class SyntheticDataset(Dataset):
111 | def __init__(self, local_rank):
112 | super(SyntheticDataset, self).__init__()
113 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
114 | img = np.transpose(img, (2, 0, 1))
115 | img = torch.from_numpy(img).squeeze(0).float()
116 | img = ((img / 255) - 0.5) / 0.5
117 | self.img = img
118 | self.label = 1
119 |
120 | def __getitem__(self, index):
121 | return self.img, self.label
122 |
123 | def __len__(self):
124 | return 1000000
125 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/detect_lm68.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from scipy.io import loadmat
5 | import tensorflow as tf
6 | from util.preprocess import align_for_lm
7 | from shutil import move
8 |
9 | mean_face = np.loadtxt('util/test_mean_face.txt')
10 | mean_face = mean_face.reshape([68, 2])
11 |
12 | def save_label(labels, save_path):
13 | np.savetxt(save_path, labels)
14 |
15 | def draw_landmarks(img, landmark, save_name):
16 | landmark = landmark
17 | lm_img = np.zeros([img.shape[0], img.shape[1], 3])
18 | lm_img[:] = img.astype(np.float32)
19 | landmark = np.round(landmark).astype(np.int32)
20 |
21 | for i in range(len(landmark)):
22 | for j in range(-1, 1):
23 | for k in range(-1, 1):
24 | if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
25 | img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
26 | landmark[i, 0]+k > 0 and \
27 | landmark[i, 0]+k < img.shape[1]:
28 | lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
29 | :] = np.array([0, 0, 255])
30 | lm_img = lm_img.astype(np.uint8)
31 |
32 | cv2.imwrite(save_name, lm_img)
33 |
34 |
35 | def load_data(img_name, txt_name):
36 | return cv2.imread(img_name), np.loadtxt(txt_name)
37 |
38 | # create tensorflow graph for landmark detector
39 | def load_lm_graph(graph_filename):
40 | with tf.gfile.GFile(graph_filename, 'rb') as f:
41 | graph_def = tf.GraphDef()
42 | graph_def.ParseFromString(f.read())
43 |
44 | with tf.Graph().as_default() as graph:
45 | tf.import_graph_def(graph_def, name='net')
46 | img_224 = graph.get_tensor_by_name('net/input_imgs:0')
47 | output_lm = graph.get_tensor_by_name('net/lm:0')
48 | lm_sess = tf.Session(graph=graph)
49 |
50 | return lm_sess,img_224,output_lm
51 |
52 | # landmark detection
53 | def detect_68p(img_path,sess,input_op,output_op):
54 | print('detecting landmarks......')
55 | names = [i for i in sorted(os.listdir(
56 | img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
57 | vis_path = os.path.join(img_path, 'vis')
58 | remove_path = os.path.join(img_path, 'remove')
59 | save_path = os.path.join(img_path, 'landmarks')
60 | if not os.path.isdir(vis_path):
61 | os.makedirs(vis_path)
62 | if not os.path.isdir(remove_path):
63 | os.makedirs(remove_path)
64 | if not os.path.isdir(save_path):
65 | os.makedirs(save_path)
66 |
67 | for i in range(0, len(names)):
68 | name = names[i]
69 | print('%05d' % (i), ' ', name)
70 | full_image_name = os.path.join(img_path, name)
71 | txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
72 | full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
73 |
74 | # if an image does not have detected 5 facial landmarks, remove it from the training list
75 | if not os.path.isfile(full_txt_name):
76 | move(full_image_name, os.path.join(remove_path, name))
77 | continue
78 |
79 | # load data
80 | img, five_points = load_data(full_image_name, full_txt_name)
81 | input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
82 |
83 | # if the alignment fails, remove corresponding image from the training list
84 | if scale == 0:
85 | move(full_txt_name, os.path.join(
86 | remove_path, txt_name))
87 | move(full_image_name, os.path.join(remove_path, name))
88 | continue
89 |
90 | # detect landmarks
91 | input_img = np.reshape(
92 | input_img, [1, 224, 224, 3]).astype(np.float32)
93 | landmark = sess.run(
94 | output_op, feed_dict={input_op: input_img})
95 |
96 | # transform back to original image coordinate
97 | landmark = landmark.reshape([68, 2]) + mean_face
98 | landmark[:, 1] = 223 - landmark[:, 1]
99 | landmark = landmark / scale
100 | landmark[:, 0] = landmark[:, 0] + bbox[0]
101 | landmark[:, 1] = landmark[:, 1] + bbox[1]
102 | landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
103 |
104 | if i % 100 == 0:
105 | draw_landmarks(img, landmark, os.path.join(vis_path, name))
106 | save_label(landmark, os.path.join(save_path, txt_name))
107 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/data/flist_dataset.py:
--------------------------------------------------------------------------------
1 | """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os.path
5 | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6 | from data.image_folder import make_dataset
7 | from PIL import Image
8 | import random
9 | import util.util as util
10 | import numpy as np
11 | import json
12 | import torch
13 | from scipy.io import loadmat, savemat
14 | import pickle
15 | from util.preprocess import align_img, estimate_norm
16 | from util.load_mats import load_lm3d
17 |
18 |
19 | def default_flist_reader(flist):
20 | """
21 | flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22 | """
23 | imlist = []
24 | with open(flist, 'r') as rf:
25 | for line in rf.readlines():
26 | impath = line.strip()
27 | imlist.append(impath)
28 |
29 | return imlist
30 |
31 | def jason_flist_reader(flist):
32 | with open(flist, 'r') as fp:
33 | info = json.load(fp)
34 | return info
35 |
36 | def parse_label(label):
37 | return torch.tensor(np.array(label).astype(np.float32))
38 |
39 |
40 | class FlistDataset(BaseDataset):
41 | """
42 | It requires one directories to host training images '/path/to/data/train'
43 | You can train the model with the dataset flag '--dataroot /path/to/data'.
44 | """
45 |
46 | def __init__(self, opt):
47 | """Initialize this dataset class.
48 |
49 | Parameters:
50 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51 | """
52 | BaseDataset.__init__(self, opt)
53 |
54 | self.lm3d_std = load_lm3d(opt.bfm_folder)
55 |
56 | msk_names = default_flist_reader(opt.flist)
57 | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58 |
59 | self.size = len(self.msk_paths)
60 | self.opt = opt
61 |
62 | self.name = 'train' if opt.isTrain else 'val'
63 | if '_' in opt.flist:
64 | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65 |
66 |
67 | def __getitem__(self, index):
68 | """Return a data point and its metadata information.
69 |
70 | Parameters:
71 | index (int) -- a random integer for data indexing
72 |
73 | Returns a dictionary that contains A, B, A_paths and B_paths
74 | img (tensor) -- an image in the input domain
75 | msk (tensor) -- its corresponding attention mask
76 | lm (tensor) -- its corresponding 3d landmarks
77 | im_paths (str) -- image paths
78 | aug_flag (bool) -- a flag used to tell whether its raw or augmented
79 | """
80 | msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81 | img_path = msk_path.replace('mask/', '')
82 | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83 |
84 | raw_img = Image.open(img_path).convert('RGB')
85 | raw_msk = Image.open(msk_path).convert('RGB')
86 | raw_lm = np.loadtxt(lm_path).astype(np.float32)
87 |
88 | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89 |
90 | aug_flag = self.opt.use_aug and self.opt.isTrain
91 | if aug_flag:
92 | img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93 |
94 | _, H = img.size
95 | M = estimate_norm(lm, H)
96 | transform = get_transform()
97 | img_tensor = transform(img)
98 | msk_tensor = transform(msk)[:1, ...]
99 | lm_tensor = parse_label(lm)
100 | M_tensor = parse_label(M)
101 |
102 |
103 | return {'imgs': img_tensor,
104 | 'lms': lm_tensor,
105 | 'msks': msk_tensor,
106 | 'M': M_tensor,
107 | 'im_paths': img_path,
108 | 'aug_flag': aug_flag,
109 | 'dataset': self.name}
110 |
111 | def _augmentation(self, img, lm, opt, msk=None):
112 | affine, affine_inv, flip = get_affine_mat(opt, img.size)
113 | img = apply_img_affine(img, affine_inv)
114 | lm = apply_lm_affine(lm, affine, flip, img.size)
115 | if msk is not None:
116 | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117 | return img, lm, msk
118 |
119 |
120 |
121 |
122 | def __len__(self):
123 | """Return the total number of images in the dataset.
124 | """
125 | return self.size
126 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from kornia.geometry import warp_affine
5 | import torch.nn.functional as F
6 |
7 | def resize_n_crop(image, M, dsize=112):
8 | # image: (b, c, h, w)
9 | # M : (b, 2, 3)
10 | return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
11 |
12 | ### perceptual level loss
13 | class PerceptualLoss(nn.Module):
14 | def __init__(self, recog_net, input_size=112):
15 | super(PerceptualLoss, self).__init__()
16 | self.recog_net = recog_net
17 | self.preprocess = lambda x: 2 * x - 1
18 | self.input_size=input_size
19 | def forward(imageA, imageB, M):
20 | """
21 | 1 - cosine distance
22 | Parameters:
23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
24 | imageB --same as imageA
25 | """
26 |
27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
29 |
30 | # freeze bn
31 | self.recog_net.eval()
32 |
33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
36 | # assert torch.sum((cosine_d > 1).float()) == 0
37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0]
38 |
39 | def perceptual_loss(id_featureA, id_featureB):
40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
41 | # assert torch.sum((cosine_d > 1).float()) == 0
42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0]
43 |
44 | ### image level loss
45 | def photo_loss(imageA, imageB, mask, eps=1e-6):
46 | """
47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
48 | Parameters:
49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
50 | imageB --same as imageA
51 | """
52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
54 | return loss
55 |
56 | def landmark_loss(predict_lm, gt_lm, weight=None):
57 | """
58 | weighted mse loss
59 | Parameters:
60 | predict_lm --torch.tensor (B, 68, 2)
61 | gt_lm --torch.tensor (B, 68, 2)
62 | weight --numpy.array (1, 68)
63 | """
64 | if not weight:
65 | weight = np.ones([68])
66 | weight[28:31] = 20
67 | weight[-8:] = 20
68 | weight = np.expand_dims(weight, 0)
69 | weight = torch.tensor(weight).to(predict_lm.device)
70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
72 | return loss
73 |
74 |
75 | ### regulization
76 | def reg_loss(coeffs_dict, opt=None):
77 | """
78 | l2 norm without the sqrt, from yu's implementation (mse)
79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
80 | Parameters:
81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
82 |
83 | """
84 | # coefficient regularization to ensure plausible 3d faces
85 | if opt:
86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
87 | else:
88 | w_id, w_exp, w_tex = 1, 1, 1, 1
89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2)
92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0]
93 |
94 | # gamma regularization to ensure a nearly-monochromatic light
95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
98 |
99 | return creg_loss, gamma_loss
100 |
101 | def reflectance_loss(texture, mask):
102 | """
103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
104 | Parameters:
105 | texture --torch.tensor, (B, N, 3)
106 | mask --torch.tensor, (N), 1 or 0
107 |
108 | """
109 | mask = mask.reshape([1, mask.shape[0], 1])
110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
112 | return loss
113 |
114 |
--------------------------------------------------------------------------------
/SadTalker/src/audio2pose_models/networks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class ResidualConv(nn.Module):
6 | def __init__(self, input_dim, output_dim, stride, padding):
7 | super(ResidualConv, self).__init__()
8 |
9 | self.conv_block = nn.Sequential(
10 | nn.BatchNorm2d(input_dim),
11 | nn.ReLU(),
12 | nn.Conv2d(
13 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14 | ),
15 | nn.BatchNorm2d(output_dim),
16 | nn.ReLU(),
17 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18 | )
19 | self.conv_skip = nn.Sequential(
20 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21 | nn.BatchNorm2d(output_dim),
22 | )
23 |
24 | def forward(self, x):
25 |
26 | return self.conv_block(x) + self.conv_skip(x)
27 |
28 |
29 | class Upsample(nn.Module):
30 | def __init__(self, input_dim, output_dim, kernel, stride):
31 | super(Upsample, self).__init__()
32 |
33 | self.upsample = nn.ConvTranspose2d(
34 | input_dim, output_dim, kernel_size=kernel, stride=stride
35 | )
36 |
37 | def forward(self, x):
38 | return self.upsample(x)
39 |
40 |
41 | class Squeeze_Excite_Block(nn.Module):
42 | def __init__(self, channel, reduction=16):
43 | super(Squeeze_Excite_Block, self).__init__()
44 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
45 | self.fc = nn.Sequential(
46 | nn.Linear(channel, channel // reduction, bias=False),
47 | nn.ReLU(inplace=True),
48 | nn.Linear(channel // reduction, channel, bias=False),
49 | nn.Sigmoid(),
50 | )
51 |
52 | def forward(self, x):
53 | b, c, _, _ = x.size()
54 | y = self.avg_pool(x).view(b, c)
55 | y = self.fc(y).view(b, c, 1, 1)
56 | return x * y.expand_as(x)
57 |
58 |
59 | class ASPP(nn.Module):
60 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61 | super(ASPP, self).__init__()
62 |
63 | self.aspp_block1 = nn.Sequential(
64 | nn.Conv2d(
65 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66 | ),
67 | nn.ReLU(inplace=True),
68 | nn.BatchNorm2d(out_dims),
69 | )
70 | self.aspp_block2 = nn.Sequential(
71 | nn.Conv2d(
72 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73 | ),
74 | nn.ReLU(inplace=True),
75 | nn.BatchNorm2d(out_dims),
76 | )
77 | self.aspp_block3 = nn.Sequential(
78 | nn.Conv2d(
79 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80 | ),
81 | nn.ReLU(inplace=True),
82 | nn.BatchNorm2d(out_dims),
83 | )
84 |
85 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86 | self._init_weights()
87 |
88 | def forward(self, x):
89 | x1 = self.aspp_block1(x)
90 | x2 = self.aspp_block2(x)
91 | x3 = self.aspp_block3(x)
92 | out = torch.cat([x1, x2, x3], dim=1)
93 | return self.output(out)
94 |
95 | def _init_weights(self):
96 | for m in self.modules():
97 | if isinstance(m, nn.Conv2d):
98 | nn.init.kaiming_normal_(m.weight)
99 | elif isinstance(m, nn.BatchNorm2d):
100 | m.weight.data.fill_(1)
101 | m.bias.data.zero_()
102 |
103 |
104 | class Upsample_(nn.Module):
105 | def __init__(self, scale=2):
106 | super(Upsample_, self).__init__()
107 |
108 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109 |
110 | def forward(self, x):
111 | return self.upsample(x)
112 |
113 |
114 | class AttentionBlock(nn.Module):
115 | def __init__(self, input_encoder, input_decoder, output_dim):
116 | super(AttentionBlock, self).__init__()
117 |
118 | self.conv_encoder = nn.Sequential(
119 | nn.BatchNorm2d(input_encoder),
120 | nn.ReLU(),
121 | nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122 | nn.MaxPool2d(2, 2),
123 | )
124 |
125 | self.conv_decoder = nn.Sequential(
126 | nn.BatchNorm2d(input_decoder),
127 | nn.ReLU(),
128 | nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129 | )
130 |
131 | self.conv_attn = nn.Sequential(
132 | nn.BatchNorm2d(output_dim),
133 | nn.ReLU(),
134 | nn.Conv2d(output_dim, 1, 1),
135 | )
136 |
137 | def forward(self, x1, x2):
138 | out = self.conv_encoder(x1) + self.conv_decoder(x2)
139 | out = self.conv_attn(out)
140 | return out * x2
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/load_mats.py:
--------------------------------------------------------------------------------
1 | """This script is to load 3D face model for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import numpy as np
5 | from PIL import Image
6 | from scipy.io import loadmat, savemat
7 | from array import array
8 | import os.path as osp
9 |
10 | # load expression basis
11 | def LoadExpBasis(bfm_folder='BFM'):
12 | n_vertex = 53215
13 | Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
14 | exp_dim = array('i')
15 | exp_dim.fromfile(Expbin, 1)
16 | expMU = array('f')
17 | expPC = array('f')
18 | expMU.fromfile(Expbin, 3*n_vertex)
19 | expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
20 | Expbin.close()
21 |
22 | expPC = np.array(expPC)
23 | expPC = np.reshape(expPC, [exp_dim[0], -1])
24 | expPC = np.transpose(expPC)
25 |
26 | expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
27 |
28 | return expPC, expEV
29 |
30 |
31 | # transfer original BFM09 to our face model
32 | def transferBFM09(bfm_folder='BFM'):
33 | print('Transfer BFM09 to BFM_model_front......')
34 | original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
35 | shapePC = original_BFM['shapePC'] # shape basis
36 | shapeEV = original_BFM['shapeEV'] # corresponding eigen value
37 | shapeMU = original_BFM['shapeMU'] # mean face
38 | texPC = original_BFM['texPC'] # texture basis
39 | texEV = original_BFM['texEV'] # eigen value
40 | texMU = original_BFM['texMU'] # mean texture
41 |
42 | expPC, expEV = LoadExpBasis(bfm_folder)
43 |
44 | # transfer BFM09 to our face model
45 |
46 | idBase = shapePC*np.reshape(shapeEV, [-1, 199])
47 | idBase = idBase/1e5 # unify the scale to decimeter
48 | idBase = idBase[:, :80] # use only first 80 basis
49 |
50 | exBase = expPC*np.reshape(expEV, [-1, 79])
51 | exBase = exBase/1e5 # unify the scale to decimeter
52 | exBase = exBase[:, :64] # use only first 64 basis
53 |
54 | texBase = texPC*np.reshape(texEV, [-1, 199])
55 | texBase = texBase[:, :80] # use only first 80 basis
56 |
57 | # our face model is cropped along face landmarks and contains only 35709 vertex.
58 | # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
59 | # thus we select corresponding vertex to get our face model.
60 |
61 | index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
62 | index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215)
63 |
64 | index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
65 | index_shape = index_shape['trimIndex'].astype(
66 | np.int32) - 1 # starts from 0 (to 53490)
67 | index_shape = index_shape[index_exp]
68 |
69 | idBase = np.reshape(idBase, [-1, 3, 80])
70 | idBase = idBase[index_shape, :, :]
71 | idBase = np.reshape(idBase, [-1, 80])
72 |
73 | texBase = np.reshape(texBase, [-1, 3, 80])
74 | texBase = texBase[index_shape, :, :]
75 | texBase = np.reshape(texBase, [-1, 80])
76 |
77 | exBase = np.reshape(exBase, [-1, 3, 64])
78 | exBase = exBase[index_exp, :, :]
79 | exBase = np.reshape(exBase, [-1, 64])
80 |
81 | meanshape = np.reshape(shapeMU, [-1, 3])/1e5
82 | meanshape = meanshape[index_shape, :]
83 | meanshape = np.reshape(meanshape, [1, -1])
84 |
85 | meantex = np.reshape(texMU, [-1, 3])
86 | meantex = meantex[index_shape, :]
87 | meantex = np.reshape(meantex, [1, -1])
88 |
89 | # other info contains triangles, region used for computing photometric loss,
90 | # region used for skin texture regularization, and 68 landmarks index etc.
91 | other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
92 | frontmask2_idx = other_info['frontmask2_idx']
93 | skinmask = other_info['skinmask']
94 | keypoints = other_info['keypoints']
95 | point_buf = other_info['point_buf']
96 | tri = other_info['tri']
97 | tri_mask2 = other_info['tri_mask2']
98 |
99 | # save our face model
100 | savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
101 | 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
102 |
103 |
104 | # load landmarks for standard face, which is used for image preprocessing
105 | def load_lm3d(bfm_folder):
106 |
107 | Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
108 | Lm3D = Lm3D['lm']
109 |
110 | # calculate 5 facial landmarks using 68 landmarks
111 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
112 | Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
113 | Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
114 | Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
115 |
116 | return Lm3D
117 |
118 |
119 | if __name__ == '__main__':
120 | transferBFM09()
--------------------------------------------------------------------------------
/SadTalker/src/generate_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from tqdm import tqdm
4 | import torch
5 | import numpy as np
6 | import random
7 | import scipy.io as scio
8 | import src.utils.audio as audio
9 |
10 | def crop_pad_audio(wav, audio_length):
11 | if len(wav) > audio_length:
12 | wav = wav[:audio_length]
13 | elif len(wav) < audio_length:
14 | wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
15 | return wav
16 |
17 | def parse_audio_length(audio_length, sr, fps):
18 | bit_per_frames = sr / fps
19 |
20 | num_frames = int(audio_length / bit_per_frames)
21 | audio_length = int(num_frames * bit_per_frames)
22 |
23 | return audio_length, num_frames
24 |
25 | def generate_blink_seq(num_frames):
26 | ratio = np.zeros((num_frames,1))
27 | frame_id = 0
28 | while frame_id in range(num_frames):
29 | start = 80
30 | if frame_id+start+9<=num_frames - 1:
31 | ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
32 | frame_id = frame_id+start+9
33 | else:
34 | break
35 | return ratio
36 |
37 | def generate_blink_seq_randomly(num_frames):
38 | ratio = np.zeros((num_frames,1))
39 | if num_frames<=20:
40 | return ratio
41 | frame_id = 0
42 | while frame_id in range(num_frames):
43 | start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70)))
44 | if frame_id+start+5<=num_frames - 1:
45 | ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
46 | frame_id = frame_id+start+5
47 | else:
48 | break
49 | return ratio
50 |
51 | def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
52 |
53 | syncnet_mel_step_size = 16
54 | fps = 25
55 |
56 | pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
57 | audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
58 |
59 |
60 | if idlemode:
61 | num_frames = int(length_of_audio * 25)
62 | indiv_mels = np.zeros((num_frames, 80, 16))
63 | else:
64 | wav = audio.load_wav(audio_path, 16000)
65 | wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
66 | wav = crop_pad_audio(wav, wav_length)
67 | orig_mel = audio.melspectrogram(wav).T
68 | spec = orig_mel.copy() # nframes 80
69 | indiv_mels = []
70 |
71 | for i in tqdm(range(num_frames), 'mel:'):
72 | start_frame_num = i-2
73 | start_idx = int(80. * (start_frame_num / float(fps)))
74 | end_idx = start_idx + syncnet_mel_step_size
75 | seq = list(range(start_idx, end_idx))
76 | seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
77 | m = spec[seq, :]
78 | indiv_mels.append(m.T)
79 | indiv_mels = np.asarray(indiv_mels) # T 80 16
80 |
81 | ratio = generate_blink_seq_randomly(num_frames) # T
82 | source_semantics_path = first_coeff_path
83 | source_semantics_dict = scio.loadmat(source_semantics_path)
84 | ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
85 | ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
86 |
87 | if ref_eyeblink_coeff_path is not None:
88 | ratio[:num_frames] = 0
89 | refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
90 | refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
91 | refeyeblink_num_frames = refeyeblink_coeff.shape[0]
92 | if refeyeblink_num_frames: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import numpy as np
14 | import importlib
15 | import torch.utils.data
16 | from face3d.data.base_dataset import BaseDataset
17 |
18 |
19 | def find_dataset_using_name(dataset_name):
20 | """Import the module "data/[dataset_name]_dataset.py".
21 |
22 | In the file, the class called DatasetNameDataset() will
23 | be instantiated. It has to be a subclass of BaseDataset,
24 | and it is case-insensitive.
25 | """
26 | dataset_filename = "data." + dataset_name + "_dataset"
27 | datasetlib = importlib.import_module(dataset_filename)
28 |
29 | dataset = None
30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31 | for name, cls in datasetlib.__dict__.items():
32 | if name.lower() == target_dataset_name.lower() \
33 | and issubclass(cls, BaseDataset):
34 | dataset = cls
35 |
36 | if dataset is None:
37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38 |
39 | return dataset
40 |
41 |
42 | def get_option_setter(dataset_name):
43 | """Return the static method of the dataset class."""
44 | dataset_class = find_dataset_using_name(dataset_name)
45 | return dataset_class.modify_commandline_options
46 |
47 |
48 | def create_dataset(opt, rank=0):
49 | """Create a dataset given the option.
50 |
51 | This function wraps the class CustomDatasetDataLoader.
52 | This is the main interface between this package and 'train.py'/'test.py'
53 |
54 | Example:
55 | >>> from data import create_dataset
56 | >>> dataset = create_dataset(opt)
57 | """
58 | data_loader = CustomDatasetDataLoader(opt, rank=rank)
59 | dataset = data_loader.load_data()
60 | return dataset
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt, rank=0):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | self.sampler = None
75 | print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76 | if opt.use_ddp and opt.isTrain:
77 | world_size = opt.world_size
78 | self.sampler = torch.utils.data.distributed.DistributedSampler(
79 | self.dataset,
80 | num_replicas=world_size,
81 | rank=rank,
82 | shuffle=not opt.serial_batches
83 | )
84 | self.dataloader = torch.utils.data.DataLoader(
85 | self.dataset,
86 | sampler=self.sampler,
87 | num_workers=int(opt.num_threads / world_size),
88 | batch_size=int(opt.batch_size / world_size),
89 | drop_last=True)
90 | else:
91 | self.dataloader = torch.utils.data.DataLoader(
92 | self.dataset,
93 | batch_size=opt.batch_size,
94 | shuffle=(not opt.serial_batches) and opt.isTrain,
95 | num_workers=int(opt.num_threads),
96 | drop_last=True
97 | )
98 |
99 | def set_epoch(self, epoch):
100 | self.dataset.current_epoch = epoch
101 | if self.sampler is not None:
102 | self.sampler.set_epoch(epoch)
103 |
104 | def load_data(self):
105 | return self
106 |
107 | def __len__(self):
108 | """Return the number of data in the dataset"""
109 | return min(len(self.dataset), self.opt.max_dataset_size)
110 |
111 | def __iter__(self):
112 | """Return a batch of data"""
113 | for i, data in enumerate(self.dataloader):
114 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
115 | break
116 | yield data
117 |
--------------------------------------------------------------------------------
/SadTalker/src/facerender/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/SadTalker/src/utils/audio.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import librosa.filters
3 | import numpy as np
4 | # import tensorflow as tf
5 | from scipy import signal
6 | from scipy.io import wavfile
7 | from src.utils.hparams import hparams as hp
8 |
9 | def load_wav(path, sr):
10 | return librosa.core.load(path, sr=sr)[0]
11 |
12 | def save_wav(wav, path, sr):
13 | wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14 | #proposed by @dsmiller
15 | wavfile.write(path, sr, wav.astype(np.int16))
16 |
17 | def save_wavenet_wav(wav, path, sr):
18 | librosa.output.write_wav(path, wav, sr=sr)
19 |
20 | def preemphasis(wav, k, preemphasize=True):
21 | if preemphasize:
22 | return signal.lfilter([1, -k], [1], wav)
23 | return wav
24 |
25 | def inv_preemphasis(wav, k, inv_preemphasize=True):
26 | if inv_preemphasize:
27 | return signal.lfilter([1], [1, -k], wav)
28 | return wav
29 |
30 | def get_hop_size():
31 | hop_size = hp.hop_size
32 | if hop_size is None:
33 | assert hp.frame_shift_ms is not None
34 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
35 | return hop_size
36 |
37 | def linearspectrogram(wav):
38 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
39 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db
40 |
41 | if hp.signal_normalization:
42 | return _normalize(S)
43 | return S
44 |
45 | def melspectrogram(wav):
46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
47 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
48 |
49 | if hp.signal_normalization:
50 | return _normalize(S)
51 | return S
52 |
53 | def _lws_processor():
54 | import lws
55 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
56 |
57 | def _stft(y):
58 | if hp.use_lws:
59 | return _lws_processor(hp).stft(y).T
60 | else:
61 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
62 |
63 | ##########################################################
64 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
65 | def num_frames(length, fsize, fshift):
66 | """Compute number of time frames of spectrogram
67 | """
68 | pad = (fsize - fshift)
69 | if length % fshift == 0:
70 | M = (length + pad * 2 - fsize) // fshift + 1
71 | else:
72 | M = (length + pad * 2 - fsize) // fshift + 2
73 | return M
74 |
75 |
76 | def pad_lr(x, fsize, fshift):
77 | """Compute left and right padding
78 | """
79 | M = num_frames(len(x), fsize, fshift)
80 | pad = (fsize - fshift)
81 | T = len(x) + 2 * pad
82 | r = (M - 1) * fshift + fsize - T
83 | return pad, pad + r
84 | ##########################################################
85 | #Librosa correct padding
86 | def librosa_pad_lr(x, fsize, fshift):
87 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
88 |
89 | # Conversions
90 | _mel_basis = None
91 |
92 | def _linear_to_mel(spectogram):
93 | global _mel_basis
94 | if _mel_basis is None:
95 | _mel_basis = _build_mel_basis()
96 | return np.dot(_mel_basis, spectogram)
97 |
98 | def _build_mel_basis():
99 | assert hp.fmax <= hp.sample_rate // 2
100 | return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
101 | fmin=hp.fmin, fmax=hp.fmax)
102 |
103 | def _amp_to_db(x):
104 | min_level = np.exp(hp.min_level_db / 20 * np.log(10))
105 | return 20 * np.log10(np.maximum(min_level, x))
106 |
107 | def _db_to_amp(x):
108 | return np.power(10.0, (x) * 0.05)
109 |
110 | def _normalize(S):
111 | if hp.allow_clipping_in_normalization:
112 | if hp.symmetric_mels:
113 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
114 | -hp.max_abs_value, hp.max_abs_value)
115 | else:
116 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
117 |
118 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
119 | if hp.symmetric_mels:
120 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
121 | else:
122 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
123 |
124 | def _denormalize(D):
125 | if hp.allow_clipping_in_normalization:
126 | if hp.symmetric_mels:
127 | return (((np.clip(D, -hp.max_abs_value,
128 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
129 | + hp.min_level_db)
130 | else:
131 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
132 |
133 | if hp.symmetric_mels:
134 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
135 | else:
136 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
137 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/util/nvdiffrast.py:
--------------------------------------------------------------------------------
1 | """This script is the differentiable renderer for Deep3DFaceRecon_pytorch
2 | Attention, antialiasing step is missing in current version.
3 | """
4 | import pytorch3d.ops
5 | import torch
6 | import torch.nn.functional as F
7 | import kornia
8 | from kornia.geometry.camera import pixel2cam
9 | import numpy as np
10 | from typing import List
11 | from scipy.io import loadmat
12 | from torch import nn
13 |
14 | from pytorch3d.structures import Meshes
15 | from pytorch3d.renderer import (
16 | look_at_view_transform,
17 | FoVPerspectiveCameras,
18 | DirectionalLights,
19 | RasterizationSettings,
20 | MeshRenderer,
21 | MeshRasterizer,
22 | SoftPhongShader,
23 | TexturesUV,
24 | )
25 |
26 | # def ndc_projection(x=0.1, n=1.0, f=50.0):
27 | # return np.array([[n/x, 0, 0, 0],
28 | # [ 0, n/-x, 0, 0],
29 | # [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
30 | # [ 0, 0, -1, 0]]).astype(np.float32)
31 |
32 | class MeshRenderer(nn.Module):
33 | def __init__(self,
34 | rasterize_fov,
35 | znear=0.1,
36 | zfar=10,
37 | rasterize_size=224):
38 | super(MeshRenderer, self).__init__()
39 |
40 | # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
41 | # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
42 | # torch.diag(torch.tensor([1., -1, -1, 1])))
43 | self.rasterize_size = rasterize_size
44 | self.fov = rasterize_fov
45 | self.znear = znear
46 | self.zfar = zfar
47 |
48 | self.rasterizer = None
49 |
50 | def forward(self, vertex, tri, feat=None):
51 | """
52 | Return:
53 | mask -- torch.tensor, size (B, 1, H, W)
54 | depth -- torch.tensor, size (B, 1, H, W)
55 | features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
56 |
57 | Parameters:
58 | vertex -- torch.tensor, size (B, N, 3)
59 | tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
60 | feat(optional) -- torch.tensor, size (B, N ,C), features
61 | """
62 | device = vertex.device
63 | rsize = int(self.rasterize_size)
64 | # ndc_proj = self.ndc_proj.to(device)
65 | # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
66 | if vertex.shape[-1] == 3:
67 | vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
68 | vertex[..., 0] = -vertex[..., 0]
69 |
70 |
71 | # vertex_ndc = vertex @ ndc_proj.t()
72 | if self.rasterizer is None:
73 | self.rasterizer = MeshRasterizer()
74 | print("create rasterizer on device cuda:%d"%device.index)
75 |
76 | # ranges = None
77 | # if isinstance(tri, List) or len(tri.shape) == 3:
78 | # vum = vertex_ndc.shape[1]
79 | # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
80 | # fstartidx = torch.cumsum(fnum, dim=0) - fnum
81 | # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
82 | # for i in range(tri.shape[0]):
83 | # tri[i] = tri[i] + i*vum
84 | # vertex_ndc = torch.cat(vertex_ndc, dim=0)
85 | # tri = torch.cat(tri, dim=0)
86 |
87 | # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
88 | tri = tri.type(torch.int32).contiguous()
89 |
90 | # rasterize
91 | cameras = FoVPerspectiveCameras(
92 | device=device,
93 | fov=self.fov,
94 | znear=self.znear,
95 | zfar=self.zfar,
96 | )
97 |
98 | raster_settings = RasterizationSettings(
99 | image_size=rsize
100 | )
101 |
102 | # print(vertex.shape, tri.shape)
103 | mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1)))
104 |
105 | fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
106 | rast_out = fragments.pix_to_face.squeeze(-1)
107 | depth = fragments.zbuf
108 |
109 | # render depth
110 | depth = depth.permute(0, 3, 1, 2)
111 | mask = (rast_out > 0).float().unsqueeze(1)
112 | depth = mask * depth
113 |
114 |
115 | image = None
116 | if feat is not None:
117 | attributes = feat.reshape(-1,3)[mesh.faces_packed()]
118 | image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
119 | fragments.bary_coords,
120 | attributes)
121 | # print(image.shape)
122 | image = image.squeeze(-2).permute(0, 3, 1, 2)
123 | image = mask * image
124 |
125 | return mask, depth, image
126 |
127 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | # self.root = opt.dataroot
31 | self.current_epoch = 0
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 |
64 | def get_transform(grayscale=False):
65 | transform_list = []
66 | if grayscale:
67 | transform_list.append(transforms.Grayscale(1))
68 | transform_list += [transforms.ToTensor()]
69 | return transforms.Compose(transform_list)
70 |
71 | def get_affine_mat(opt, size):
72 | shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73 | w, h = size
74 |
75 | if 'shift' in opt.preprocess:
76 | shift_pixs = int(opt.shift_pixs)
77 | shift_x = random.randint(-shift_pixs, shift_pixs)
78 | shift_y = random.randint(-shift_pixs, shift_pixs)
79 | if 'scale' in opt.preprocess:
80 | scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81 | if 'rot' in opt.preprocess:
82 | rot_angle = opt.rot_angle * (2 * random.random() - 1)
83 | rot_rad = -rot_angle * np.pi/180
84 | if 'flip' in opt.preprocess:
85 | flip = random.random() > 0.5
86 |
87 | shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88 | flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89 | shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90 | rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91 | scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92 | shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93 |
94 | affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95 | affine_inv = np.linalg.inv(affine)
96 | return affine, affine_inv, flip
97 |
98 | def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99 | return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100 |
101 | def apply_lm_affine(landmark, affine, flip, size):
102 | _, h = size
103 | lm = landmark.copy()
104 | lm[:, 1] = h - 1 - lm[:, 1]
105 | lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106 | lm = lm @ np.transpose(affine)
107 | lm[:, :2] = lm[:, :2] / lm[:, 2:]
108 | lm = lm[:, :2]
109 | lm[:, 1] = h - 1 - lm[:, 1]
110 | if flip:
111 | lm_ = lm.copy()
112 | lm_[:17] = lm[16::-1]
113 | lm_[17:22] = lm[26:21:-1]
114 | lm_[22:27] = lm[21:16:-1]
115 | lm_[31:36] = lm[35:30:-1]
116 | lm_[36:40] = lm[45:41:-1]
117 | lm_[40:42] = lm[47:45:-1]
118 | lm_[42:46] = lm[39:35:-1]
119 | lm_[46:48] = lm[41:39:-1]
120 | lm_[48:55] = lm[54:47:-1]
121 | lm_[55:60] = lm[59:54:-1]
122 | lm_[60:65] = lm[64:59:-1]
123 | lm_[65:68] = lm[67:64:-1]
124 | lm = lm_
125 | return lm
126 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/backbones/mobilefacenet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3 | Original author cavalleria
4 | '''
5 |
6 | import torch.nn as nn
7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8 | import torch
9 |
10 |
11 | class Flatten(Module):
12 | def forward(self, x):
13 | return x.view(x.size(0), -1)
14 |
15 |
16 | class ConvBlock(Module):
17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18 | super(ConvBlock, self).__init__()
19 | self.layers = nn.Sequential(
20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21 | BatchNorm2d(num_features=out_c),
22 | PReLU(num_parameters=out_c)
23 | )
24 |
25 | def forward(self, x):
26 | return self.layers(x)
27 |
28 |
29 | class LinearBlock(Module):
30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31 | super(LinearBlock, self).__init__()
32 | self.layers = nn.Sequential(
33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34 | BatchNorm2d(num_features=out_c)
35 | )
36 |
37 | def forward(self, x):
38 | return self.layers(x)
39 |
40 |
41 | class DepthWise(Module):
42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43 | super(DepthWise, self).__init__()
44 | self.residual = residual
45 | self.layers = nn.Sequential(
46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49 | )
50 |
51 | def forward(self, x):
52 | short_cut = None
53 | if self.residual:
54 | short_cut = x
55 | x = self.layers(x)
56 | if self.residual:
57 | output = short_cut + x
58 | else:
59 | output = x
60 | return output
61 |
62 |
63 | class Residual(Module):
64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65 | super(Residual, self).__init__()
66 | modules = []
67 | for _ in range(num_block):
68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69 | self.layers = Sequential(*modules)
70 |
71 | def forward(self, x):
72 | return self.layers(x)
73 |
74 |
75 | class GDC(Module):
76 | def __init__(self, embedding_size):
77 | super(GDC, self).__init__()
78 | self.layers = nn.Sequential(
79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80 | Flatten(),
81 | Linear(512, embedding_size, bias=False),
82 | BatchNorm1d(embedding_size))
83 |
84 | def forward(self, x):
85 | return self.layers(x)
86 |
87 |
88 | class MobileFaceNet(Module):
89 | def __init__(self, fp16=False, num_features=512):
90 | super(MobileFaceNet, self).__init__()
91 | scale = 2
92 | self.fp16 = fp16
93 | self.layers = nn.Sequential(
94 | ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95 | ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96 | DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97 | Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98 | DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99 | Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100 | DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101 | Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102 | )
103 | self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104 | self.features = GDC(num_features)
105 | self._initialize_weights()
106 |
107 | def _initialize_weights(self):
108 | for m in self.modules():
109 | if isinstance(m, nn.Conv2d):
110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111 | if m.bias is not None:
112 | m.bias.data.zero_()
113 | elif isinstance(m, nn.BatchNorm2d):
114 | m.weight.data.fill_(1)
115 | m.bias.data.zero_()
116 | elif isinstance(m, nn.Linear):
117 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118 | if m.bias is not None:
119 | m.bias.data.zero_()
120 |
121 | def forward(self, x):
122 | with torch.cuda.amp.autocast(self.fp16):
123 | x = self.layers(x)
124 | x = self.conv_sep(x.float() if self.fp16 else x)
125 | x = self.features(x)
126 | return x
127 |
128 |
129 | def get_mbf(fp16, num_features):
130 | return MobileFaceNet(fp16, num_features)
--------------------------------------------------------------------------------
/nodes/SadTalkerNode.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import torch # type: ignore
5 | import numpy as np # type: ignore
6 | from PIL import Image # type: ignore
7 | import soundfile as sf # type: ignore
8 | from datetime import datetime
9 | import folder_paths # type: ignore
10 |
11 |
12 | from ..SadTalker.src.Pb import SadTalker
13 |
14 |
15 | class SadTalkerNode:
16 |
17 | def __init__(self):
18 | self.checkpoint_path = os.path.abspath(
19 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "checkpoints")
20 | )
21 | self.config_path = os.path.abspath(
22 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "src", "config")
23 | )
24 | self.output_dir = os.path.abspath(
25 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "output")
26 | )
27 | self.comfy_output_dir = folder_paths.get_output_directory()
28 | self.input_dir = os.path.abspath(
29 | os.path.join(os.path.dirname(__file__), "..", "SadTalker", "input")
30 | )
31 | os.makedirs(self.output_dir, exist_ok=True)
32 | os.makedirs(self.input_dir, exist_ok=True)
33 | os.makedirs(self.comfy_output_dir, exist_ok=True)
34 |
35 | # 实例化 SadTalker
36 | self.sad_talker = SadTalker(
37 | checkpoint_path=self.checkpoint_path,
38 | config_path=self.config_path,
39 | lazy_load=True,
40 | )
41 |
42 | @classmethod
43 | def INPUT_TYPES(cls):
44 | return {
45 | "required": {
46 | "image": ("IMAGE",),
47 | "audio": ("AUDIO",),
48 | "poseStyle": ("INT", {"default": 0, "min": 0, "max": 46}),
49 | "faceModelResolution": (["256", "512"], {"default": "256"}),
50 | "preprocess": (
51 | ["crop", "resize", "full", "extcrop", "extfull"],
52 | {"default": "crop"},
53 | ),
54 | "stillMode": ("BOOLEAN", {"default": False}),
55 | "batchSizeInGeneration": (
56 | "INT",
57 | {"default": 2, "min": 0, "max": 10},
58 | ),
59 | "gfpganAsFaceEnhancer": ("BOOLEAN", {"default": False}),
60 | "useIdleMode": ("BOOLEAN", {"default": False}),
61 | "idleModeTime": ("INT", {"default": 5, "min": 1, "max": 90}),
62 | "useRefVideo": ("BOOLEAN", {"default": False}),
63 | "refInfo": (
64 | ["pose", "blink", "pose+blink", "all"],
65 | {"default": "pose"},
66 | ),
67 | },
68 | "optional": {
69 | "refVideo": ("VIDEOSTRING",),
70 | },
71 | }
72 |
73 | RETURN_TYPES = (
74 | "STRING",
75 | "STRING",
76 | )
77 | RETURN_NAMES = (
78 | "video_path",
79 | "show_video_path",
80 | )
81 | FUNCTION = "generate"
82 | CATEGORY = "SadTalker"
83 |
84 | def generate(
85 | self,
86 | image,
87 | audio,
88 | poseStyle,
89 | faceModelResolution,
90 | preprocess,
91 | stillMode,
92 | batchSizeInGeneration,
93 | gfpganAsFaceEnhancer,
94 | useIdleMode,
95 | idleModeTime,
96 | useRefVideo,
97 | refInfo,
98 | refVideo=None,
99 | ):
100 |
101 | # 生成时间戳目录
102 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
103 | input_dir = os.path.join(self.input_dir, timestamp)
104 | os.makedirs(input_dir, exist_ok=True)
105 |
106 | # 保存图像
107 | if isinstance(image, torch.Tensor):
108 | image = image.cpu().numpy()
109 | if image.ndim == 4:
110 | image = image[0]
111 | if image.ndim == 3:
112 | if image.shape[2] in {1, 3}:
113 | if image.shape[2] == 1:
114 | image = np.squeeze(image, axis=2)
115 | else:
116 | raise ValueError("Unexpected number of channels in image tensor")
117 |
118 | image = (image * 255).astype(np.uint8)
119 | image_path = os.path.join(input_dir, "input_image.png")
120 | Image.fromarray(image).save(image_path)
121 |
122 | # 提取并保存音频
123 | waveform = audio["waveform"].cpu().numpy().squeeze()
124 | sample_rate = audio["sample_rate"]
125 | if waveform.ndim == 2:
126 | waveform = waveform.T
127 | audio_path = os.path.join(input_dir, "input_audio.wav")
128 | sf.write(audio_path, waveform, sample_rate)
129 |
130 | result = self.sad_talker.test(
131 | source_image=image_path,
132 | driven_audio=audio_path,
133 | preprocess=preprocess,
134 | still_mode=stillMode,
135 | use_enhancer=gfpganAsFaceEnhancer,
136 | batch_size=batchSizeInGeneration,
137 | size=int(faceModelResolution),
138 | pose_style=poseStyle,
139 | exp_scale=1.0,
140 | use_ref_video=useRefVideo,
141 | ref_video=refVideo,
142 | ref_info=refInfo,
143 | use_idle_mode=useIdleMode,
144 | length_of_audio=idleModeTime,
145 | use_blink=True,
146 | result_dir=self.output_dir,
147 | )
148 |
149 | comy_out = os.path.join(self.comfy_output_dir, timestamp + ".mp4")
150 | shutil.copy(result, comy_out)
151 |
152 | return (
153 | result,
154 | comy_out,
155 | )
156 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/utils/utils_callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from typing import List
4 |
5 | import torch
6 |
7 | from eval import verification
8 | from utils.utils_logging import AverageMeter
9 |
10 |
11 | class CallBackVerification(object):
12 | def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
13 | self.frequent: int = frequent
14 | self.rank: int = rank
15 | self.highest_acc: float = 0.0
16 | self.highest_acc_list: List[float] = [0.0] * len(val_targets)
17 | self.ver_list: List[object] = []
18 | self.ver_name_list: List[str] = []
19 | if self.rank is 0:
20 | self.init_dataset(
21 | val_targets=val_targets, data_dir=rec_prefix, image_size=image_size
22 | )
23 |
24 | def ver_test(self, backbone: torch.nn.Module, global_step: int):
25 | results = []
26 | for i in range(len(self.ver_list)):
27 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
28 | self.ver_list[i], backbone, 10, 10
29 | )
30 | if acc2 > self.highest_acc_list[i]:
31 | self.highest_acc_list[i] = acc2
32 | results.append(acc2)
33 |
34 | def init_dataset(self, val_targets, data_dir, image_size):
35 | for name in val_targets:
36 | path = os.path.join(data_dir, name + ".bin")
37 | if os.path.exists(path):
38 | data_set = verification.load_bin(path, image_size)
39 | self.ver_list.append(data_set)
40 | self.ver_name_list.append(name)
41 |
42 | def __call__(self, num_update, backbone: torch.nn.Module):
43 | if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
44 | backbone.eval()
45 | self.ver_test(backbone, num_update)
46 | backbone.train()
47 |
48 |
49 | class CallBackLogging(object):
50 | def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
51 | self.frequent: int = frequent
52 | self.rank: int = rank
53 | self.time_start = time.time()
54 | self.total_step: int = total_step
55 | self.batch_size: int = batch_size
56 | self.world_size: int = world_size
57 | self.writer = writer
58 |
59 | self.init = False
60 | self.tic = 0
61 |
62 | def __call__(
63 | self,
64 | global_step: int,
65 | loss: AverageMeter,
66 | epoch: int,
67 | fp16: bool,
68 | learning_rate: float,
69 | grad_scaler: torch.cuda.amp.GradScaler,
70 | ):
71 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
72 | if self.init:
73 | try:
74 | speed: float = (
75 | self.frequent * self.batch_size / (time.time() - self.tic)
76 | )
77 | speed_total = speed * self.world_size
78 | except ZeroDivisionError:
79 | speed_total = float("inf")
80 |
81 | time_now = (time.time() - self.time_start) / 3600
82 | time_total = time_now / ((global_step + 1) / self.total_step)
83 | time_for_end = time_total - time_now
84 | if self.writer is not None:
85 | self.writer.add_scalar("time_for_end", time_for_end, global_step)
86 | self.writer.add_scalar("learning_rate", learning_rate, global_step)
87 | self.writer.add_scalar("loss", loss.avg, global_step)
88 | if fp16:
89 | msg = (
90 | "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d "
91 | "Fp16 Grad Scale: %2.f Required: %1.f hours"
92 | % (
93 | speed_total,
94 | loss.avg,
95 | learning_rate,
96 | epoch,
97 | global_step,
98 | grad_scaler.get_scale(),
99 | time_for_end,
100 | )
101 | )
102 | else:
103 | msg = (
104 | "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d "
105 | "Required: %1.f hours"
106 | % (
107 | speed_total,
108 | loss.avg,
109 | learning_rate,
110 | epoch,
111 | global_step,
112 | time_for_end,
113 | )
114 | )
115 | loss.reset()
116 | self.tic = time.time()
117 | else:
118 | self.init = True
119 | self.tic = time.time()
120 |
121 |
122 | class CallBackModelCheckpoint(object):
123 | def __init__(self, rank, output="./"):
124 | self.rank: int = rank
125 | self.output: str = output
126 |
127 | def __call__(
128 | self,
129 | global_step,
130 | backbone,
131 | partial_fc,
132 | ):
133 | if global_step > 100 and self.rank == 0:
134 | path_module = os.path.join(self.output, "backbone.pth")
135 | torch.save(backbone.module.state_dict(), path_module)
136 |
137 | if global_step > 100 and partial_fc is not None:
138 | partial_fc.save_params()
139 |
--------------------------------------------------------------------------------
/SadTalker/src/face3d/models/arcface_torch/docs/speed_benchmark.md:
--------------------------------------------------------------------------------
1 | ## Test Training Speed
2 |
3 | - Test Commands
4 |
5 | You need to use the following two commands to test the Partial FC training performance.
6 | The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
7 | batch size is 1024.
8 | ```shell
9 | # Model Parallel
10 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
11 | # Partial FC 0.1
12 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
13 | ```
14 |
15 | - GPU Memory
16 |
17 | ```
18 | # (Model Parallel) gpustat -i
19 | [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
20 | [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
21 | [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
22 | [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
23 | [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
24 | [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
25 | [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
26 | [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
27 |
28 | # (Partial FC 0.1) gpustat -i
29 | [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
30 | [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
31 | [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
32 | [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
33 | [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
34 | [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
35 | [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
36 | [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
37 | ```
38 |
39 | - Training Speed
40 |
41 | ```python
42 | # (Model Parallel) trainging.log
43 | Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
44 | Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
45 | Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
46 | Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
47 | Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
48 |
49 | # (Partial FC 0.1) trainging.log
50 | Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
51 | Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
52 | Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
53 | Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
54 | Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
55 | ```
56 |
57 | In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
58 | and the training speed is 2.5 times faster than the model parallel.
59 |
60 |
61 | ## Speed Benchmark
62 |
63 | 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
64 |
65 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
66 | | :--- | :--- | :--- | :--- |
67 | |125000 | 4681 | 4824 | 5004 |
68 | |250000 | 4047 | 4521 | 4976 |
69 | |500000 | 3087 | 4013 | 4900 |
70 | |1000000 | 2090 | 3449 | 4803 |
71 | |1400000 | 1672 | 3043 | 4738 |
72 | |2000000 | - | 2593 | 4626 |
73 | |4000000 | - | 1748 | 4208 |
74 | |5500000 | - | 1389 | 3975 |
75 | |8000000 | - | - | 3565 |
76 | |16000000 | - | - | 2679 |
77 | |29000000 | - | - | 1855 |
78 |
79 | 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
80 |
81 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
82 | | :--- | :--- | :--- | :--- |
83 | |125000 | 7358 | 5306 | 4868 |
84 | |250000 | 9940 | 5826 | 5004 |
85 | |500000 | 14220 | 7114 | 5202 |
86 | |1000000 | 23708 | 9966 | 5620 |
87 | |1400000 | 32252 | 11178 | 6056 |
88 | |2000000 | - | 13978 | 6472 |
89 | |4000000 | - | 23238 | 8284 |
90 | |5500000 | - | 32188 | 9854 |
91 | |8000000 | - | - | 12310 |
92 | |16000000 | - | - | 19950 |
93 | |29000000 | - | - | 32324 |
94 |
--------------------------------------------------------------------------------