├── insightface_func ├── __init__.py ├── models │ └── antelope │ │ └── scrfd_2.5g_bnkps.onnx └── face_detect_crop_single.py ├── Temp └── temp.txt ├── checkpoints └── put wav2lip onnx models here ├── requirements.txt ├── setup.txt ├── convert2onnx └── makeonnx.py ├── README.md ├── hparams.py ├── audio_orig.py ├── audio.py └── inference_onnxModel.py /insightface_func/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Temp/temp.txt: -------------------------------------------------------------------------------- 1 | required for temp files -------------------------------------------------------------------------------- /checkpoints/put wav2lip onnx models here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.8.0.76 2 | numpy 3 | tqdm 4 | librosa 5 | numba 6 | insightface==0.2.1 7 | onnxruntime -------------------------------------------------------------------------------- /insightface_func/models/antelope/scrfd_2.5g_bnkps.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instant-high/wav2lip-onnx/HEAD/insightface_func/models/antelope/scrfd_2.5g_bnkps.onnx -------------------------------------------------------------------------------- /setup.txt: -------------------------------------------------------------------------------- 1 | conda create -n wav2lip_onnx python=3.7 2 | conda activate wav2lip_onnx 3 | cd c:\tutorial\wav2lip_onnx 4 | pip install -r requirements.txt 5 | 6 | for use with Nvidia GPU: 7 | conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0 (version depending on your graphic card model) 8 | pip uninstall onnxruntime 9 | pip install onnxruntime-gpu 10 | 11 | maybe it's neccessary to also 12 | pip install opencv-python 13 | 14 | --------------------------- 15 | if you get some "onnx 1.9 providers" error: 16 | 17 | Edit this file: 18 | e.g. File "C:\Users\.conda\envs\ENVname\lib\site-packages\insightface\model_zoo\model_zoo.py" 19 | line 23, in get_model 20 | 21 | change: 22 | session = onnxruntime.InferenceSession(self.onnx_file, None) 23 | 24 | to: 25 | session = onnxruntime.InferenceSession(self.onnx_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] 26 | --------------------------- 27 | 28 | inference: 29 | python -W ignore inference_onnxModel.py --checkpoint_path "checkpoints\wav2lip_gan.onnx" --face "D:\some.mp4" --audio "D:\some.wav" --outfile "D:\output.mp4" --nosmooth --pads 0 10 0 0 --fps 29.97 30 | -------------------------------------------------------------------------------- /convert2onnx/makeonnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.onnx 3 | from models import Wav2Lip 4 | import sys 5 | 6 | model = Wav2Lip() 7 | checkpoint_path = "checkpoints/wav2lip_gan.pth" 8 | #checkpoint_path = "checkpoints/wav2lip.pth" 9 | 10 | 11 | modelz = torch.load(checkpoint_path) 12 | 13 | device="cpu" 14 | s = modelz["state_dict"] 15 | new_s = {} 16 | for k, v in s.items(): 17 | new_s[k.replace('module.', '')] = v 18 | 19 | 20 | model.load_state_dict(new_s) 21 | model = model.to(device) 22 | 23 | model.eval() 24 | 25 | input_shape1 = (1, 1, 80, 16) 26 | input_shape2 = (1, 6, 96, 96) 27 | 28 | dynamic_axes = {'mel_spectrogram': {0: 'batch_size'}, 'video_frames': {0: 'batch_size'}} 29 | 30 | # or "checkpoints/wav2lip.onnx" 31 | torch.onnx.export(model, 32 | (torch.randn(input_shape1, device=device), torch.randn(input_shape2, device=device)), 33 | "checkpoints/wav2lip_gan.onnx", 34 | export_params=True, 35 | opset_version=10, 36 | input_names=["mel_spectrogram", "video_frames"], 37 | output_names=["predicted_frames"], 38 | dynamic_axes=dynamic_axes, 39 | verbose=False) 40 | print("Done !!") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wav2lip-onnx 2 | This is my modified minimum wav2lip version. 3 | 4 | No torch required. 5 | 6 | Inference is quite fast running on CPU using the converted wav2lip onnx models and antelope face detection. 7 | Can be run on Nvidia GPU, tested on RTX3060 8 | Update: tested on GTX1050 9 | 10 | No additional functions like face enhancement, face alignment. 11 | Just same functions as the original repository 12 | 13 | (Modified)Face detection is taken from simswap 14 | 15 | Installation: 16 | Clone this repository and read Setup.txt 17 | 18 | Don't forget to install ffmpeg and set path variable. 19 | 20 | 21 | 22 | Face detection checkpoint already in insightface_func/models/antelope 23 | 24 | Converted wav2lip / wav2lip_gan checkpoints can be downloaded here: 25 | https://drive.google.com/file/d/1_l4QC2RJ9nXapSQRD61-Q4KbSApc53HM/view?usp=sharing. 26 | 27 | Unzip to checkpoints folder 28 | 29 | If you want to convert the wav2lip checkpoints yourself (folder convert2onnx/makeonnx.py) you have to run the script in the root of an original wav2lip installation. 30 | 31 | Original wav2lip - https://github.com/Rudrabha/Wav2Lip 32 | 33 | SimSwap - https://github.com/neuralchen/SimSwap 34 | 35 | 36 | Some of my older, not yet public, projects you can find here: 37 | https://www.youtube.com/playlist?list=PLvwlV1S1SYHBjPjwY49KF5d-z59a32v8C 38 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | def get_image_list(data_root, split): 5 | filelist = [] 6 | 7 | with open('filelists/{}.txt'.format(split)) as f: 8 | for line in f: 9 | line = line.strip() 10 | if ' ' in line: line = line.split()[0] 11 | filelist.append(os.path.join(data_root, line)) 12 | 13 | return filelist 14 | 15 | class HParams: 16 | def __init__(self, **kwargs): 17 | self.data = {} 18 | 19 | for key, value in kwargs.items(): 20 | self.data[key] = value 21 | 22 | def __getattr__(self, key): 23 | if key not in self.data: 24 | raise AttributeError("'HParams' object has no attribute %s" % key) 25 | return self.data[key] 26 | 27 | def set_hparam(self, key, value): 28 | self.data[key] = value 29 | 30 | 31 | # Default hyperparameters 32 | hparams = HParams( 33 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 34 | # network 35 | rescale=True, # Whether to rescale audio prior to preprocessing 36 | rescaling_max=0.9, # Rescaling value 37 | 38 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 39 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 40 | # Does not work if n_ffit is not multiple of hop_size!! 41 | use_lws=False, 42 | 43 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 44 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 45 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 46 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 47 | 48 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 49 | 50 | # Mel and Linear spectrograms normalization/scaling and clipping 51 | signal_normalization=True, 52 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 53 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 54 | symmetric_mels=True, 55 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 56 | # faster and cleaner convergence) 57 | max_abs_value=4., 58 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 59 | # be too big to avoid gradient explosion, 60 | # not too small for fast convergence) 61 | # Contribution by @begeekmyfriend 62 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 63 | # levels. Also allows for better G&L phase reconstruction) 64 | preemphasize=True, # whether to apply filter 65 | preemphasis=0.97, # filter coefficient. 66 | 67 | # Limits 68 | min_level_db=-100, 69 | ref_level_db=20, 70 | fmin=55, 71 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 72 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 73 | fmax=7600, # To be increased/reduced depending on data. 74 | 75 | ###################### Our training parameters ################################# 76 | img_size=96, 77 | fps=25, 78 | 79 | batch_size=16, 80 | initial_learning_rate=1e-4, 81 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 82 | num_workers=16, 83 | checkpoint_interval=3000, 84 | eval_interval=3000, 85 | save_optimizer_state=True, 86 | 87 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 88 | syncnet_batch_size=64, 89 | syncnet_lr=1e-4, 90 | syncnet_eval_interval=10000, 91 | syncnet_checkpoint_interval=10000, 92 | 93 | disc_wt=0.07, 94 | disc_initial_learning_rate=1e-4, 95 | ) 96 | 97 | 98 | def hparams_debug_string(): 99 | values = hparams.values() 100 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 101 | return "Hyperparameters:\n" + "\n".join(hp) 102 | -------------------------------------------------------------------------------- /insightface_func/face_detect_crop_single.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Naiyuan liu 3 | Github: https://github.com/NNNNAI 4 | Date: 2021-11-23 17:03:58 5 | LastEditors: Naiyuan liu 6 | LastEditTime: 2021-11-24 16:46:04 7 | Description: 8 | ''' 9 | from __future__ import division 10 | import collections 11 | import numpy as np 12 | import glob 13 | import os 14 | import os.path as osp 15 | import cv2 16 | from insightface.model_zoo import model_zoo 17 | #from insightface_func.utils import face_align_ffhqandnewarc as face_align 18 | 19 | __all__ = ['Face_detect_crop', 'Face'] 20 | 21 | Face = collections.namedtuple('Face', [ 22 | 'bbox', 'kps', 'det_score', 'embedding', 'gender', 'age', 23 | 'embedding_norm', 'normed_embedding', 24 | 'landmark' 25 | ]) 26 | 27 | Face.__new__.__defaults__ = (None, ) * len(Face._fields) 28 | 29 | 30 | class Face_detect_crop: 31 | def __init__(self, name, root='~/.insightface_func/models'): 32 | self.models = {} 33 | root = os.path.expanduser(root) 34 | onnx_files = glob.glob(osp.join(root, name, '*.onnx')) 35 | onnx_files = sorted(onnx_files) 36 | for onnx_file in onnx_files: 37 | if onnx_file.find('_selfgen_')>0: 38 | #print('ignore:', onnx_file) 39 | continue 40 | model = model_zoo.get_model(onnx_file) 41 | if model.taskname not in self.models: 42 | print('find model:', onnx_file, model.taskname) 43 | self.models[model.taskname] = model 44 | else: 45 | print('duplicated model task type, ignore:', onnx_file, model.taskname) 46 | del model 47 | assert 'detection' in self.models 48 | self.det_model = self.models['detection'] 49 | 50 | 51 | def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640), mode ='None'): 52 | self.det_thresh = det_thresh 53 | self.mode = mode 54 | assert det_size is not None 55 | print('set det-size:', det_size) 56 | self.det_size = det_size 57 | for taskname, model in self.models.items(): 58 | if taskname=='detection': 59 | model.prepare(ctx_id, input_size=det_size) 60 | else: 61 | model.prepare(ctx_id) 62 | 63 | def get(self, img, crop_size, max_num=0): 64 | bboxes, kpss = self.det_model.detect(img, 65 | #threshold=self.det_thresh, 66 | max_num=max_num, 67 | metric='default') 68 | if bboxes.shape[0] == 0: 69 | return None 70 | # ret = [] 71 | # for i in range(bboxes.shape[0]): 72 | # bbox = bboxes[i, 0:4] 73 | # det_score = bboxes[i, 4] 74 | # kps = None 75 | # if kpss is not None: 76 | # kps = kpss[i] 77 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None') 78 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 79 | # for i in range(bboxes.shape[0]): 80 | # kps = None 81 | # if kpss is not None: 82 | # kps = kpss[i] 83 | # M, _ = face_align.estimate_norm(kps, crop_size, mode ='None') 84 | # align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 85 | 86 | det_score = bboxes[..., 4] 87 | 88 | # select the face with the hightest detection score 89 | best_index = np.argmax(det_score) 90 | 91 | kps = None 92 | if kpss is not None: 93 | kps = kpss[best_index] 94 | M, _ = face_align.estimate_norm(kps, crop_size, mode = self.mode) 95 | align_img = cv2.warpAffine(img, M, (crop_size, crop_size), borderValue=0.0) 96 | 97 | return [align_img], [M] 98 | 99 | def getBox(self, img, max_num=0): 100 | bboxes, kpss = self.det_model.detect(img, 101 | #threshold=self.det_thresh, 102 | max_num=max_num, 103 | metric='default') 104 | if bboxes.shape[0] == 0: 105 | return None 106 | 107 | x1 = int(bboxes[0, 0:1]) 108 | y1 = int(bboxes[0, 1:2]) 109 | x2 = int(bboxes[0, 2:3]) 110 | y2 = int(bboxes[0, 3:4]) 111 | 112 | 113 | return (x1,y1,x2,y2) -------------------------------------------------------------------------------- /audio_orig.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | # import tensorflow as tf 5 | from scipy import signal 6 | from scipy.io import wavfile 7 | from hparams import hparams as hp 8 | 9 | def load_wav(path, sr): 10 | return librosa.core.load(path, sr=sr)[0] 11 | 12 | def save_wav(wav, path, sr): 13 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 14 | #proposed by @dsmiller 15 | wavfile.write(path, sr, wav.astype(np.int16)) 16 | 17 | def save_wavenet_wav(wav, path, sr): 18 | librosa.output.write_wav(path, wav, sr=sr) 19 | 20 | def preemphasis(wav, k, preemphasize=True): 21 | if preemphasize: 22 | return signal.lfilter([1, -k], [1], wav) 23 | return wav 24 | 25 | def inv_preemphasis(wav, k, inv_preemphasize=True): 26 | if inv_preemphasize: 27 | return signal.lfilter([1], [1, -k], wav) 28 | return wav 29 | 30 | def get_hop_size(): 31 | hop_size = hp.hop_size 32 | if hop_size is None: 33 | assert hp.frame_shift_ms is not None 34 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 35 | return hop_size 36 | 37 | def linearspectrogram(wav): 38 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 39 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 40 | 41 | if hp.signal_normalization: 42 | return _normalize(S) 43 | return S 44 | 45 | def melspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | def _lws_processor(): 54 | import lws 55 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 56 | 57 | def _stft(y): 58 | if hp.use_lws: 59 | return _lws_processor(hp).stft(y).T 60 | else: 61 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 62 | 63 | ########################################################## 64 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 65 | def num_frames(length, fsize, fshift): 66 | """Compute number of time frames of spectrogram 67 | """ 68 | pad = (fsize - fshift) 69 | if length % fshift == 0: 70 | M = (length + pad * 2 - fsize) // fshift + 1 71 | else: 72 | M = (length + pad * 2 - fsize) // fshift + 2 73 | return M 74 | 75 | 76 | def pad_lr(x, fsize, fshift): 77 | """Compute left and right padding 78 | """ 79 | M = num_frames(len(x), fsize, fshift) 80 | pad = (fsize - fshift) 81 | T = len(x) + 2 * pad 82 | r = (M - 1) * fshift + fsize - T 83 | return pad, pad + r 84 | ########################################################## 85 | #Librosa correct padding 86 | def librosa_pad_lr(x, fsize, fshift): 87 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 88 | 89 | # Conversions 90 | _mel_basis = None 91 | 92 | def _linear_to_mel(spectogram): 93 | global _mel_basis 94 | if _mel_basis is None: 95 | _mel_basis = _build_mel_basis() 96 | return np.dot(_mel_basis, spectogram) 97 | 98 | def _build_mel_basis(): 99 | assert hp.fmax <= hp.sample_rate // 2 100 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, 101 | fmin=hp.fmin, fmax=hp.fmax) 102 | 103 | def _amp_to_db(x): 104 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 105 | return 20 * np.log10(np.maximum(min_level, x)) 106 | 107 | def _db_to_amp(x): 108 | return np.power(10.0, (x) * 0.05) 109 | 110 | def _normalize(S): 111 | if hp.allow_clipping_in_normalization: 112 | if hp.symmetric_mels: 113 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 114 | -hp.max_abs_value, hp.max_abs_value) 115 | else: 116 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 117 | 118 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 119 | if hp.symmetric_mels: 120 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 121 | else: 122 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 123 | 124 | def _denormalize(D): 125 | if hp.allow_clipping_in_normalization: 126 | if hp.symmetric_mels: 127 | return (((np.clip(D, -hp.max_abs_value, 128 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 129 | + hp.min_level_db) 130 | else: 131 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 132 | 133 | if hp.symmetric_mels: 134 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 135 | else: 136 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 137 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | # import tensorflow as tf 5 | from scipy import signal 6 | from scipy.io import wavfile 7 | from hparams import hparams as hp 8 | 9 | def load_wav(path, sr): 10 | return librosa.core.load(path, sr=sr)[0] 11 | 12 | def save_wav(wav, path, sr): 13 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 14 | #proposed by @dsmiller 15 | wavfile.write(path, sr, wav.astype(np.int16)) 16 | 17 | def save_wavenet_wav(wav, path, sr): 18 | librosa.output.write_wav(path, wav, sr=sr) 19 | 20 | def preemphasis(wav, k, preemphasize=True): 21 | if preemphasize: 22 | return signal.lfilter([1, -k], [1], wav) 23 | return wav 24 | 25 | def inv_preemphasis(wav, k, inv_preemphasize=True): 26 | if inv_preemphasize: 27 | return signal.lfilter([1], [1, -k], wav) 28 | return wav 29 | 30 | def get_hop_size(): 31 | hop_size = hp.hop_size 32 | if hop_size is None: 33 | assert hp.frame_shift_ms is not None 34 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 35 | return hop_size 36 | 37 | def linearspectrogram(wav): 38 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 39 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 40 | 41 | if hp.signal_normalization: 42 | return _normalize(S) 43 | return S 44 | 45 | def melspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | def _lws_processor(): 54 | import lws 55 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 56 | 57 | def _stft(y): 58 | if hp.use_lws: 59 | return _lws_processor(hp).stft(y).T 60 | else: 61 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 62 | 63 | ########################################################## 64 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 65 | def num_frames(length, fsize, fshift): 66 | """Compute number of time frames of spectrogram 67 | """ 68 | pad = (fsize - fshift) 69 | if length % fshift == 0: 70 | M = (length + pad * 2 - fsize) // fshift + 1 71 | else: 72 | M = (length + pad * 2 - fsize) // fshift + 2 73 | return M 74 | 75 | 76 | def pad_lr(x, fsize, fshift): 77 | """Compute left and right padding 78 | """ 79 | M = num_frames(len(x), fsize, fshift) 80 | pad = (fsize - fshift) 81 | T = len(x) + 2 * pad 82 | r = (M - 1) * fshift + fsize - T 83 | return pad, pad + r 84 | ########################################################## 85 | #Librosa correct padding 86 | def librosa_pad_lr(x, fsize, fshift): 87 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 88 | 89 | # Conversions 90 | _mel_basis = None 91 | 92 | def _linear_to_mel(spectogram): 93 | global _mel_basis 94 | if _mel_basis is None: 95 | _mel_basis = _build_mel_basis() 96 | return np.dot(_mel_basis, spectogram) 97 | 98 | def _build_mel_basis(): 99 | assert hp.fmax <= hp.sample_rate // 2 100 | return librosa.filters.mel(sr=hp.sample_rate, n_fft= hp.n_fft, n_mels=hp.num_mels, 101 | fmin=hp.fmin, fmax=hp.fmax) 102 | 103 | def _amp_to_db(x): 104 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 105 | return 20 * np.log10(np.maximum(min_level, x)) 106 | 107 | def _db_to_amp(x): 108 | return np.power(10.0, (x) * 0.05) 109 | 110 | def _normalize(S): 111 | if hp.allow_clipping_in_normalization: 112 | if hp.symmetric_mels: 113 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 114 | -hp.max_abs_value, hp.max_abs_value) 115 | else: 116 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 117 | 118 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 119 | if hp.symmetric_mels: 120 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 121 | else: 122 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 123 | 124 | def _denormalize(D): 125 | if hp.allow_clipping_in_normalization: 126 | if hp.symmetric_mels: 127 | return (((np.clip(D, -hp.max_abs_value, 128 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 129 | + hp.min_level_db) 130 | else: 131 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 132 | 133 | if hp.symmetric_mels: 134 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 135 | else: 136 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 137 | -------------------------------------------------------------------------------- /inference_onnxModel.py: -------------------------------------------------------------------------------- 1 | from os import listdir, path 2 | import numpy as np 3 | import scipy, cv2, os, sys, argparse, audio 4 | import json, subprocess, random, string 5 | from tqdm import tqdm 6 | #from glob import glob 7 | #import torch, face_detection 8 | #import face_detection 9 | #from models import Wav2Lip 10 | import platform 11 | from PIL import Image 12 | 13 | import onnxruntime 14 | onnxruntime.set_default_logger_severity(3) 15 | from insightface_func.face_detect_crop_single import Face_detect_crop 16 | 17 | parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models') 18 | 19 | parser.add_argument('--checkpoint_path', type=str, help='Name of saved checkpoint to load weights from', required=True) 20 | 21 | parser.add_argument('--face', type=str, help='Filepath of video/image that contains faces to use', required=True) 22 | 23 | parser.add_argument('--audio', type=str, help='Filepath of video/audio file to use as raw audio source', required=True) 24 | 25 | parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', default='results/result_voice.mp4') 26 | 27 | parser.add_argument('--static', type=bool, help='If True, then use only first video frame for inference', default=False) 28 | 29 | parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', default=25., required=False) 30 | 31 | parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], help='Padding (top, bottom, left, right). Please adjust to include chin at least') 32 | 33 | parser.add_argument('--face_det_batch_size', type=int, help='Batch size for face detection', default=16) 34 | 35 | parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=1) 36 | 37 | parser.add_argument('--resize_factor', default=1, type=int, help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') 38 | 39 | parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') 40 | 41 | parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.''Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') 42 | 43 | parser.add_argument('--rotate', default=False, action='store_true',help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.''Use if you get a flipped result, despite feeding a normal looking video') 44 | 45 | parser.add_argument('--nosmooth', default=False, action='store_true',help='Prevent smoothing face detections over a short temporal window') 46 | 47 | parser.add_argument('--preview', default=False, action='store_true',help='Preview during inference') 48 | 49 | args = parser.parse_args() 50 | args.img_size = 96 51 | 52 | if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: 53 | args.static = True 54 | 55 | 56 | def load_model(device): 57 | model_path = args.checkpoint_path 58 | session_options = onnxruntime.SessionOptions() 59 | session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 60 | providers = ["CPUExecutionProvider"] 61 | if device == 'cuda': 62 | providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),"CPUExecutionProvider"] 63 | session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers) 64 | 65 | return session 66 | 67 | def get_smoothened_boxes(boxes, T): 68 | for i in range(len(boxes)): 69 | if i + T > len(boxes): 70 | window = boxes[len(boxes) - T:] 71 | else: 72 | window = boxes[i : i + T] 73 | boxes[i] = np.mean(window, axis=0) 74 | return boxes 75 | 76 | def face_detect(images): 77 | 78 | detector = Face_detect_crop(name='antelope', root='./insightface_func/models') 79 | detector.prepare(ctx_id= 0, det_thresh=0.3, det_size=(320,320),mode='none') 80 | 81 | predictions = [] 82 | for i in tqdm(range(0, len(images))): 83 | bbox = detector.getBox(images[i]) 84 | predictions.append(bbox) 85 | 86 | results = [] 87 | pady1, pady2, padx1, padx2 = args.pads 88 | for rect, image in zip(predictions, images): 89 | if rect is None: 90 | cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. 91 | raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') 92 | 93 | y1 = max(0, rect[1] - pady1) 94 | y2 = min(image.shape[0], rect[3] + pady2) 95 | x1 = max(0, rect[0] - padx1) 96 | x2 = min(image.shape[1], rect[2] + padx2) 97 | 98 | results.append([x1, y1, x2, y2]) 99 | 100 | boxes = np.array(results) 101 | if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5) 102 | results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] 103 | 104 | del detector 105 | return results 106 | 107 | def datagen(frames, mels): 108 | 109 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 110 | 111 | if args.box[0] == -1: 112 | if not args.static: 113 | face_det_results = face_detect(frames) # BGR2RGB for CNN face detection 114 | else: 115 | face_det_results = face_detect([frames[0]]) 116 | else: 117 | print('Using the specified bounding box instead of face detection...') 118 | y1, y2, x1, x2 = args.box 119 | face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] 120 | 121 | for i, m in enumerate(mels): 122 | idx = 0 if args.static else i%len(frames) 123 | frame_to_save = frames[idx].copy() 124 | face, coords = face_det_results[idx].copy() 125 | 126 | face = cv2.resize(face, (args.img_size, args.img_size)) 127 | 128 | img_batch.append(face) 129 | mel_batch.append(m) 130 | frame_batch.append(frame_to_save) 131 | coords_batch.append(coords) 132 | 133 | if len(img_batch) >= args.wav2lip_batch_size: 134 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 135 | 136 | img_masked = img_batch.copy() 137 | img_masked[:, args.img_size//2:] = 0 138 | 139 | #input(img_masked.shape) 140 | 141 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 142 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 143 | 144 | yield img_batch, mel_batch, frame_batch, coords_batch 145 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 146 | 147 | if len(img_batch) > 0: 148 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 149 | 150 | img_masked = img_batch.copy() 151 | img_masked[:, args.img_size//2:] = 0 152 | 153 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 154 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 155 | 156 | yield img_batch, mel_batch, frame_batch, coords_batch 157 | 158 | mel_step_size = 16 159 | 160 | #device = 'cuda' if torch.cuda.is_available() else 'cpu' only if torch is installed 161 | device = 'cpu' 162 | if onnxruntime.get_device() == 'GPU': device = 'cuda' 163 | 164 | 165 | 166 | def to_numpy(tensor): 167 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 168 | 169 | def main(): 170 | print("Running on " + onnxruntime.get_device()) 171 | 172 | im = cv2.imread(args.face) 173 | 174 | if not os.path.isfile(args.face): 175 | raise ValueError('--face argument must be a valid path to video/image file') 176 | 177 | elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: 178 | full_frames = [cv2.imread(args.face)] 179 | fps = args.fps 180 | 181 | else: 182 | video_stream = cv2.VideoCapture(args.face) 183 | fps = video_stream.get(cv2.CAP_PROP_FPS) 184 | 185 | print('Reading video frames...') 186 | 187 | full_frames = [] 188 | while 1: 189 | still_reading, frame = video_stream.read() 190 | if not still_reading: 191 | video_stream.release() 192 | break 193 | if args.resize_factor > 1: 194 | frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor)) 195 | 196 | if args.rotate: 197 | frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) 198 | 199 | y1, y2, x1, x2 = args.crop 200 | if x2 == -1: x2 = frame.shape[1] 201 | if y2 == -1: y2 = frame.shape[0] 202 | 203 | frame = frame[y1:y2, x1:x2] 204 | 205 | full_frames.append(frame) 206 | 207 | print ("Number of frames available for inference: "+str(len(full_frames))) 208 | 209 | if not args.audio.endswith('.wav'): 210 | print('Extracting raw audio...') 211 | command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav') 212 | 213 | subprocess.call(command, shell=True) 214 | args.audio = 'temp/temp.wav' 215 | 216 | wav = audio.load_wav(args.audio, 16000) 217 | mel = audio.melspectrogram(wav) 218 | #print(mel.shape) 219 | 220 | if np.isnan(mel.reshape(-1)).sum() > 0: 221 | raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 222 | 223 | mel_chunks = [] 224 | mel_idx_multiplier = 80./fps 225 | i = 0 226 | while 1: 227 | start_idx = int(i * mel_idx_multiplier) 228 | if start_idx + mel_step_size > len(mel[0]): 229 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) 230 | break 231 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 232 | i += 1 233 | 234 | print("Length of mel chunks: {}".format(len(mel_chunks))) 235 | 236 | full_frames = full_frames[:len(mel_chunks)] 237 | 238 | batch_size = 1 239 | gen = datagen(full_frames.copy(), mel_chunks) 240 | 241 | for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, 242 | total=int(np.ceil(float(len(mel_chunks))/batch_size)))): 243 | if i == 0: 244 | 245 | model = load_model(device) # load wav2lip.onnx model) 246 | 247 | frame_h, frame_w = full_frames[0].shape[:-1] 248 | out = cv2.VideoWriter('temp/result.avi', cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) 249 | 250 | img_batch = img_batch.transpose((0, 3, 1, 2)).astype(np.float32) 251 | mel_batch = mel_batch.transpose((0, 3, 1, 2)).astype(np.float32) 252 | 253 | pred = model.run(None,{'mel_spectrogram':mel_batch, 'video_frames':img_batch})[0][0] 254 | pred = pred.transpose(1, 2, 0)*255 255 | pred = pred.astype(np.uint8) 256 | pred = pred.reshape((1, 96, 96, 3)) 257 | 258 | for p, f, c in zip(pred, frames, coords): 259 | 260 | y1, y2, x1, x2 = c 261 | p = cv2.resize(p, (x2 - x1, y2 - y1)) 262 | f[y1:y2, x1:x2] = p 263 | out.write(f) 264 | if args.preview: 265 | cv2.imshow("Result",f) 266 | cv2.waitKey(1) 267 | 268 | out.release() 269 | 270 | command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile) 271 | subprocess.call(command, shell=platform.system() != 'Windows') 272 | os.remove('temp/result.avi') 273 | 274 | if __name__ == '__main__': 275 | main() 276 | --------------------------------------------------------------------------------