├── __init__.py ├── checkpoints └── README.md ├── models ├── __init__.py ├── conv.py ├── syncnet.py └── wav2lip.py ├── requirements.txt ├── README.md ├── hparams.py ├── audio.py └── Wav2Lip.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | Place all your checkpoints (.pth files) here. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wav2lip import Wav2Lip, Wav2Lip_disc_qual 2 | from .syncnet import SyncNet_color -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.8.0.76 2 | torch==2.0.1 3 | numpy==1.23.5 4 | tqdm==4.66.1 5 | moviepy==1.0.3 6 | librosa==0.7.0 7 | numba==0.57.1 8 | -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | 15 | def forward(self, x): 16 | out = self.conv_block(x) 17 | if self.residual: 18 | out += x 19 | return self.act(out) 20 | 21 | class nonorm_Conv2d(nn.Module): 22 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.conv_block = nn.Sequential( 25 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 26 | ) 27 | self.act = nn.LeakyReLU(0.01, inplace=True) 28 | 29 | def forward(self, x): 30 | out = self.conv_block(x) 31 | return self.act(out) 32 | 33 | class Conv2dTranspose(nn.Module): 34 | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.conv_block = nn.Sequential( 37 | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), 38 | nn.BatchNorm2d(cout) 39 | ) 40 | self.act = nn.ReLU() 41 | 42 | def forward(self, x): 43 | out = self.conv_block(x) 44 | return self.act(out) 45 | -------------------------------------------------------------------------------- /models/syncnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .conv import Conv2d 6 | 7 | class SyncNet_color(nn.Module): 8 | def __init__(self): 9 | super(SyncNet_color, self).__init__() 10 | 11 | self.face_encoder = nn.Sequential( 12 | Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), 13 | 14 | Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), 15 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 16 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 17 | 18 | Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 19 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 20 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 22 | 23 | Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 24 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 25 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 26 | 27 | Conv2d(256, 512, kernel_size=3, stride=2, padding=1), 28 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 29 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 30 | 31 | Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 32 | Conv2d(512, 512, kernel_size=3, stride=1, padding=0), 33 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 34 | 35 | self.audio_encoder = nn.Sequential( 36 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 37 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 38 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 39 | 40 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 41 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 42 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 43 | 44 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 45 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 46 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 47 | 48 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 49 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 50 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 51 | 52 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 53 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 54 | 55 | def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) 56 | face_embedding = self.face_encoder(face_sequences) 57 | audio_embedding = self.audio_encoder(audio_sequences) 58 | 59 | audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) 60 | face_embedding = face_embedding.view(face_embedding.size(0), -1) 61 | 62 | audio_embedding = F.normalize(audio_embedding, p=2, dim=1) 63 | face_embedding = F.normalize(face_embedding, p=2, dim=1) 64 | 65 | 66 | return audio_embedding, face_embedding 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wav2Lip-Inference-on-Python3.9 2 | This project fixes the Wav2Lip project Inference so that it can run on Python 3.9. Wav2Lip is a project that can be used to lip-sync videos to audio. The original project was dependent on Python 3.6 and used deprecated libraries. Alot of people were unable to resolve the issues. This project fixes those problems so that Wav2Lip can now run on Python 3.9 or higher. 3 | 4 | Original Project: https://github.com/Rudrabha/Wav2Lip 5 | 6 | This repository enables you to perform lip-syncing using the Wav2Lip model directly in Python, offering an alternative to command-line usage. It provides a `Processor` class with methods to process video and audio inputs, generate lip-synced videos, and customize various options. You can also find the original command-line options available as arguments in this Python script. 7 | 8 | ## Getting Started 9 | 10 | ### Prerequisites 11 | 12 | Before using this repository, ensure you have the following prerequisites installed: 13 | 14 | - Python 3.9 or later 15 | - Dependencies listed in `requirements.txt` 16 | - Download and place pretrained checkpoints in checkpoints folder. You can get the links to download in the original repository. 17 | 18 | ### Installing 19 | 20 | To get started, clone this repository: 21 | 22 | ```bash 23 | git clone https://github.com/HassanMuhammadSannaullah/Wav2lip-Fix-For-Inference.git 24 | cd Wav2lip-Fix-For-Inference 25 | pip install -r requirements.txt 26 | ``` 27 | ## Important Note 28 | In the decorators.py file of the librose module, make the following change to ensure compatibility: 29 | 30 | ```python 31 | # Change this import line 32 | from numba.decorators import jit as optional_jit 33 | 34 | # To this 35 | from numba import jit as optional_jit 36 | ``` 37 | 38 | ## Usage 39 | You can either directly run the wav2lip.py file in the project or import Process class from it, somewhere else in the code. Following is the sample way to run the inference 40 | 41 | 1. Import Processor class 42 | ```python 43 | from Wav2Lip import Processor 44 | ``` 45 | 46 | 2. Use run method to perform inference 47 | ```python 48 | processor = Processor() 49 | processor.run("path_to_face_video_or_image", "path_to_audio.wav", "output_path.mp4") 50 | ``` 51 | 52 | Additional Options 53 | You can customize various options by providing arguments to the Processor class constructor or modifying the run method. Here are some important options: 54 | ``` 55 | # These can be set in the constructor 56 | checkpoint_path: Path to the Wav2Lip model checkpoint. 57 | nosmooth: Disable smoothening of face boxes. 58 | static: Use a static image for face detection. 59 | 60 | # All below can be set in the run function of Processor class 61 | resize_factor: Resize factor for video frames. 62 | rotate: Rotate frames (useful for portrait videos). 63 | crop: Crop the video frame [y1, y2, x1, x2]. 64 | fps: Frames per second for the output video. 65 | mel_step_size: Mel spectrogram step size. 66 | wav2lip_batch_size: Batch size for inference. 67 | ``` 68 | For detailed information on these options, refer to the code comments in the Processor class, or refer to the original implementation of wav2lip 69 | 70 | ## Disclaimer 71 | 72 | This project is provided for educational and entertainment purposes only. The author and contributors of this repository are not responsible for any harmful, unethical, or inappropriate use of the software or its outputs. Users are encouraged to adhere to ethical guidelines and legal regulations when using this project. 73 | 74 | Please use this project responsibly and consider the implications of your actions. If you have any concerns or questions regarding the usage of this software, feel free to reach out for guidance. 75 | 76 | By using this software, you agree to the above disclaimer. 77 | 78 | ## Acknowledgments 79 | This project is built upon the Wav2Lip repository by Rudrabha Mukhopadhyay. 80 | If you encounter any issues or have questions, feel free to open an issue 81 | 82 | Happy lip-syncing! 83 | 84 | 85 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | def get_image_list(data_root, split): 5 | filelist = [] 6 | 7 | with open('filelists/{}.txt'.format(split)) as f: 8 | for line in f: 9 | line = line.strip() 10 | if ' ' in line: line = line.split()[0] 11 | filelist.append(os.path.join(data_root, line)) 12 | 13 | return filelist 14 | 15 | class HParams: 16 | def __init__(self, **kwargs): 17 | self.data = {} 18 | 19 | for key, value in kwargs.items(): 20 | self.data[key] = value 21 | 22 | def __getattr__(self, key): 23 | if key not in self.data: 24 | raise AttributeError("'HParams' object has no attribute %s" % key) 25 | return self.data[key] 26 | 27 | def set_hparam(self, key, value): 28 | self.data[key] = value 29 | 30 | 31 | # Default hyperparameters 32 | hparams = HParams( 33 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 34 | # network 35 | rescale=True, # Whether to rescale audio prior to preprocessing 36 | rescaling_max=0.9, # Rescaling value 37 | 38 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 39 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 40 | # Does not work if n_ffit is not multiple of hop_size!! 41 | use_lws=False, 42 | 43 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 44 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 45 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 46 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 47 | 48 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 49 | 50 | # Mel and Linear spectrograms normalization/scaling and clipping 51 | signal_normalization=True, 52 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 53 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 54 | symmetric_mels=True, 55 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 56 | # faster and cleaner convergence) 57 | max_abs_value=4., 58 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 59 | # be too big to avoid gradient explosion, 60 | # not too small for fast convergence) 61 | # Contribution by @begeekmyfriend 62 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 63 | # levels. Also allows for better G&L phase reconstruction) 64 | preemphasize=True, # whether to apply filter 65 | preemphasis=0.97, # filter coefficient. 66 | 67 | # Limits 68 | min_level_db=-100, 69 | ref_level_db=20, 70 | fmin=55, 71 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 72 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 73 | fmax=7600, # To be increased/reduced depending on data. 74 | 75 | ###################### Our training parameters ################################# 76 | img_size=96, 77 | fps=25, 78 | 79 | batch_size=16, 80 | initial_learning_rate=1e-4, 81 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 82 | num_workers=16, 83 | checkpoint_interval=3000, 84 | eval_interval=3000, 85 | save_optimizer_state=True, 86 | 87 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 88 | syncnet_batch_size=64, 89 | syncnet_lr=1e-4, 90 | syncnet_eval_interval=10000, 91 | syncnet_checkpoint_interval=10000, 92 | 93 | disc_wt=0.07, 94 | disc_initial_learning_rate=1e-4, 95 | ) 96 | 97 | 98 | def hparams_debug_string(): 99 | values = hparams.values() 100 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 101 | return "Hyperparameters:\n" + "\n".join(hp) 102 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | 5 | # import tensorflow as tf 6 | from scipy import signal 7 | from scipy.io import wavfile 8 | from hparams import hparams as hp 9 | 10 | 11 | def load_wav(path, sr): 12 | return librosa.core.load(path, sr=sr)[0] 13 | 14 | 15 | def save_wav(wav, path, sr): 16 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 17 | # proposed by @dsmiller 18 | wavfile.write(path, sr, wav.astype(np.int16)) 19 | 20 | 21 | def save_wavenet_wav(wav, path, sr): 22 | librosa.output.write_wav(path, wav, sr=sr) 23 | 24 | 25 | def preemphasis(wav, k, preemphasize=True): 26 | if preemphasize: 27 | return signal.lfilter([1, -k], [1], wav) 28 | return wav 29 | 30 | 31 | def inv_preemphasis(wav, k, inv_preemphasize=True): 32 | if inv_preemphasize: 33 | return signal.lfilter([1], [1, -k], wav) 34 | return wav 35 | 36 | 37 | def get_hop_size(): 38 | hop_size = hp.hop_size 39 | if hop_size is None: 40 | assert hp.frame_shift_ms is not None 41 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 42 | return hop_size 43 | 44 | 45 | def linearspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | 54 | def melspectrogram(wav): 55 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 56 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 57 | 58 | if hp.signal_normalization: 59 | return _normalize(S) 60 | return S 61 | 62 | 63 | def _lws_processor(): 64 | import lws 65 | 66 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 67 | 68 | 69 | def _stft(y): 70 | if hp.use_lws: 71 | return _lws_processor(hp).stft(y).T 72 | else: 73 | return librosa.stft( 74 | y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size 75 | ) 76 | 77 | 78 | ########################################################## 79 | # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 80 | def num_frames(length, fsize, fshift): 81 | """Compute number of time frames of spectrogram""" 82 | pad = fsize - fshift 83 | if length % fshift == 0: 84 | M = (length + pad * 2 - fsize) // fshift + 1 85 | else: 86 | M = (length + pad * 2 - fsize) // fshift + 2 87 | return M 88 | 89 | 90 | def pad_lr(x, fsize, fshift): 91 | """Compute left and right padding""" 92 | M = num_frames(len(x), fsize, fshift) 93 | pad = fsize - fshift 94 | T = len(x) + 2 * pad 95 | r = (M - 1) * fshift + fsize - T 96 | return pad, pad + r 97 | 98 | 99 | ########################################################## 100 | # Librosa correct padding 101 | def librosa_pad_lr(x, fsize, fshift): 102 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 103 | 104 | 105 | # Conversions 106 | _mel_basis = None 107 | 108 | 109 | def _linear_to_mel(spectogram): 110 | global _mel_basis 111 | if _mel_basis is None: 112 | _mel_basis = _build_mel_basis() 113 | return np.dot(_mel_basis, spectogram) 114 | 115 | 116 | def _build_mel_basis(): 117 | assert hp.fmax <= hp.sample_rate // 2 118 | return librosa.filters.mel( 119 | hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin, fmax=hp.fmax 120 | ) 121 | 122 | 123 | def _amp_to_db(x): 124 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 125 | return 20 * np.log10(np.maximum(min_level, x)) 126 | 127 | 128 | def _db_to_amp(x): 129 | return np.power(10.0, (x) * 0.05) 130 | 131 | 132 | def _normalize(S): 133 | if hp.allow_clipping_in_normalization: 134 | if hp.symmetric_mels: 135 | return np.clip( 136 | (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) 137 | - hp.max_abs_value, 138 | -hp.max_abs_value, 139 | hp.max_abs_value, 140 | ) 141 | else: 142 | return np.clip( 143 | hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 144 | 0, 145 | hp.max_abs_value, 146 | ) 147 | 148 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 149 | if hp.symmetric_mels: 150 | return (2 * hp.max_abs_value) * ( 151 | (S - hp.min_level_db) / (-hp.min_level_db) 152 | ) - hp.max_abs_value 153 | else: 154 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 155 | 156 | 157 | def _denormalize(D): 158 | if hp.allow_clipping_in_normalization: 159 | if hp.symmetric_mels: 160 | return ( 161 | (np.clip(D, -hp.max_abs_value, hp.max_abs_value) + hp.max_abs_value) 162 | * -hp.min_level_db 163 | / (2 * hp.max_abs_value) 164 | ) + hp.min_level_db 165 | else: 166 | return ( 167 | np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value 168 | ) + hp.min_level_db 169 | 170 | if hp.symmetric_mels: 171 | return ( 172 | (D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value) 173 | ) + hp.min_level_db 174 | else: 175 | return (D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db 176 | -------------------------------------------------------------------------------- /models/wav2lip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d 7 | 8 | class Wav2Lip(nn.Module): 9 | def __init__(self): 10 | super(Wav2Lip, self).__init__() 11 | 12 | self.face_encoder_blocks = nn.ModuleList([ 13 | nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 14 | 15 | nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 16 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 17 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), 18 | 19 | nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 20 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 22 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), 23 | 24 | nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 25 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 26 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), 27 | 28 | nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 29 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 30 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), 31 | 32 | nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 33 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 34 | 35 | nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 36 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 37 | 38 | self.audio_encoder = nn.Sequential( 39 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 40 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 41 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 42 | 43 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 44 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 45 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 46 | 47 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 48 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 49 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 50 | 51 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 52 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 53 | 54 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 55 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 56 | 57 | self.face_decoder_blocks = nn.ModuleList([ 58 | nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), 59 | 60 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 61 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 62 | 63 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), 64 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 65 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 66 | 67 | nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), 68 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), 69 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 70 | 71 | nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), 72 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 73 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 74 | 75 | nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), 76 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 77 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 78 | 79 | nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 80 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 81 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 82 | 83 | self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), 84 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), 85 | nn.Sigmoid()) 86 | 87 | def forward(self, audio_sequences, face_sequences): 88 | # audio_sequences = (B, T, 1, 80, 16) 89 | B = audio_sequences.size(0) 90 | 91 | input_dim_size = len(face_sequences.size()) 92 | if input_dim_size > 4: 93 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) 94 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 95 | 96 | audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 97 | 98 | feats = [] 99 | x = face_sequences 100 | for f in self.face_encoder_blocks: 101 | x = f(x) 102 | feats.append(x) 103 | 104 | x = audio_embedding 105 | for f in self.face_decoder_blocks: 106 | x = f(x) 107 | try: 108 | x = torch.cat((x, feats[-1]), dim=1) 109 | except Exception as e: 110 | print(x.size()) 111 | print(feats[-1].size()) 112 | raise e 113 | 114 | feats.pop() 115 | 116 | x = self.output_block(x) 117 | 118 | if input_dim_size > 4: 119 | x = torch.split(x, B, dim=0) # [(B, C, H, W)] 120 | outputs = torch.stack(x, dim=2) # (B, C, T, H, W) 121 | 122 | else: 123 | outputs = x 124 | 125 | return outputs 126 | 127 | class Wav2Lip_disc_qual(nn.Module): 128 | def __init__(self): 129 | super(Wav2Lip_disc_qual, self).__init__() 130 | 131 | self.face_encoder_blocks = nn.ModuleList([ 132 | nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 133 | 134 | nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 135 | nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), 136 | 137 | nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 138 | nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), 139 | 140 | nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 141 | nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), 142 | 143 | nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 144 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), 145 | 146 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 147 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), 148 | 149 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 150 | nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 151 | 152 | self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) 153 | self.label_noise = .0 154 | 155 | def get_lower_half(self, face_sequences): 156 | return face_sequences[:, :, face_sequences.size(2)//2:] 157 | 158 | def to_2d(self, face_sequences): 159 | B = face_sequences.size(0) 160 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 161 | return face_sequences 162 | 163 | def perceptual_forward(self, false_face_sequences): 164 | false_face_sequences = self.to_2d(false_face_sequences) 165 | false_face_sequences = self.get_lower_half(false_face_sequences) 166 | 167 | false_feats = false_face_sequences 168 | for f in self.face_encoder_blocks: 169 | false_feats = f(false_feats) 170 | 171 | false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), 172 | torch.ones((len(false_feats), 1)).cuda()) 173 | 174 | return false_pred_loss 175 | 176 | def forward(self, face_sequences): 177 | face_sequences = self.to_2d(face_sequences) 178 | face_sequences = self.get_lower_half(face_sequences) 179 | 180 | x = face_sequences 181 | for f in self.face_encoder_blocks: 182 | x = f(x) 183 | 184 | return self.binary_pred(x).view(len(x), -1) 185 | -------------------------------------------------------------------------------- /Wav2Lip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import subprocess 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from moviepy.editor import VideoFileClip, AudioFileClip 8 | from models import Wav2Lip 9 | import audio 10 | from datetime import datetime 11 | import shutil 12 | 13 | 14 | class Processor: 15 | def __init__( 16 | self, 17 | checkpoint_path=os.path.join( 18 | "checkpoints", "wav2lip_gan.pth" 19 | ), 20 | nosmooth=False, 21 | static=False, 22 | ): 23 | self.checkpoint_path = checkpoint_path 24 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 25 | self.static = static 26 | self.nosmooth = nosmooth 27 | 28 | def get_smoothened_boxes(self, boxes, T): 29 | for i in range(len(boxes)): 30 | if i + T > len(boxes): 31 | window = boxes[len(boxes) - T :] 32 | else: 33 | window = boxes[i : i + T] 34 | boxes[i] = np.mean(window, axis=0) 35 | return boxes 36 | 37 | def face_detect(self, images): 38 | print("Detecting Faces") 39 | # Load the pre-trained Haar Cascade Classifier for face detection 40 | face_cascade = cv2.CascadeClassifier( 41 | os.path.join( 42 | "checkpoints", 43 | "haarcascade_frontalface_default.xml", 44 | ) 45 | ) # cv2.data.haarcascades 46 | pads = [0, 10, 0, 0] 47 | results = [] 48 | pady1, pady2, padx1, padx2 = pads 49 | 50 | for image in images: 51 | # Convert the image to grayscale for face detection 52 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 53 | 54 | # Detect faces in the grayscale image 55 | faces = face_cascade.detectMultiScale( 56 | gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30) 57 | ) 58 | 59 | if len(faces) > 0: 60 | # Get the first detected face (you can modify this to handle multiple faces) 61 | x, y, w, h = faces[0] 62 | 63 | # Calculate the bounding box coordinates 64 | x1 = max(0, x - padx1) 65 | x2 = min(image.shape[1], x + w + padx2) 66 | y1 = max(0, y - pady1) 67 | y2 = min(image.shape[0], y + h + pady2) 68 | 69 | results.append([x1, y1, x2, y2]) 70 | else: 71 | cv2.imwrite( 72 | os.path.join("temp","faulty_frame.jpg"), image 73 | ) # Save the frame where the face was not detected. 74 | raise ValueError("Face not detected! Ensure the image contains a face.") 75 | 76 | boxes = np.array(results) 77 | if not self.nosmooth: 78 | boxes = self.get_smoothened_boxes(boxes, 5) 79 | results = [ 80 | [image[y1:y2, x1:x2], (y1, y2, x1, x2)] 81 | for image, (x1, y1, x2, y2) in zip(images, boxes) 82 | ] 83 | 84 | return results 85 | 86 | def datagen(self, frames, mels): 87 | img_size = 96 88 | box = [-1, -1, -1, -1] 89 | wav2lip_batch_size = 128 90 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 91 | 92 | if box[0] == -1: 93 | if not self.static: 94 | face_det_results = self.face_detect( 95 | frames 96 | ) # BGR2RGB for CNN face detection 97 | else: 98 | face_det_results = self.face_detect([frames[0]]) 99 | else: 100 | print("Using the specified bounding box instead of face detection...") 101 | y1, y2, x1, x2 = box 102 | face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames] 103 | 104 | for i, m in enumerate(mels): 105 | idx = 0 if self.static else i % len(frames) 106 | frame_to_save = frames[idx].copy() 107 | face, coords = face_det_results[idx].copy() 108 | 109 | face = cv2.resize(face, (img_size, img_size)) 110 | img_batch.append(face) 111 | mel_batch.append(m) 112 | frame_batch.append(frame_to_save) 113 | coords_batch.append(coords) 114 | 115 | if len(img_batch) >= wav2lip_batch_size: 116 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 117 | 118 | img_masked = img_batch.copy() 119 | img_masked[:, img_size // 2 :] = 0 120 | 121 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 122 | mel_batch = np.reshape( 123 | mel_batch, 124 | [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1], 125 | ) 126 | 127 | yield img_batch, mel_batch, frame_batch, coords_batch 128 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 129 | 130 | if len(img_batch) > 0: 131 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 132 | 133 | img_masked = img_batch.copy() 134 | img_masked[:, img_size // 2 :] = 0 135 | 136 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 137 | mel_batch = np.reshape( 138 | mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1] 139 | ) 140 | 141 | yield img_batch, mel_batch, frame_batch, coords_batch 142 | 143 | def _load(self, checkpoint_path): 144 | if self.device == "cuda": 145 | checkpoint = torch.load(checkpoint_path) 146 | else: 147 | checkpoint = torch.load( 148 | checkpoint_path, map_location=lambda storage, loc: storage 149 | ) 150 | return checkpoint 151 | 152 | def load_model(self, path): 153 | model = Wav2Lip() 154 | print("Load checkpoint from: {}".format(path)) 155 | checkpoint = self._load(path) 156 | s = checkpoint["state_dict"] 157 | new_s = {} 158 | for k, v in s.items(): 159 | new_s[k.replace("module.", "")] = v 160 | model.load_state_dict(new_s) 161 | 162 | model = model.to(self.device) 163 | return model.eval() 164 | 165 | def run( 166 | self, 167 | face, 168 | audio_file, 169 | output_path="output.mp4", 170 | resize_factor=4, 171 | rotate=False, 172 | crop=[0, -1, 0, -1], 173 | fps=25, 174 | mel_step_size=16, 175 | wav2lip_batch_size=128, 176 | ): 177 | if not os.path.isfile(face): 178 | raise ValueError("--face argument must be a valid path to video/image file") 179 | 180 | elif face.split(".")[1] in ["jpg", "png", "jpeg"]: 181 | full_frames = [cv2.imread(face)] 182 | fps = fps 183 | 184 | else: 185 | video_stream = cv2.VideoCapture(face) 186 | fps = video_stream.get(cv2.CAP_PROP_FPS) 187 | 188 | print("Reading video frames...") 189 | 190 | full_frames = [] 191 | while 1: 192 | still_reading, frame = video_stream.read() 193 | if not still_reading: 194 | video_stream.release() 195 | break 196 | if resize_factor > 1: 197 | frame = cv2.resize( 198 | frame, 199 | ( 200 | frame.shape[1] // resize_factor, 201 | frame.shape[0] // resize_factor, 202 | ), 203 | ) 204 | 205 | if rotate: 206 | frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) 207 | 208 | y1, y2, x1, x2 = crop 209 | if x2 == -1: 210 | x2 = frame.shape[1] 211 | if y2 == -1: 212 | y2 = frame.shape[0] 213 | 214 | frame = frame[y1:y2, x1:x2] 215 | 216 | full_frames.append(frame) 217 | 218 | print("Number of frames available for inference: " + str(len(full_frames))) 219 | 220 | if not audio_file.endswith(".wav"): 221 | print("Extracting raw audio...") 222 | command = "ffmpeg -y -i {} -strict -2 {}".format( 223 | audio_file, f"{os.path.join('temp','temp.wav')}" 224 | ) 225 | 226 | subprocess.call(command, shell=True) 227 | audio_file = os.path.join("temp", "temp.wav") 228 | 229 | wav = audio.load_wav(audio_file, 16000) 230 | mel = audio.melspectrogram(wav) 231 | print(mel.shape) 232 | 233 | if np.isnan(mel.reshape(-1)).sum() > 0: 234 | raise ValueError( 235 | "Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again" 236 | ) 237 | 238 | mel_chunks = [] 239 | mel_idx_multiplier = 80.0 / fps 240 | i = 0 241 | while 1: 242 | start_idx = int(i * mel_idx_multiplier) 243 | if start_idx + mel_step_size > len(mel[0]): 244 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size :]) 245 | break 246 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 247 | i += 1 248 | 249 | print("Length of mel chunks: {}".format(len(mel_chunks))) 250 | 251 | full_frames = full_frames[: len(mel_chunks)] 252 | 253 | print("Full Frames before gen : ", len(full_frames)) 254 | 255 | batch_size = wav2lip_batch_size 256 | gen = self.datagen(full_frames.copy(), mel_chunks) 257 | 258 | for i, (img_batch, mel_batch, frames, coords) in enumerate( 259 | tqdm(gen, total=int(np.ceil(float(len(mel_chunks)) / batch_size))) 260 | ): 261 | if i == 0: 262 | model = self.load_model(self.checkpoint_path) 263 | print("Model loaded") 264 | generated_temp_video_path = os.path.join( 265 | "temp", 266 | f"{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_result.avi", 267 | ) 268 | frame_h, frame_w = full_frames[0].shape[:-1] 269 | out = cv2.VideoWriter( 270 | generated_temp_video_path, 271 | cv2.VideoWriter_fourcc(*"DIVX"), 272 | fps, 273 | (frame_w, frame_h), 274 | ) 275 | 276 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to( 277 | self.device 278 | ) 279 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to( 280 | self.device 281 | ) 282 | 283 | with torch.no_grad(): 284 | pred = model(mel_batch, img_batch) 285 | 286 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0 287 | 288 | for p, f, c in zip(pred, frames, coords): 289 | y1, y2, x1, x2 = c 290 | p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) 291 | 292 | f[y1:y2, x1:x2] = p 293 | out.write(f) 294 | 295 | out.release() 296 | 297 | # Load the video and audio clips 298 | video_clip = VideoFileClip(generated_temp_video_path) 299 | audio_clip = AudioFileClip(audio_file) 300 | 301 | # Set the audio of the video clip to the loaded audio clip 302 | video_clip = video_clip.set_audio(audio_clip) 303 | 304 | # Write the combined video to a new file 305 | video_clip.write_videofile(output_path, codec="libx264", audio_codec="aac") 306 | 307 | 308 | if __name__ == "__main__": 309 | processor = Processor() 310 | processor.run("image_path", "audio_path") 311 | --------------------------------------------------------------------------------