├── .gitignore ├── .idea ├── misc.xml ├── wav2lip_vq.iml └── workspace.xml ├── README.md ├── audio.py ├── color_syncnet_train.py ├── color_syncnet_train_vq.py ├── data └── vqgan-project.yaml ├── evaluation ├── README.md ├── gen_videos_from_filelist.py ├── real_videos_inference.py ├── scores_LSE │ ├── SyncNetInstance_calc_scores.py │ ├── calculate_scores_LRS.py │ ├── calculate_scores_real_videos.py │ └── calculate_scores_real_videos.sh └── test_filelists │ ├── README.md │ └── ReSyncED │ ├── random_pairs.txt │ └── tts_pairs.txt ├── face_detection ├── README.md ├── __init__.py ├── api.py ├── detection │ ├── __init__.py │ ├── core.py │ └── sfd │ │ ├── __init__.py │ │ ├── bbox.py │ │ ├── detect.py │ │ ├── net_s3fd.py │ │ └── sfd_detector.py ├── models.py └── utils.py ├── filelists └── README.md ├── get_filelist.py ├── hparams.py ├── hq_wav2lip_train.py ├── inference.py ├── models ├── __init__.py ├── conv.py ├── encoder_vq.py ├── quantize_vq.py ├── syncnet.py ├── syncnet_vq.py ├── vqgan.py └── wav2lip.py ├── preprocess.py ├── requirements.txt └── wav2lip_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *.pkl 3 | *.jpg 4 | *.mp4 5 | *.pth 6 | *.pyc 7 | __pycache__ 8 | *.h5 9 | *.avi 10 | *.wav 11 | filelists/*.txt 12 | evaluation/test_filelists/lr*.txt 13 | *.pyc 14 | *.mkv 15 | *.gif 16 | *.webm 17 | *.mp3 18 | checkpoints/ 19 | results/ 20 | temp/ 21 | data/min_pre/ 22 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/wav2lip_vq.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 13 | 14 | 20 | 21 | 22 | 24 | 25 | 26 | 27 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 1687235581309 79 | 84 | 85 | 86 | 87 | 89 | 90 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Wav2lip in a compact Vector Quantized (VQ) space 2 | 3 | - VQGAN 4 | - https://github.com/CompVis/taming-transformers 5 | - debugging custom models #107 6 | - fine-tune based on [vqgan_imagenet_f16_1024] 7 | - https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/ 8 | - image_size = 256 9 | - syncnet_vq.py 10 | - face_encoder: (B, T x 256, 16, 16) -> (B, 512, 1, 1) 11 | - audio_encoder: (B, 1, 80, 16) -> (B, 512, 1, 1) 12 | - color_syncnet_train_vq.py 13 | - vqgan config / ckpt -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | # import tensorflow as tf 5 | from scipy import signal 6 | from scipy.io import wavfile 7 | from hparams import hparams as hp 8 | 9 | def load_wav(path, sr): 10 | return librosa.core.load(path, sr=sr)[0] 11 | 12 | def save_wav(wav, path, sr): 13 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 14 | #proposed by @dsmiller 15 | wavfile.write(path, sr, wav.astype(np.int16)) 16 | 17 | def save_wavenet_wav(wav, path, sr): 18 | librosa.output.write_wav(path, wav, sr=sr) 19 | 20 | def preemphasis(wav, k, preemphasize=True): 21 | if preemphasize: 22 | return signal.lfilter([1, -k], [1], wav) 23 | return wav 24 | 25 | def inv_preemphasis(wav, k, inv_preemphasize=True): 26 | if inv_preemphasize: 27 | return signal.lfilter([1], [1, -k], wav) 28 | return wav 29 | 30 | def get_hop_size(): 31 | hop_size = hp.hop_size 32 | if hop_size is None: 33 | assert hp.frame_shift_ms is not None 34 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 35 | return hop_size 36 | 37 | def linearspectrogram(wav): 38 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 39 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 40 | 41 | if hp.signal_normalization: 42 | return _normalize(S) 43 | return S 44 | 45 | def melspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | def _lws_processor(): 54 | import lws 55 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 56 | 57 | def _stft(y): 58 | if hp.use_lws: 59 | return _lws_processor(hp).stft(y).T 60 | else: 61 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 62 | 63 | ########################################################## 64 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 65 | def num_frames(length, fsize, fshift): 66 | """Compute number of time frames of spectrogram 67 | """ 68 | pad = (fsize - fshift) 69 | if length % fshift == 0: 70 | M = (length + pad * 2 - fsize) // fshift + 1 71 | else: 72 | M = (length + pad * 2 - fsize) // fshift + 2 73 | return M 74 | 75 | 76 | def pad_lr(x, fsize, fshift): 77 | """Compute left and right padding 78 | """ 79 | M = num_frames(len(x), fsize, fshift) 80 | pad = (fsize - fshift) 81 | T = len(x) + 2 * pad 82 | r = (M - 1) * fshift + fsize - T 83 | return pad, pad + r 84 | ########################################################## 85 | #Librosa correct padding 86 | def librosa_pad_lr(x, fsize, fshift): 87 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 88 | 89 | # Conversions 90 | _mel_basis = None 91 | 92 | def _linear_to_mel(spectogram): 93 | global _mel_basis 94 | if _mel_basis is None: 95 | _mel_basis = _build_mel_basis() 96 | return np.dot(_mel_basis, spectogram) 97 | 98 | def _build_mel_basis(): 99 | assert hp.fmax <= hp.sample_rate // 2 100 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, 101 | fmin=hp.fmin, fmax=hp.fmax) 102 | 103 | def _amp_to_db(x): 104 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 105 | return 20 * np.log10(np.maximum(min_level, x)) 106 | 107 | def _db_to_amp(x): 108 | return np.power(10.0, (x) * 0.05) 109 | 110 | def _normalize(S): 111 | if hp.allow_clipping_in_normalization: 112 | if hp.symmetric_mels: 113 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 114 | -hp.max_abs_value, hp.max_abs_value) 115 | else: 116 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 117 | 118 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 119 | if hp.symmetric_mels: 120 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 121 | else: 122 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 123 | 124 | def _denormalize(D): 125 | if hp.allow_clipping_in_normalization: 126 | if hp.symmetric_mels: 127 | return (((np.clip(D, -hp.max_abs_value, 128 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 129 | + hp.min_level_db) 130 | else: 131 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 132 | 133 | if hp.symmetric_mels: 134 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 135 | else: 136 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 137 | -------------------------------------------------------------------------------- /color_syncnet_train.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join, basename, isfile 2 | from tqdm import tqdm 3 | 4 | from models import SyncNet_color as SyncNet 5 | import audio 6 | 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils import data as data_utils 12 | import numpy as np 13 | 14 | from glob import glob 15 | 16 | import os, random, cv2, argparse 17 | from hparams import hparams, get_image_list 18 | 19 | parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator') 20 | 21 | parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True) 22 | 23 | parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str) 24 | parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str) 25 | 26 | args = parser.parse_args() 27 | 28 | 29 | global_step = 0 30 | global_epoch = 0 31 | use_cuda = torch.cuda.is_available() 32 | print('use_cuda: {}'.format(use_cuda)) 33 | 34 | syncnet_T = 5 35 | syncnet_mel_step_size = 16 36 | 37 | class Dataset(object): 38 | def __init__(self, split): 39 | self.all_videos = get_image_list(args.data_root, split) 40 | 41 | def get_frame_id(self, frame): 42 | return int(basename(frame).split('.')[0]) 43 | 44 | def get_window(self, start_frame): 45 | start_id = self.get_frame_id(start_frame) 46 | vidname = dirname(start_frame) 47 | 48 | window_fnames = [] 49 | for frame_id in range(start_id, start_id + syncnet_T): 50 | frame = join(vidname, '{}.jpg'.format(frame_id)) 51 | if not isfile(frame): 52 | return None 53 | window_fnames.append(frame) 54 | return window_fnames 55 | 56 | def crop_audio_window(self, spec, start_frame): 57 | # num_frames = (T x hop_size * fps) / sample_rate 58 | start_frame_num = self.get_frame_id(start_frame) 59 | start_idx = int(80. * (start_frame_num / float(hparams.fps))) 60 | 61 | end_idx = start_idx + syncnet_mel_step_size 62 | 63 | return spec[start_idx : end_idx, :] 64 | 65 | 66 | def __len__(self): 67 | return len(self.all_videos) 68 | 69 | def __getitem__(self, idx): 70 | while 1: 71 | idx = random.randint(0, len(self.all_videos) - 1) 72 | vidname = self.all_videos[idx] 73 | 74 | img_names = list(glob(join(vidname, '*.jpg'))) 75 | if len(img_names) <= 3 * syncnet_T: 76 | continue 77 | img_name = random.choice(img_names) 78 | wrong_img_name = random.choice(img_names) 79 | while wrong_img_name == img_name: 80 | wrong_img_name = random.choice(img_names) 81 | 82 | if random.choice([True, False]): 83 | y = torch.ones(1).float() 84 | chosen = img_name 85 | else: 86 | y = torch.zeros(1).float() 87 | chosen = wrong_img_name 88 | 89 | window_fnames = self.get_window(chosen) 90 | if window_fnames is None: 91 | continue 92 | 93 | window = [] 94 | all_read = True 95 | for fname in window_fnames: 96 | img = cv2.imread(fname) 97 | if img is None: 98 | all_read = False 99 | break 100 | try: 101 | img = cv2.resize(img, (hparams.img_size, hparams.img_size)) 102 | except Exception as e: 103 | all_read = False 104 | break 105 | 106 | window.append(img) 107 | 108 | if not all_read: continue 109 | 110 | try: 111 | wavpath = join(vidname, "audio.wav") 112 | wav = audio.load_wav(wavpath, hparams.sample_rate) 113 | 114 | orig_mel = audio.melspectrogram(wav).T 115 | except Exception as e: 116 | continue 117 | 118 | mel = self.crop_audio_window(orig_mel.copy(), img_name) 119 | 120 | if (mel.shape[0] != syncnet_mel_step_size): 121 | continue 122 | 123 | # H x W x 3 * T 124 | x = np.concatenate(window, axis=2) / 255. 125 | x = x.transpose(2, 0, 1) 126 | x = x[:, x.shape[1]//2:] 127 | 128 | x = torch.FloatTensor(x) 129 | mel = torch.FloatTensor(mel.T).unsqueeze(0) 130 | 131 | return x, mel, y 132 | 133 | logloss = nn.BCELoss() 134 | def cosine_loss(a, v, y): 135 | d = nn.functional.cosine_similarity(a, v) 136 | loss = logloss(d.unsqueeze(1), y) 137 | 138 | return loss 139 | 140 | def train(device, model, train_data_loader, test_data_loader, optimizer, 141 | checkpoint_dir=None, checkpoint_interval=None, nepochs=None): 142 | 143 | global global_step, global_epoch 144 | resumed_step = global_step 145 | 146 | while global_epoch < nepochs: 147 | running_loss = 0. 148 | prog_bar = tqdm(enumerate(train_data_loader)) 149 | for step, (x, mel, y) in prog_bar: 150 | model.train() 151 | optimizer.zero_grad() 152 | 153 | # Transform data to CUDA device 154 | x = x.to(device) 155 | 156 | mel = mel.to(device) 157 | 158 | a, v = model(mel, x) 159 | y = y.to(device) 160 | 161 | loss = cosine_loss(a, v, y) 162 | loss.backward() 163 | optimizer.step() 164 | 165 | global_step += 1 166 | cur_session_steps = global_step - resumed_step 167 | running_loss += loss.item() 168 | 169 | if global_step == 1 or global_step % checkpoint_interval == 0: 170 | save_checkpoint( 171 | model, optimizer, global_step, checkpoint_dir, global_epoch) 172 | 173 | if global_step % hparams.syncnet_eval_interval == 0: 174 | with torch.no_grad(): 175 | eval_model(test_data_loader, global_step, device, model, checkpoint_dir) 176 | 177 | prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1))) 178 | 179 | global_epoch += 1 180 | 181 | def eval_model(test_data_loader, global_step, device, model, checkpoint_dir): 182 | eval_steps = 1400 183 | print('Evaluating for {} steps'.format(eval_steps)) 184 | losses = [] 185 | while 1: 186 | for step, (x, mel, y) in enumerate(test_data_loader): 187 | 188 | model.eval() 189 | 190 | # Transform data to CUDA device 191 | x = x.to(device) 192 | 193 | mel = mel.to(device) 194 | 195 | a, v = model(mel, x) 196 | y = y.to(device) 197 | 198 | loss = cosine_loss(a, v, y) 199 | losses.append(loss.item()) 200 | 201 | if step > eval_steps: break 202 | 203 | averaged_loss = sum(losses) / len(losses) 204 | print(averaged_loss) 205 | 206 | return 207 | 208 | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch): 209 | 210 | checkpoint_path = join( 211 | checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step)) 212 | optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None 213 | torch.save({ 214 | "state_dict": model.state_dict(), 215 | "optimizer": optimizer_state, 216 | "global_step": step, 217 | "global_epoch": epoch, 218 | }, checkpoint_path) 219 | print("Saved checkpoint:", checkpoint_path) 220 | 221 | def _load(checkpoint_path): 222 | if use_cuda: 223 | checkpoint = torch.load(checkpoint_path) 224 | else: 225 | checkpoint = torch.load(checkpoint_path, 226 | map_location=lambda storage, loc: storage) 227 | return checkpoint 228 | 229 | def load_checkpoint(path, model, optimizer, reset_optimizer=False): 230 | global global_step 231 | global global_epoch 232 | 233 | print("Load checkpoint from: {}".format(path)) 234 | checkpoint = _load(path) 235 | model.load_state_dict(checkpoint["state_dict"]) 236 | if not reset_optimizer: 237 | optimizer_state = checkpoint["optimizer"] 238 | if optimizer_state is not None: 239 | print("Load optimizer state from {}".format(path)) 240 | optimizer.load_state_dict(checkpoint["optimizer"]) 241 | global_step = checkpoint["global_step"] 242 | global_epoch = checkpoint["global_epoch"] 243 | 244 | return model 245 | 246 | if __name__ == "__main__": 247 | checkpoint_dir = args.checkpoint_dir 248 | checkpoint_path = args.checkpoint_path 249 | 250 | if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) 251 | 252 | # Dataset and Dataloader setup 253 | train_dataset = Dataset('train') 254 | test_dataset = Dataset('val') 255 | 256 | train_data_loader = data_utils.DataLoader( 257 | train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True, 258 | num_workers=hparams.num_workers) 259 | 260 | test_data_loader = data_utils.DataLoader( 261 | test_dataset, batch_size=hparams.syncnet_batch_size, 262 | num_workers=8) 263 | 264 | device = torch.device("cuda" if use_cuda else "cpu") 265 | 266 | # Model 267 | model = SyncNet().to(device) 268 | print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 269 | 270 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], 271 | lr=hparams.syncnet_lr) 272 | 273 | if checkpoint_path is not None: 274 | load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False) 275 | 276 | train(device, model, train_data_loader, test_data_loader, optimizer, 277 | checkpoint_dir=checkpoint_dir, 278 | checkpoint_interval=hparams.syncnet_checkpoint_interval, 279 | nepochs=hparams.nepochs) 280 | -------------------------------------------------------------------------------- /color_syncnet_train_vq.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join, basename, isfile 2 | 3 | from tqdm import tqdm 4 | 5 | from models.syncnet_vq import SyncNet_color as SyncNet 6 | import audio 7 | 8 | import torch 9 | from torch import nn 10 | from torch import optim 11 | from torch.utils import data as data_utils 12 | import numpy as np 13 | 14 | from glob import glob 15 | 16 | import os, random, cv2, argparse 17 | from hparams import hparams, get_image_list 18 | 19 | parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator') 20 | 21 | parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", default='data/min_pre/') 22 | parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', default='checkpoints', type=str) 23 | parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str) 24 | 25 | args = parser.parse_args() 26 | 27 | global_step = 0 28 | global_epoch = 0 29 | use_cuda = torch.cuda.is_available() 30 | print('use_cuda: {}'.format(use_cuda)) 31 | 32 | syncnet_T = 5 33 | syncnet_mel_step_size = 16 34 | img_size = 256 35 | device = torch.device("cuda" if use_cuda else "cpu") 36 | config_path = './data/vqgan-project.yaml' 37 | # ckpt_path = 'xxx/taming-transformers/logs/xxx/checkpoints/last.ckpt' 38 | 39 | 40 | class Dataset(object): 41 | def __init__(self, split): 42 | self.all_videos = get_image_list(args.data_root, split) 43 | 44 | def get_frame_id(self, frame): 45 | return int(basename(frame).split('.')[0]) 46 | 47 | def get_window(self, start_frame): 48 | start_id = self.get_frame_id(start_frame) 49 | vidname = dirname(start_frame) 50 | 51 | window_fnames = [] 52 | for frame_id in range(start_id, start_id + syncnet_T): 53 | frame = join(vidname, '{}.jpg'.format(frame_id)) 54 | if not isfile(frame): 55 | return None 56 | window_fnames.append(frame) 57 | return window_fnames 58 | 59 | def crop_audio_window(self, spec, start_frame): 60 | # num_frames = (T x hop_size * fps) / sample_rate 61 | start_frame_num = self.get_frame_id(start_frame) 62 | start_idx = int(80. * (start_frame_num / float(hparams.fps))) 63 | 64 | end_idx = start_idx + syncnet_mel_step_size 65 | 66 | return spec[start_idx: end_idx, :] 67 | 68 | def __len__(self): 69 | return len(self.all_videos) 70 | 71 | def __getitem__(self, idx): 72 | while 1: 73 | idx = random.randint(0, len(self.all_videos) - 1) 74 | vidname = self.all_videos[idx] 75 | 76 | img_names = list(glob(join(vidname, '*.jpg'))) 77 | if len(img_names) <= 3 * syncnet_T: 78 | continue 79 | img_name = random.choice(img_names) 80 | wrong_img_name = random.choice(img_names) 81 | while wrong_img_name == img_name: 82 | wrong_img_name = random.choice(img_names) 83 | 84 | if random.choice([True, False]): 85 | y = torch.ones(1).float() 86 | chosen = img_name 87 | else: 88 | y = torch.zeros(1).float() 89 | chosen = wrong_img_name 90 | 91 | window_fnames = self.get_window(chosen) 92 | if window_fnames is None: 93 | continue 94 | 95 | window = [] 96 | all_read = True 97 | for fname in window_fnames: 98 | img = cv2.imread(fname) 99 | if img is None: 100 | all_read = False 101 | break 102 | try: 103 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 104 | img = cv2.resize(img, (img_size, img_size)) 105 | except Exception as e: 106 | all_read = False 107 | break 108 | 109 | window.append(img) 110 | 111 | if not all_read: continue 112 | 113 | try: 114 | wavpath = join(vidname, "audio.wav") 115 | wav = audio.load_wav(wavpath, hparams.sample_rate) 116 | 117 | orig_mel = audio.melspectrogram(wav).T 118 | except Exception as e: 119 | continue 120 | 121 | mel = self.crop_audio_window(orig_mel.copy(), img_name) 122 | 123 | if (mel.shape[0] != syncnet_mel_step_size): 124 | continue 125 | 126 | # H x W x 3 * T 127 | x = np.array(window) 128 | x = (x / 127.5 - 1.0).astype(np.float32) 129 | x = x.transpose(0, 3, 1, 2) # T, 3, H, W 130 | 131 | x = torch.FloatTensor(x) 132 | mel = torch.FloatTensor(mel.T).unsqueeze(0) 133 | 134 | return x, mel, y 135 | 136 | 137 | logloss = nn.BCELoss() 138 | 139 | 140 | def cosine_loss(a, v, y): 141 | d = nn.functional.cosine_similarity(a, v) 142 | loss = logloss(d.unsqueeze(1), y) 143 | 144 | return loss 145 | 146 | 147 | def train(device, model, train_data_loader, test_data_loader, optimizer, 148 | checkpoint_dir=None, checkpoint_interval=None, nepochs=None): 149 | global global_step, global_epoch 150 | 151 | while global_epoch < nepochs: 152 | running_loss = 0. 153 | prog_bar = tqdm(enumerate(train_data_loader)) 154 | for step, (x, mel, y) in prog_bar: 155 | model.train() 156 | optimizer.zero_grad() 157 | 158 | x = x.to(device) 159 | mel = mel.to(device) 160 | 161 | a, v = model(mel, x) 162 | y = y.to(device) 163 | 164 | loss = cosine_loss(a, v, y) 165 | loss.backward() 166 | optimizer.step() 167 | 168 | global_step += 1 169 | running_loss += loss.item() 170 | 171 | if global_step == 1 or global_step % checkpoint_interval == 0: 172 | save_checkpoint( 173 | model, optimizer, global_step, checkpoint_dir, global_epoch) 174 | 175 | if global_step % hparams.syncnet_eval_interval == 0: 176 | with torch.no_grad(): 177 | eval_model(test_data_loader, global_step, device, model, checkpoint_dir) 178 | 179 | prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1))) 180 | 181 | global_epoch += 1 182 | 183 | 184 | def eval_model(test_data_loader, global_step, device, model, checkpoint_dir): 185 | eval_steps = 100 if torch.cuda.is_available() else 2 186 | print('Evaluating for {} steps'.format(eval_steps)) 187 | losses = [] 188 | model.eval() 189 | 190 | while 1: 191 | for step, (x, mel, y) in enumerate(test_data_loader): 192 | x = x.to(device) 193 | mel = mel.to(device) 194 | 195 | a, v = model(mel, x) 196 | y = y.to(device) 197 | 198 | loss = cosine_loss(a, v, y) 199 | losses.append(loss.item()) 200 | 201 | if step > eval_steps: break 202 | 203 | averaged_loss = sum(losses) / len(losses) 204 | print(averaged_loss) 205 | 206 | return 207 | 208 | 209 | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, checkpoint_path=None): 210 | if checkpoint_path is None: 211 | checkpoint_path = join(checkpoint_dir, "lipSync_step{:09d}.pth".format(global_step)) 212 | optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None 213 | torch.save({ 214 | "state_dict": model.state_dict(), 215 | "optimizer": optimizer_state, 216 | "global_step": step, 217 | "global_epoch": epoch, 218 | }, checkpoint_path) 219 | print("Saved checkpoint:", checkpoint_path) 220 | 221 | 222 | def _load(checkpoint_path): 223 | if use_cuda: 224 | checkpoint = torch.load(checkpoint_path) 225 | else: 226 | checkpoint = torch.load(checkpoint_path, 227 | map_location=lambda storage, loc: storage) 228 | return checkpoint 229 | 230 | 231 | def load_checkpoint(path, model, optimizer, reset_optimizer=False): 232 | global global_step 233 | global global_epoch 234 | 235 | print("Load checkpoint from: {}".format(path)) 236 | checkpoint = _load(path) 237 | model.load_state_dict(checkpoint["state_dict"]) 238 | 239 | if not reset_optimizer: 240 | optimizer_state = checkpoint["optimizer"] 241 | if optimizer_state is not None: 242 | print("Load optimizer state from {}".format(path)) 243 | optimizer.load_state_dict(checkpoint["optimizer"]) 244 | global_step = checkpoint["global_step"] 245 | global_epoch = checkpoint["global_epoch"] 246 | 247 | return model 248 | 249 | 250 | if __name__ == "__main__": 251 | checkpoint_dir = args.checkpoint_dir 252 | checkpoint_path = args.checkpoint_path 253 | 254 | if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) 255 | 256 | # Dataset and Dataloader setup 257 | train_dataset = Dataset('train') 258 | test_dataset = Dataset('val') 259 | 260 | train_data_loader = data_utils.DataLoader( 261 | train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True, 262 | num_workers=hparams.num_workers) 263 | 264 | test_data_loader = data_utils.DataLoader( 265 | test_dataset, batch_size=hparams.syncnet_batch_size, 266 | num_workers=hparams.num_workers) 267 | 268 | # Model 269 | model = SyncNet(config_path) 270 | model.to(device) 271 | print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 272 | 273 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], 274 | lr=hparams.syncnet_lr) 275 | 276 | if checkpoint_path is not None: 277 | load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False) 278 | 279 | train(device, model, train_data_loader, test_data_loader, optimizer, 280 | checkpoint_dir=checkpoint_dir, 281 | checkpoint_interval=hparams.syncnet_checkpoint_interval, 282 | nepochs=hparams.nepochs) 283 | 284 | save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch, 285 | checkpoint_path=join(checkpoint_dir, "lipSync_last.pth")) 286 | -------------------------------------------------------------------------------- /data/vqgan-project.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: false 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 10000 30 | disc_weight: 0.8 31 | codebook_weight: 1.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 2 36 | num_workers: 8 37 | train: 38 | target: taming.data.custom.CustomTrain 39 | params: 40 | training_images_list_file: some/train.txt 41 | size: 256 42 | validation: 43 | target: taming.data.custom.CustomTest 44 | params: 45 | test_images_list_file: some/val.txt 46 | size: 256 47 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Novel Evaluation Framework, new filelists, and using the LSE-D and LSE-C metric. 2 | 3 | Our paper also proposes a novel evaluation framework (Section 4). To evaluate on LRS2, LRS3, and LRW, the filelists are present in the `test_filelists` folder. Please use `gen_videos_from_filelist.py` script to generate the videos. After that, you can calculate the LSE-D and LSE-C scores using the instructions below. Please see [this thread](https://github.com/Rudrabha/Wav2Lip/issues/22#issuecomment-712825380) on how to calculate the FID scores. 4 | 5 | The videos of the ReSyncED benchmark for real-world evaluation will be released soon. 6 | 7 | ### Steps to set-up the evaluation repository for LSE-D and LSE-C metric: 8 | We use the pre-trained syncnet model available in this [repository](https://github.com/joonson/syncnet_python). 9 | 10 | * Clone the SyncNet repository. 11 | ``` 12 | git clone https://github.com/joonson/syncnet_python.git 13 | ``` 14 | * Follow the procedure given in the above linked [repository](https://github.com/joonson/syncnet_python) to download the pretrained models and set up the dependencies. 15 | * **Note: Please install a separate virtual environment for the evaluation scripts. The versions used by Wav2Lip and the publicly released code of SyncNet is different and can cause version mis-match issues. To avoid this, we suggest the users to install a separate virtual environment for the evaluation scripts** 16 | ``` 17 | cd syncnet_python 18 | pip install -r requirements.txt 19 | sh download_model.sh 20 | ``` 21 | * The above step should ensure that all the dependencies required by the repository is installed and the pre-trained models are downloaded. 22 | 23 | ### Running the evaluation scripts: 24 | * Copy our evaluation scripts given in this folder to the cloned repository. 25 | ``` 26 | cd Wav2Lip/evaluation/scores_LSE/ 27 | cp *.py syncnet_python/ 28 | cp *.sh syncnet_python/ 29 | ``` 30 | **Note: We will release the test filelists for LRW, LRS2 and LRS3 shortly once we receive permission from the dataset creators. We will also release the Real World Dataset we have collected shortly.** 31 | 32 | * Our evaluation technique does not require ground-truth of any sorts. Given lip-synced videos we can directly calculate the scores from only the generated videos. Please store the generated videos (from our test sets or your own generated videos) in the following folder structure. 33 | ``` 34 | video data root (Folder containing all videos) 35 | ├── All .mp4 files 36 | ``` 37 | * Change the folder back to the cloned repository. 38 | ``` 39 | cd syncnet_python 40 | ``` 41 | * To run evaluation on the LRW, LRS2 and LRS3 test files, please run the following command: 42 | ``` 43 | python calculate_scores_LRS.py --data_root /path/to/video/data/root --tmp_dir tmp_dir/ 44 | ``` 45 | 46 | * To run evaluation on the ReSynced dataset or your own generated videos, please run the following command: 47 | ``` 48 | sh calculate_scores_real_videos.sh /path/to/video/data/root 49 | ``` 50 | * The generated scores will be present in the all_scores.txt generated in the ```syncnet_python/``` folder 51 | 52 | # Evaluation of image quality using FID metric. 53 | We use the [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository for calculating the FID metrics. We dump all the frames in both ground-truth and generated videos and calculate the FID score. 54 | 55 | 56 | # Opening issues related to evaluation scripts 57 | * Please open the issues with the "Evaluation" label if you face any issues in the evaluation scripts. 58 | 59 | # Acknowledgements 60 | Our evaluation pipeline in based on two existing repositories. LSE metrics are based on the [syncnet_python](https://github.com/joonson/syncnet_python) repository and the FID score is based on [pytorch-fid](https://github.com/mseitzer/pytorch-fid) repository. We thank the authors of both the repositories for releasing their wonderful code. 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /evaluation/gen_videos_from_filelist.py: -------------------------------------------------------------------------------- 1 | from os import listdir, path 2 | import numpy as np 3 | import scipy, cv2, os, sys, argparse 4 | import dlib, json, subprocess 5 | from tqdm import tqdm 6 | from glob import glob 7 | import torch 8 | 9 | sys.path.append('../') 10 | import audio 11 | import face_detection 12 | from models import Wav2Lip 13 | 14 | parser = argparse.ArgumentParser(description='Code to generate results for test filelists') 15 | 16 | parser.add_argument('--filelist', type=str, 17 | help='Filepath of filelist file to read', required=True) 18 | parser.add_argument('--results_dir', type=str, help='Folder to save all results into', 19 | required=True) 20 | parser.add_argument('--data_root', type=str, required=True) 21 | parser.add_argument('--checkpoint_path', type=str, 22 | help='Name of saved checkpoint to load weights from', required=True) 23 | 24 | parser.add_argument('--pads', nargs='+', type=int, default=[0, 0, 0, 0], 25 | help='Padding (top, bottom, left, right)') 26 | parser.add_argument('--face_det_batch_size', type=int, 27 | help='Single GPU batch size for face detection', default=64) 28 | parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128) 29 | 30 | # parser.add_argument('--resize_factor', default=1, type=int) 31 | 32 | args = parser.parse_args() 33 | args.img_size = 96 34 | 35 | def get_smoothened_boxes(boxes, T): 36 | for i in range(len(boxes)): 37 | if i + T > len(boxes): 38 | window = boxes[len(boxes) - T:] 39 | else: 40 | window = boxes[i : i + T] 41 | boxes[i] = np.mean(window, axis=0) 42 | return boxes 43 | 44 | def face_detect(images): 45 | batch_size = args.face_det_batch_size 46 | 47 | while 1: 48 | predictions = [] 49 | try: 50 | for i in range(0, len(images), batch_size): 51 | predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) 52 | except RuntimeError: 53 | if batch_size == 1: 54 | raise RuntimeError('Image too big to run face detection on GPU') 55 | batch_size //= 2 56 | args.face_det_batch_size = batch_size 57 | print('Recovering from OOM error; New batch size: {}'.format(batch_size)) 58 | continue 59 | break 60 | 61 | results = [] 62 | pady1, pady2, padx1, padx2 = args.pads 63 | for rect, image in zip(predictions, images): 64 | if rect is None: 65 | raise ValueError('Face not detected!') 66 | 67 | y1 = max(0, rect[1] - pady1) 68 | y2 = min(image.shape[0], rect[3] + pady2) 69 | x1 = max(0, rect[0] - padx1) 70 | x2 = min(image.shape[1], rect[2] + padx2) 71 | 72 | results.append([x1, y1, x2, y2]) 73 | 74 | boxes = get_smoothened_boxes(np.array(results), T=5) 75 | results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)] 76 | 77 | return results 78 | 79 | def datagen(frames, face_det_results, mels): 80 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 81 | 82 | for i, m in enumerate(mels): 83 | if i >= len(frames): raise ValueError('Equal or less lengths only') 84 | 85 | frame_to_save = frames[i].copy() 86 | face, coords, valid_frame = face_det_results[i].copy() 87 | if not valid_frame: 88 | continue 89 | 90 | face = cv2.resize(face, (args.img_size, args.img_size)) 91 | 92 | img_batch.append(face) 93 | mel_batch.append(m) 94 | frame_batch.append(frame_to_save) 95 | coords_batch.append(coords) 96 | 97 | if len(img_batch) >= args.wav2lip_batch_size: 98 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 99 | 100 | img_masked = img_batch.copy() 101 | img_masked[:, args.img_size//2:] = 0 102 | 103 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 104 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 105 | 106 | yield img_batch, mel_batch, frame_batch, coords_batch 107 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 108 | 109 | if len(img_batch) > 0: 110 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 111 | 112 | img_masked = img_batch.copy() 113 | img_masked[:, args.img_size//2:] = 0 114 | 115 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 116 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 117 | 118 | yield img_batch, mel_batch, frame_batch, coords_batch 119 | 120 | fps = 25 121 | mel_step_size = 16 122 | mel_idx_multiplier = 80./fps 123 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 124 | print('Using {} for inference.'.format(device)) 125 | 126 | detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, 127 | flip_input=False, device=device) 128 | 129 | def _load(checkpoint_path): 130 | if device == 'cuda': 131 | checkpoint = torch.load(checkpoint_path) 132 | else: 133 | checkpoint = torch.load(checkpoint_path, 134 | map_location=lambda storage, loc: storage) 135 | return checkpoint 136 | 137 | def load_model(path): 138 | model = Wav2Lip() 139 | print("Load checkpoint from: {}".format(path)) 140 | checkpoint = _load(path) 141 | s = checkpoint["state_dict"] 142 | new_s = {} 143 | for k, v in s.items(): 144 | new_s[k.replace('module.', '')] = v 145 | model.load_state_dict(new_s) 146 | 147 | model = model.to(device) 148 | return model.eval() 149 | 150 | model = load_model(args.checkpoint_path) 151 | 152 | def main(): 153 | assert args.data_root is not None 154 | data_root = args.data_root 155 | 156 | if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) 157 | 158 | with open(args.filelist, 'r') as filelist: 159 | lines = filelist.readlines() 160 | 161 | for idx, line in enumerate(tqdm(lines)): 162 | audio_src, video = line.strip().split() 163 | 164 | audio_src = os.path.join(data_root, audio_src) + '.mp4' 165 | video = os.path.join(data_root, video) + '.mp4' 166 | 167 | command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav') 168 | subprocess.call(command, shell=True) 169 | temp_audio = '../temp/temp.wav' 170 | 171 | wav = audio.load_wav(temp_audio, 16000) 172 | mel = audio.melspectrogram(wav) 173 | if np.isnan(mel.reshape(-1)).sum() > 0: 174 | continue 175 | 176 | mel_chunks = [] 177 | i = 0 178 | while 1: 179 | start_idx = int(i * mel_idx_multiplier) 180 | if start_idx + mel_step_size > len(mel[0]): 181 | break 182 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 183 | i += 1 184 | 185 | video_stream = cv2.VideoCapture(video) 186 | 187 | full_frames = [] 188 | while 1: 189 | still_reading, frame = video_stream.read() 190 | if not still_reading or len(full_frames) > len(mel_chunks): 191 | video_stream.release() 192 | break 193 | full_frames.append(frame) 194 | 195 | if len(full_frames) < len(mel_chunks): 196 | continue 197 | 198 | full_frames = full_frames[:len(mel_chunks)] 199 | 200 | try: 201 | face_det_results = face_detect(full_frames.copy()) 202 | except ValueError as e: 203 | continue 204 | 205 | batch_size = args.wav2lip_batch_size 206 | gen = datagen(full_frames.copy(), face_det_results, mel_chunks) 207 | 208 | for i, (img_batch, mel_batch, frames, coords) in enumerate(gen): 209 | if i == 0: 210 | frame_h, frame_w = full_frames[0].shape[:-1] 211 | out = cv2.VideoWriter('../temp/result.avi', 212 | cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) 213 | 214 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) 215 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) 216 | 217 | with torch.no_grad(): 218 | pred = model(mel_batch, img_batch) 219 | 220 | 221 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. 222 | 223 | for pl, f, c in zip(pred, frames, coords): 224 | y1, y2, x1, x2 = c 225 | pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1)) 226 | f[y1:y2, x1:x2] = pl 227 | out.write(f) 228 | 229 | out.release() 230 | 231 | vid = os.path.join(args.results_dir, '{}.mp4'.format(idx)) 232 | 233 | command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio, 234 | '../temp/result.avi', vid) 235 | subprocess.call(command, shell=True) 236 | 237 | if __name__ == '__main__': 238 | main() 239 | -------------------------------------------------------------------------------- /evaluation/real_videos_inference.py: -------------------------------------------------------------------------------- 1 | from os import listdir, path 2 | import numpy as np 3 | import scipy, cv2, os, sys, argparse 4 | import dlib, json, subprocess 5 | from tqdm import tqdm 6 | from glob import glob 7 | import torch 8 | 9 | sys.path.append('../') 10 | import audio 11 | import face_detection 12 | from models import Wav2Lip 13 | 14 | parser = argparse.ArgumentParser(description='Code to generate results on ReSyncED evaluation set') 15 | 16 | parser.add_argument('--mode', type=str, 17 | help='random | dubbed | tts', required=True) 18 | 19 | parser.add_argument('--filelist', type=str, 20 | help='Filepath of filelist file to read', default=None) 21 | 22 | parser.add_argument('--results_dir', type=str, help='Folder to save all results into', 23 | required=True) 24 | parser.add_argument('--data_root', type=str, required=True) 25 | parser.add_argument('--checkpoint_path', type=str, 26 | help='Name of saved checkpoint to load weights from', required=True) 27 | parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], 28 | help='Padding (top, bottom, left, right)') 29 | 30 | parser.add_argument('--face_det_batch_size', type=int, 31 | help='Single GPU batch size for face detection', default=16) 32 | 33 | parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip', default=128) 34 | parser.add_argument('--face_res', help='Approximate resolution of the face at which to test', default=180) 35 | parser.add_argument('--min_frame_res', help='Do not downsample further below this frame resolution', default=480) 36 | parser.add_argument('--max_frame_res', help='Downsample to at least this frame resolution', default=720) 37 | # parser.add_argument('--resize_factor', default=1, type=int) 38 | 39 | args = parser.parse_args() 40 | args.img_size = 96 41 | 42 | def get_smoothened_boxes(boxes, T): 43 | for i in range(len(boxes)): 44 | if i + T > len(boxes): 45 | window = boxes[len(boxes) - T:] 46 | else: 47 | window = boxes[i : i + T] 48 | boxes[i] = np.mean(window, axis=0) 49 | return boxes 50 | 51 | def rescale_frames(images): 52 | rect = detector.get_detections_for_batch(np.array([images[0]]))[0] 53 | if rect is None: 54 | raise ValueError('Face not detected!') 55 | h, w = images[0].shape[:-1] 56 | 57 | x1, y1, x2, y2 = rect 58 | 59 | face_size = max(np.abs(y1 - y2), np.abs(x1 - x2)) 60 | 61 | diff = np.abs(face_size - args.face_res) 62 | for factor in range(2, 16): 63 | downsampled_res = face_size // factor 64 | if min(h//factor, w//factor) < args.min_frame_res: break 65 | if np.abs(downsampled_res - args.face_res) >= diff: break 66 | 67 | factor -= 1 68 | if factor == 1: return images 69 | 70 | return [cv2.resize(im, (im.shape[1]//(factor), im.shape[0]//(factor))) for im in images] 71 | 72 | 73 | def face_detect(images): 74 | batch_size = args.face_det_batch_size 75 | images = rescale_frames(images) 76 | 77 | while 1: 78 | predictions = [] 79 | try: 80 | for i in range(0, len(images), batch_size): 81 | predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) 82 | except RuntimeError: 83 | if batch_size == 1: 84 | raise RuntimeError('Image too big to run face detection on GPU') 85 | batch_size //= 2 86 | print('Recovering from OOM error; New batch size: {}'.format(batch_size)) 87 | continue 88 | break 89 | 90 | results = [] 91 | pady1, pady2, padx1, padx2 = args.pads 92 | for rect, image in zip(predictions, images): 93 | if rect is None: 94 | raise ValueError('Face not detected!') 95 | 96 | y1 = max(0, rect[1] - pady1) 97 | y2 = min(image.shape[0], rect[3] + pady2) 98 | x1 = max(0, rect[0] - padx1) 99 | x2 = min(image.shape[1], rect[2] + padx2) 100 | 101 | results.append([x1, y1, x2, y2]) 102 | 103 | boxes = get_smoothened_boxes(np.array(results), T=5) 104 | results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2), True] for image, (x1, y1, x2, y2) in zip(images, boxes)] 105 | 106 | return results, images 107 | 108 | def datagen(frames, face_det_results, mels): 109 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 110 | 111 | for i, m in enumerate(mels): 112 | if i >= len(frames): raise ValueError('Equal or less lengths only') 113 | 114 | frame_to_save = frames[i].copy() 115 | face, coords, valid_frame = face_det_results[i].copy() 116 | if not valid_frame: 117 | continue 118 | 119 | face = cv2.resize(face, (args.img_size, args.img_size)) 120 | 121 | img_batch.append(face) 122 | mel_batch.append(m) 123 | frame_batch.append(frame_to_save) 124 | coords_batch.append(coords) 125 | 126 | if len(img_batch) >= args.wav2lip_batch_size: 127 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 128 | 129 | img_masked = img_batch.copy() 130 | img_masked[:, args.img_size//2:] = 0 131 | 132 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 133 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 134 | 135 | yield img_batch, mel_batch, frame_batch, coords_batch 136 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 137 | 138 | if len(img_batch) > 0: 139 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 140 | 141 | img_masked = img_batch.copy() 142 | img_masked[:, args.img_size//2:] = 0 143 | 144 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 145 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 146 | 147 | yield img_batch, mel_batch, frame_batch, coords_batch 148 | 149 | def increase_frames(frames, l): 150 | ## evenly duplicating frames to increase length of video 151 | while len(frames) < l: 152 | dup_every = float(l) / len(frames) 153 | 154 | final_frames = [] 155 | next_duplicate = 0. 156 | 157 | for i, f in enumerate(frames): 158 | final_frames.append(f) 159 | 160 | if int(np.ceil(next_duplicate)) == i: 161 | final_frames.append(f) 162 | 163 | next_duplicate += dup_every 164 | 165 | frames = final_frames 166 | 167 | return frames[:l] 168 | 169 | mel_step_size = 16 170 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 171 | print('Using {} for inference.'.format(device)) 172 | 173 | detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, 174 | flip_input=False, device=device) 175 | 176 | def _load(checkpoint_path): 177 | if device == 'cuda': 178 | checkpoint = torch.load(checkpoint_path) 179 | else: 180 | checkpoint = torch.load(checkpoint_path, 181 | map_location=lambda storage, loc: storage) 182 | return checkpoint 183 | 184 | def load_model(path): 185 | model = Wav2Lip() 186 | print("Load checkpoint from: {}".format(path)) 187 | checkpoint = _load(path) 188 | s = checkpoint["state_dict"] 189 | new_s = {} 190 | for k, v in s.items(): 191 | new_s[k.replace('module.', '')] = v 192 | model.load_state_dict(new_s) 193 | 194 | model = model.to(device) 195 | return model.eval() 196 | 197 | model = load_model(args.checkpoint_path) 198 | 199 | def main(): 200 | if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir) 201 | 202 | if args.mode == 'dubbed': 203 | files = listdir(args.data_root) 204 | lines = ['{} {}'.format(f, f) for f in files] 205 | 206 | else: 207 | assert args.filelist is not None 208 | with open(args.filelist, 'r') as filelist: 209 | lines = filelist.readlines() 210 | 211 | for idx, line in enumerate(tqdm(lines)): 212 | video, audio_src = line.strip().split() 213 | 214 | audio_src = os.path.join(args.data_root, audio_src) 215 | video = os.path.join(args.data_root, video) 216 | 217 | command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, '../temp/temp.wav') 218 | subprocess.call(command, shell=True) 219 | temp_audio = '../temp/temp.wav' 220 | 221 | wav = audio.load_wav(temp_audio, 16000) 222 | mel = audio.melspectrogram(wav) 223 | 224 | if np.isnan(mel.reshape(-1)).sum() > 0: 225 | raise ValueError('Mel contains nan!') 226 | 227 | video_stream = cv2.VideoCapture(video) 228 | 229 | fps = video_stream.get(cv2.CAP_PROP_FPS) 230 | mel_idx_multiplier = 80./fps 231 | 232 | full_frames = [] 233 | while 1: 234 | still_reading, frame = video_stream.read() 235 | if not still_reading: 236 | video_stream.release() 237 | break 238 | 239 | if min(frame.shape[:-1]) > args.max_frame_res: 240 | h, w = frame.shape[:-1] 241 | scale_factor = min(h, w) / float(args.max_frame_res) 242 | h = int(h/scale_factor) 243 | w = int(w/scale_factor) 244 | 245 | frame = cv2.resize(frame, (w, h)) 246 | full_frames.append(frame) 247 | 248 | mel_chunks = [] 249 | i = 0 250 | while 1: 251 | start_idx = int(i * mel_idx_multiplier) 252 | if start_idx + mel_step_size > len(mel[0]): 253 | break 254 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 255 | i += 1 256 | 257 | if len(full_frames) < len(mel_chunks): 258 | if args.mode == 'tts': 259 | full_frames = increase_frames(full_frames, len(mel_chunks)) 260 | else: 261 | raise ValueError('#Frames, audio length mismatch') 262 | 263 | else: 264 | full_frames = full_frames[:len(mel_chunks)] 265 | 266 | try: 267 | face_det_results, full_frames = face_detect(full_frames.copy()) 268 | except ValueError as e: 269 | continue 270 | 271 | batch_size = args.wav2lip_batch_size 272 | gen = datagen(full_frames.copy(), face_det_results, mel_chunks) 273 | 274 | for i, (img_batch, mel_batch, frames, coords) in enumerate(gen): 275 | if i == 0: 276 | frame_h, frame_w = full_frames[0].shape[:-1] 277 | 278 | out = cv2.VideoWriter('../temp/result.avi', 279 | cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) 280 | 281 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) 282 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) 283 | 284 | with torch.no_grad(): 285 | pred = model(mel_batch, img_batch) 286 | 287 | 288 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. 289 | 290 | for pl, f, c in zip(pred, frames, coords): 291 | y1, y2, x1, x2 = c 292 | pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1)) 293 | f[y1:y2, x1:x2] = pl 294 | out.write(f) 295 | 296 | out.release() 297 | 298 | vid = os.path.join(args.results_dir, '{}.mp4'.format(idx)) 299 | command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format('../temp/temp.wav', 300 | '../temp/result.avi', vid) 301 | subprocess.call(command, shell=True) 302 | 303 | 304 | if __name__ == '__main__': 305 | main() 306 | -------------------------------------------------------------------------------- /evaluation/scores_LSE/SyncNetInstance_calc_scores.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | # Video 25 FPS, Audio 16000HZ 4 | 5 | import torch 6 | import numpy 7 | import time, pdb, argparse, subprocess, os, math, glob 8 | import cv2 9 | import python_speech_features 10 | 11 | from scipy import signal 12 | from scipy.io import wavfile 13 | from SyncNetModel import * 14 | from shutil import rmtree 15 | 16 | 17 | # ==================== Get OFFSET ==================== 18 | 19 | def calc_pdist(feat1, feat2, vshift=10): 20 | 21 | win_size = vshift*2+1 22 | 23 | feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift)) 24 | 25 | dists = [] 26 | 27 | for i in range(0,len(feat1)): 28 | 29 | dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:])) 30 | 31 | return dists 32 | 33 | # ==================== MAIN DEF ==================== 34 | 35 | class SyncNetInstance(torch.nn.Module): 36 | 37 | def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024): 38 | super(SyncNetInstance, self).__init__(); 39 | 40 | self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda(); 41 | 42 | def evaluate(self, opt, videofile): 43 | 44 | self.__S__.eval(); 45 | 46 | # ========== ========== 47 | # Convert files 48 | # ========== ========== 49 | 50 | if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)): 51 | rmtree(os.path.join(opt.tmp_dir,opt.reference)) 52 | 53 | os.makedirs(os.path.join(opt.tmp_dir,opt.reference)) 54 | 55 | command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg'))) 56 | output = subprocess.call(command, shell=True, stdout=None) 57 | 58 | command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))) 59 | output = subprocess.call(command, shell=True, stdout=None) 60 | 61 | # ========== ========== 62 | # Load video 63 | # ========== ========== 64 | 65 | images = [] 66 | 67 | flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg')) 68 | flist.sort() 69 | 70 | for fname in flist: 71 | img_input = cv2.imread(fname) 72 | img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE 73 | images.append(img_input) 74 | 75 | im = numpy.stack(images,axis=3) 76 | im = numpy.expand_dims(im,axis=0) 77 | im = numpy.transpose(im,(0,3,4,1,2)) 78 | 79 | imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) 80 | 81 | # ========== ========== 82 | # Load audio 83 | # ========== ========== 84 | 85 | sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav')) 86 | mfcc = zip(*python_speech_features.mfcc(audio,sample_rate)) 87 | mfcc = numpy.stack([numpy.array(i) for i in mfcc]) 88 | 89 | cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0) 90 | cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float()) 91 | 92 | # ========== ========== 93 | # Check audio and video input length 94 | # ========== ========== 95 | 96 | #if (float(len(audio))/16000) != (float(len(images))/25) : 97 | # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25)) 98 | 99 | min_length = min(len(images),math.floor(len(audio)/640)) 100 | 101 | # ========== ========== 102 | # Generate video and audio feats 103 | # ========== ========== 104 | 105 | lastframe = min_length-5 106 | im_feat = [] 107 | cc_feat = [] 108 | 109 | tS = time.time() 110 | for i in range(0,lastframe,opt.batch_size): 111 | 112 | im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] 113 | im_in = torch.cat(im_batch,0) 114 | im_out = self.__S__.forward_lip(im_in.cuda()); 115 | im_feat.append(im_out.data.cpu()) 116 | 117 | cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] 118 | cc_in = torch.cat(cc_batch,0) 119 | cc_out = self.__S__.forward_aud(cc_in.cuda()) 120 | cc_feat.append(cc_out.data.cpu()) 121 | 122 | im_feat = torch.cat(im_feat,0) 123 | cc_feat = torch.cat(cc_feat,0) 124 | 125 | # ========== ========== 126 | # Compute offset 127 | # ========== ========== 128 | 129 | #print('Compute time %.3f sec.' % (time.time()-tS)) 130 | 131 | dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift) 132 | mdist = torch.mean(torch.stack(dists,1),1) 133 | 134 | minval, minidx = torch.min(mdist,0) 135 | 136 | offset = opt.vshift-minidx 137 | conf = torch.median(mdist) - minval 138 | 139 | fdist = numpy.stack([dist[minidx].numpy() for dist in dists]) 140 | # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15) 141 | fconf = torch.median(mdist).numpy() - fdist 142 | fconfm = signal.medfilt(fconf,kernel_size=9) 143 | 144 | numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format}) 145 | #print('Framewise conf: ') 146 | #print(fconfm) 147 | #print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf)) 148 | 149 | dists_npy = numpy.array([ dist.numpy() for dist in dists ]) 150 | return offset.numpy(), conf.numpy(), minval.numpy() 151 | 152 | def extract_feature(self, opt, videofile): 153 | 154 | self.__S__.eval(); 155 | 156 | # ========== ========== 157 | # Load video 158 | # ========== ========== 159 | cap = cv2.VideoCapture(videofile) 160 | 161 | frame_num = 1; 162 | images = [] 163 | while frame_num: 164 | frame_num += 1 165 | ret, image = cap.read() 166 | if ret == 0: 167 | break 168 | 169 | images.append(image) 170 | 171 | im = numpy.stack(images,axis=3) 172 | im = numpy.expand_dims(im,axis=0) 173 | im = numpy.transpose(im,(0,3,4,1,2)) 174 | 175 | imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) 176 | 177 | # ========== ========== 178 | # Generate video feats 179 | # ========== ========== 180 | 181 | lastframe = len(images)-4 182 | im_feat = [] 183 | 184 | tS = time.time() 185 | for i in range(0,lastframe,opt.batch_size): 186 | 187 | im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] 188 | im_in = torch.cat(im_batch,0) 189 | im_out = self.__S__.forward_lipfeat(im_in.cuda()); 190 | im_feat.append(im_out.data.cpu()) 191 | 192 | im_feat = torch.cat(im_feat,0) 193 | 194 | # ========== ========== 195 | # Compute offset 196 | # ========== ========== 197 | 198 | print('Compute time %.3f sec.' % (time.time()-tS)) 199 | 200 | return im_feat 201 | 202 | 203 | def loadParameters(self, path): 204 | loaded_state = torch.load(path, map_location=lambda storage, loc: storage); 205 | 206 | self_state = self.__S__.state_dict(); 207 | 208 | for name, param in loaded_state.items(): 209 | 210 | self_state[name].copy_(param); 211 | -------------------------------------------------------------------------------- /evaluation/scores_LSE/calculate_scores_LRS.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import time, pdb, argparse, subprocess 5 | import glob 6 | import os 7 | from tqdm import tqdm 8 | 9 | from SyncNetInstance_calc_scores import * 10 | 11 | # ==================== LOAD PARAMS ==================== 12 | 13 | 14 | parser = argparse.ArgumentParser(description = "SyncNet"); 15 | 16 | parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); 17 | parser.add_argument('--batch_size', type=int, default='20', help=''); 18 | parser.add_argument('--vshift', type=int, default='15', help=''); 19 | parser.add_argument('--data_root', type=str, required=True, help=''); 20 | parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help=''); 21 | parser.add_argument('--reference', type=str, default="demo", help=''); 22 | 23 | opt = parser.parse_args(); 24 | 25 | 26 | # ==================== RUN EVALUATION ==================== 27 | 28 | s = SyncNetInstance(); 29 | 30 | s.loadParameters(opt.initial_model); 31 | #print("Model %s loaded."%opt.initial_model); 32 | path = os.path.join(opt.data_root, "*.mp4") 33 | 34 | all_videos = glob.glob(path) 35 | 36 | prog_bar = tqdm(range(len(all_videos))) 37 | avg_confidence = 0. 38 | avg_min_distance = 0. 39 | 40 | 41 | for videofile_idx in prog_bar: 42 | videofile = all_videos[videofile_idx] 43 | offset, confidence, min_distance = s.evaluate(opt, videofile=videofile) 44 | avg_confidence += confidence 45 | avg_min_distance += min_distance 46 | prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3))) 47 | prog_bar.refresh() 48 | 49 | print ('Average Confidence: {}'.format(avg_confidence/len(all_videos))) 50 | print ('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos))) 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /evaluation/scores_LSE/calculate_scores_real_videos.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import time, pdb, argparse, subprocess, pickle, os, gzip, glob 5 | 6 | from SyncNetInstance_calc_scores import * 7 | 8 | # ==================== PARSE ARGUMENT ==================== 9 | 10 | parser = argparse.ArgumentParser(description = "SyncNet"); 11 | parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); 12 | parser.add_argument('--batch_size', type=int, default='20', help=''); 13 | parser.add_argument('--vshift', type=int, default='15', help=''); 14 | parser.add_argument('--data_dir', type=str, default='data/work', help=''); 15 | parser.add_argument('--videofile', type=str, default='', help=''); 16 | parser.add_argument('--reference', type=str, default='', help=''); 17 | opt = parser.parse_args(); 18 | 19 | setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi')) 20 | setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp')) 21 | setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork')) 22 | setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop')) 23 | 24 | 25 | # ==================== LOAD MODEL AND FILE LIST ==================== 26 | 27 | s = SyncNetInstance(); 28 | 29 | s.loadParameters(opt.initial_model); 30 | #print("Model %s loaded."%opt.initial_model); 31 | 32 | flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi')) 33 | flist.sort() 34 | 35 | # ==================== GET OFFSETS ==================== 36 | 37 | dists = [] 38 | for idx, fname in enumerate(flist): 39 | offset, conf, dist = s.evaluate(opt,videofile=fname) 40 | print (str(dist)+" "+str(conf)) 41 | 42 | # ==================== PRINT RESULTS TO FILE ==================== 43 | 44 | #with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil: 45 | # pickle.dump(dists, fil) 46 | -------------------------------------------------------------------------------- /evaluation/scores_LSE/calculate_scores_real_videos.sh: -------------------------------------------------------------------------------- 1 | rm all_scores.txt 2 | yourfilenames=`ls $1` 3 | 4 | for eachfile in $yourfilenames 5 | do 6 | python run_pipeline.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir 7 | python calculate_scores_real_videos.py --videofile $1/$eachfile --reference wav2lip --data_dir tmp_dir >> all_scores.txt 8 | done 9 | -------------------------------------------------------------------------------- /evaluation/test_filelists/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the filelists for the new evaluation framework proposed in the paper. 2 | 3 | ## Test filelists for LRS2, LRS3, and LRW. 4 | 5 | This folder contains three filelists, each containing a list of names of audio-video pairs from the test sets of LRS2, LRS3, and LRW. The LRS2 and LRW filelists are strictly "Copyright BBC" and can only be used for “non-commercial research by applicants who have an agreement with the BBC to access the Lip Reading in the Wild and/or Lip Reading Sentences in the Wild datasets”. Please follow this link for more details: [https://www.bbc.co.uk/rd/projects/lip-reading-datasets](https://www.bbc.co.uk/rd/projects/lip-reading-datasets). 6 | 7 | 8 | ## ReSynCED benchmark 9 | 10 | The sub-folder `ReSynCED` contains filelists for our own Real-world lip-Sync Evaluation Dataset (ReSyncED). 11 | 12 | 13 | #### Instructions on how to use the above two filelists are available in the README of the parent folder. 14 | -------------------------------------------------------------------------------- /evaluation/test_filelists/ReSyncED/random_pairs.txt: -------------------------------------------------------------------------------- 1 | sachin.mp4 emma_cropped.mp4 2 | sachin.mp4 mourinho.mp4 3 | sachin.mp4 elon.mp4 4 | sachin.mp4 messi2.mp4 5 | sachin.mp4 cr1.mp4 6 | sachin.mp4 sachin.mp4 7 | sachin.mp4 sg.mp4 8 | sachin.mp4 fergi.mp4 9 | sachin.mp4 spanish_lec1.mp4 10 | sachin.mp4 bush_small.mp4 11 | sachin.mp4 macca_cut.mp4 12 | sachin.mp4 ca_cropped.mp4 13 | sachin.mp4 lecun.mp4 14 | sachin.mp4 spanish_lec0.mp4 15 | srk.mp4 emma_cropped.mp4 16 | srk.mp4 mourinho.mp4 17 | srk.mp4 elon.mp4 18 | srk.mp4 messi2.mp4 19 | srk.mp4 cr1.mp4 20 | srk.mp4 srk.mp4 21 | srk.mp4 sachin.mp4 22 | srk.mp4 sg.mp4 23 | srk.mp4 fergi.mp4 24 | srk.mp4 spanish_lec1.mp4 25 | srk.mp4 bush_small.mp4 26 | srk.mp4 macca_cut.mp4 27 | srk.mp4 ca_cropped.mp4 28 | srk.mp4 guardiola.mp4 29 | srk.mp4 lecun.mp4 30 | srk.mp4 spanish_lec0.mp4 31 | cr1.mp4 emma_cropped.mp4 32 | cr1.mp4 elon.mp4 33 | cr1.mp4 messi2.mp4 34 | cr1.mp4 cr1.mp4 35 | cr1.mp4 spanish_lec1.mp4 36 | cr1.mp4 bush_small.mp4 37 | cr1.mp4 macca_cut.mp4 38 | cr1.mp4 ca_cropped.mp4 39 | cr1.mp4 lecun.mp4 40 | cr1.mp4 spanish_lec0.mp4 41 | macca_cut.mp4 emma_cropped.mp4 42 | macca_cut.mp4 elon.mp4 43 | macca_cut.mp4 messi2.mp4 44 | macca_cut.mp4 spanish_lec1.mp4 45 | macca_cut.mp4 macca_cut.mp4 46 | macca_cut.mp4 ca_cropped.mp4 47 | macca_cut.mp4 spanish_lec0.mp4 48 | lecun.mp4 emma_cropped.mp4 49 | lecun.mp4 elon.mp4 50 | lecun.mp4 messi2.mp4 51 | lecun.mp4 spanish_lec1.mp4 52 | lecun.mp4 macca_cut.mp4 53 | lecun.mp4 ca_cropped.mp4 54 | lecun.mp4 lecun.mp4 55 | lecun.mp4 spanish_lec0.mp4 56 | messi2.mp4 emma_cropped.mp4 57 | messi2.mp4 elon.mp4 58 | messi2.mp4 messi2.mp4 59 | messi2.mp4 spanish_lec1.mp4 60 | messi2.mp4 macca_cut.mp4 61 | messi2.mp4 ca_cropped.mp4 62 | messi2.mp4 spanish_lec0.mp4 63 | ca_cropped.mp4 emma_cropped.mp4 64 | ca_cropped.mp4 elon.mp4 65 | ca_cropped.mp4 spanish_lec1.mp4 66 | ca_cropped.mp4 ca_cropped.mp4 67 | ca_cropped.mp4 spanish_lec0.mp4 68 | spanish_lec1.mp4 spanish_lec1.mp4 69 | spanish_lec1.mp4 spanish_lec0.mp4 70 | elon.mp4 elon.mp4 71 | elon.mp4 spanish_lec1.mp4 72 | elon.mp4 spanish_lec0.mp4 73 | guardiola.mp4 emma_cropped.mp4 74 | guardiola.mp4 mourinho.mp4 75 | guardiola.mp4 elon.mp4 76 | guardiola.mp4 messi2.mp4 77 | guardiola.mp4 cr1.mp4 78 | guardiola.mp4 sachin.mp4 79 | guardiola.mp4 sg.mp4 80 | guardiola.mp4 fergi.mp4 81 | guardiola.mp4 spanish_lec1.mp4 82 | guardiola.mp4 bush_small.mp4 83 | guardiola.mp4 macca_cut.mp4 84 | guardiola.mp4 ca_cropped.mp4 85 | guardiola.mp4 guardiola.mp4 86 | guardiola.mp4 lecun.mp4 87 | guardiola.mp4 spanish_lec0.mp4 88 | fergi.mp4 emma_cropped.mp4 89 | fergi.mp4 mourinho.mp4 90 | fergi.mp4 elon.mp4 91 | fergi.mp4 messi2.mp4 92 | fergi.mp4 cr1.mp4 93 | fergi.mp4 sachin.mp4 94 | fergi.mp4 sg.mp4 95 | fergi.mp4 fergi.mp4 96 | fergi.mp4 spanish_lec1.mp4 97 | fergi.mp4 bush_small.mp4 98 | fergi.mp4 macca_cut.mp4 99 | fergi.mp4 ca_cropped.mp4 100 | fergi.mp4 lecun.mp4 101 | fergi.mp4 spanish_lec0.mp4 102 | spanish.mp4 emma_cropped.mp4 103 | spanish.mp4 spanish.mp4 104 | spanish.mp4 mourinho.mp4 105 | spanish.mp4 elon.mp4 106 | spanish.mp4 messi2.mp4 107 | spanish.mp4 cr1.mp4 108 | spanish.mp4 srk.mp4 109 | spanish.mp4 sachin.mp4 110 | spanish.mp4 sg.mp4 111 | spanish.mp4 fergi.mp4 112 | spanish.mp4 spanish_lec1.mp4 113 | spanish.mp4 bush_small.mp4 114 | spanish.mp4 macca_cut.mp4 115 | spanish.mp4 ca_cropped.mp4 116 | spanish.mp4 guardiola.mp4 117 | spanish.mp4 lecun.mp4 118 | spanish.mp4 spanish_lec0.mp4 119 | bush_small.mp4 emma_cropped.mp4 120 | bush_small.mp4 elon.mp4 121 | bush_small.mp4 messi2.mp4 122 | bush_small.mp4 spanish_lec1.mp4 123 | bush_small.mp4 bush_small.mp4 124 | bush_small.mp4 macca_cut.mp4 125 | bush_small.mp4 ca_cropped.mp4 126 | bush_small.mp4 lecun.mp4 127 | bush_small.mp4 spanish_lec0.mp4 128 | emma_cropped.mp4 emma_cropped.mp4 129 | emma_cropped.mp4 elon.mp4 130 | emma_cropped.mp4 spanish_lec1.mp4 131 | emma_cropped.mp4 spanish_lec0.mp4 132 | sg.mp4 emma_cropped.mp4 133 | sg.mp4 mourinho.mp4 134 | sg.mp4 elon.mp4 135 | sg.mp4 messi2.mp4 136 | sg.mp4 cr1.mp4 137 | sg.mp4 sachin.mp4 138 | sg.mp4 sg.mp4 139 | sg.mp4 fergi.mp4 140 | sg.mp4 spanish_lec1.mp4 141 | sg.mp4 bush_small.mp4 142 | sg.mp4 macca_cut.mp4 143 | sg.mp4 ca_cropped.mp4 144 | sg.mp4 lecun.mp4 145 | sg.mp4 spanish_lec0.mp4 146 | spanish_lec0.mp4 spanish_lec0.mp4 147 | mourinho.mp4 emma_cropped.mp4 148 | mourinho.mp4 mourinho.mp4 149 | mourinho.mp4 elon.mp4 150 | mourinho.mp4 messi2.mp4 151 | mourinho.mp4 cr1.mp4 152 | mourinho.mp4 sachin.mp4 153 | mourinho.mp4 sg.mp4 154 | mourinho.mp4 fergi.mp4 155 | mourinho.mp4 spanish_lec1.mp4 156 | mourinho.mp4 bush_small.mp4 157 | mourinho.mp4 macca_cut.mp4 158 | mourinho.mp4 ca_cropped.mp4 159 | mourinho.mp4 lecun.mp4 160 | mourinho.mp4 spanish_lec0.mp4 161 | -------------------------------------------------------------------------------- /evaluation/test_filelists/ReSyncED/tts_pairs.txt: -------------------------------------------------------------------------------- 1 | adam_1.mp4 andreng_optimization.wav 2 | agad_2.mp4 agad_2.wav 3 | agad_1.mp4 agad_1.wav 4 | agad_3.mp4 agad_3.wav 5 | rms_prop_1.mp4 rms_prop_tts.wav 6 | tf_1.mp4 tf_1.wav 7 | tf_2.mp4 tf_2.wav 8 | andrew_ng_ai_business.mp4 andrewng_business_tts.wav 9 | covid_autopsy_1.mp4 autopsy_tts.wav 10 | news_1.mp4 news_tts.wav 11 | andrew_ng_fund_1.mp4 andrewng_ai_fund.wav 12 | covid_treatments_1.mp4 covid_tts.wav 13 | pytorch_v_tf.mp4 pytorch_vs_tf_eng.wav 14 | pytorch_1.mp4 pytorch.wav 15 | pkb_1.mp4 pkb_1.wav 16 | ss_1.mp4 ss_1.wav 17 | carlsen_1.mp4 carlsen_eng.wav 18 | french.mp4 french.wav -------------------------------------------------------------------------------- /face_detection/README.md: -------------------------------------------------------------------------------- 1 | The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. -------------------------------------------------------------------------------- /face_detection/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __author__ = """Adrian Bulat""" 4 | __email__ = 'adrian.bulat@nottingham.ac.uk' 5 | __version__ = '1.0.1' 6 | 7 | from .api import FaceAlignment, LandmarksType, NetworkSize 8 | -------------------------------------------------------------------------------- /face_detection/api.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from torch.utils.model_zoo import load_url 5 | from enum import Enum 6 | import numpy as np 7 | import cv2 8 | try: 9 | import urllib.request as request_file 10 | except BaseException: 11 | import urllib as request_file 12 | 13 | from .models import FAN, ResNetDepth 14 | from .utils import * 15 | 16 | 17 | class LandmarksType(Enum): 18 | """Enum class defining the type of landmarks to detect. 19 | 20 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face 21 | ``_2halfD`` - this points represent the projection of the 3D points into 3D 22 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space 23 | 24 | """ 25 | _2D = 1 26 | _2halfD = 2 27 | _3D = 3 28 | 29 | 30 | class NetworkSize(Enum): 31 | # TINY = 1 32 | # SMALL = 2 33 | # MEDIUM = 3 34 | LARGE = 4 35 | 36 | def __new__(cls, value): 37 | member = object.__new__(cls) 38 | member._value_ = value 39 | return member 40 | 41 | def __int__(self): 42 | return self.value 43 | 44 | ROOT = os.path.dirname(os.path.abspath(__file__)) 45 | 46 | class FaceAlignment: 47 | def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, 48 | device='cuda', flip_input=False, face_detector='sfd', verbose=False): 49 | self.device = device 50 | self.flip_input = flip_input 51 | self.landmarks_type = landmarks_type 52 | self.verbose = verbose 53 | 54 | network_size = int(network_size) 55 | 56 | if 'cuda' in device: 57 | torch.backends.cudnn.benchmark = True 58 | 59 | # Get the face detector 60 | face_detector_module = __import__('face_detection.detection.' + face_detector, 61 | globals(), locals(), [face_detector], 0) 62 | self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) 63 | 64 | def get_detections_for_batch(self, images): 65 | images = images[..., ::-1] 66 | detected_faces = self.face_detector.detect_from_batch(images.copy()) 67 | results = [] 68 | 69 | for i, d in enumerate(detected_faces): 70 | if len(d) == 0: 71 | results.append(None) 72 | continue 73 | d = d[0] 74 | d = np.clip(d, 0, None) 75 | 76 | x1, y1, x2, y2 = map(int, d[:-1]) 77 | results.append((x1, y1, x2, y2)) 78 | 79 | return results -------------------------------------------------------------------------------- /face_detection/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FaceDetector -------------------------------------------------------------------------------- /face_detection/detection/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import cv2 7 | 8 | 9 | class FaceDetector(object): 10 | """An abstract class representing a face detector. 11 | 12 | Any other face detection implementation must subclass it. All subclasses 13 | must implement ``detect_from_image``, that return a list of detected 14 | bounding boxes. Optionally, for speed considerations detect from path is 15 | recommended. 16 | """ 17 | 18 | def __init__(self, device, verbose): 19 | self.device = device 20 | self.verbose = verbose 21 | 22 | if verbose: 23 | if 'cpu' in device: 24 | logger = logging.getLogger(__name__) 25 | logger.warning("Detection running on CPU, this may be potentially slow.") 26 | 27 | if 'cpu' not in device and 'cuda' not in device: 28 | if verbose: 29 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) 30 | raise ValueError 31 | 32 | def detect_from_image(self, tensor_or_path): 33 | """Detects faces in a given image. 34 | 35 | This function detects the faces present in a provided BGR(usually) 36 | image. The input can be either the image itself or the path to it. 37 | 38 | Arguments: 39 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path 40 | to an image or the image itself. 41 | 42 | Example:: 43 | 44 | >>> path_to_image = 'data/image_01.jpg' 45 | ... detected_faces = detect_from_image(path_to_image) 46 | [A list of bounding boxes (x1, y1, x2, y2)] 47 | >>> image = cv2.imread(path_to_image) 48 | ... detected_faces = detect_from_image(image) 49 | [A list of bounding boxes (x1, y1, x2, y2)] 50 | 51 | """ 52 | raise NotImplementedError 53 | 54 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): 55 | """Detects faces from all the images present in a given directory. 56 | 57 | Arguments: 58 | path {string} -- a string containing a path that points to the folder containing the images 59 | 60 | Keyword Arguments: 61 | extensions {list} -- list of string containing the extensions to be 62 | consider in the following format: ``.extension_name`` (default: 63 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the 64 | folder recursively (default: {False}) show_progress_bar {bool} -- 65 | display a progressbar (default: {True}) 66 | 67 | Example: 68 | >>> directory = 'data' 69 | ... detected_faces = detect_from_directory(directory) 70 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} 71 | 72 | """ 73 | if self.verbose: 74 | logger = logging.getLogger(__name__) 75 | 76 | if len(extensions) == 0: 77 | if self.verbose: 78 | logger.error("Expected at list one extension, but none was received.") 79 | raise ValueError 80 | 81 | if self.verbose: 82 | logger.info("Constructing the list of images.") 83 | additional_pattern = '/**/*' if recursive else '/*' 84 | files = [] 85 | for extension in extensions: 86 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) 87 | 88 | if self.verbose: 89 | logger.info("Finished searching for images. %s images found", len(files)) 90 | logger.info("Preparing to run the detection.") 91 | 92 | predictions = {} 93 | for image_path in tqdm(files, disable=not show_progress_bar): 94 | if self.verbose: 95 | logger.info("Running the face detector on image: %s", image_path) 96 | predictions[image_path] = self.detect_from_image(image_path) 97 | 98 | if self.verbose: 99 | logger.info("The detector was successfully run on all %s images", len(files)) 100 | 101 | return predictions 102 | 103 | @property 104 | def reference_scale(self): 105 | raise NotImplementedError 106 | 107 | @property 108 | def reference_x_shift(self): 109 | raise NotImplementedError 110 | 111 | @property 112 | def reference_y_shift(self): 113 | raise NotImplementedError 114 | 115 | @staticmethod 116 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): 117 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray 118 | 119 | Arguments: 120 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself 121 | """ 122 | if isinstance(tensor_or_path, str): 123 | return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1] 124 | elif torch.is_tensor(tensor_or_path): 125 | # Call cpu in case its coming from cuda 126 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() 127 | elif isinstance(tensor_or_path, np.ndarray): 128 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path 129 | else: 130 | raise TypeError 131 | -------------------------------------------------------------------------------- /face_detection/detection/sfd/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfd_detector import SFDDetector as FaceDetector -------------------------------------------------------------------------------- /face_detection/detection/sfd/bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import cv2 5 | import random 6 | import datetime 7 | import time 8 | import math 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | try: 14 | from iou import IOU 15 | except BaseException: 16 | # IOU cython speedup 10x 17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): 18 | sa = abs((ax2 - ax1) * (ay2 - ay1)) 19 | sb = abs((bx2 - bx1) * (by2 - by1)) 20 | x1, y1 = max(ax1, bx1), max(ay1, by1) 21 | x2, y2 = min(ax2, bx2), min(ay2, by2) 22 | w = x2 - x1 23 | h = y2 - y1 24 | if w < 0 or h < 0: 25 | return 0.0 26 | else: 27 | return 1.0 * w * h / (sa + sb - w * h) 28 | 29 | 30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): 31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh 33 | dw, dh = math.log(ww / aww), math.log(hh / ahh) 34 | return dx, dy, dw, dh 35 | 36 | 37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): 38 | xc, yc = dx * aww + axc, dy * ahh + ayc 39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh 40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 41 | return x1, y1, x2, y2 42 | 43 | 44 | def nms(dets, thresh): 45 | if 0 == len(dets): 46 | return [] 47 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] 48 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 49 | order = scores.argsort()[::-1] 50 | 51 | keep = [] 52 | while order.size > 0: 53 | i = order[0] 54 | keep.append(i) 55 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) 56 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) 57 | 58 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) 59 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h) 60 | 61 | inds = np.where(ovr <= thresh)[0] 62 | order = order[inds + 1] 63 | 64 | return keep 65 | 66 | 67 | def encode(matched, priors, variances): 68 | """Encode the variances from the priorbox layers into the ground truth boxes 69 | we have matched (based on jaccard overlap) with the prior boxes. 70 | Args: 71 | matched: (tensor) Coords of ground truth for each prior in point-form 72 | Shape: [num_priors, 4]. 73 | priors: (tensor) Prior boxes in center-offset form 74 | Shape: [num_priors,4]. 75 | variances: (list[float]) Variances of priorboxes 76 | Return: 77 | encoded boxes (tensor), Shape: [num_priors, 4] 78 | """ 79 | 80 | # dist b/t match center and prior's center 81 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 82 | # encode variance 83 | g_cxcy /= (variances[0] * priors[:, 2:]) 84 | # match wh / prior wh 85 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 86 | g_wh = torch.log(g_wh) / variances[1] 87 | # return target for smooth_l1_loss 88 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 89 | 90 | 91 | def decode(loc, priors, variances): 92 | """Decode locations from predictions using priors to undo 93 | the encoding we did for offset regression at train time. 94 | Args: 95 | loc (tensor): location predictions for loc layers, 96 | Shape: [num_priors,4] 97 | priors (tensor): Prior boxes in center-offset form. 98 | Shape: [num_priors,4]. 99 | variances: (list[float]) Variances of priorboxes 100 | Return: 101 | decoded bounding box predictions 102 | """ 103 | 104 | boxes = torch.cat(( 105 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 106 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 107 | boxes[:, :2] -= boxes[:, 2:] / 2 108 | boxes[:, 2:] += boxes[:, :2] 109 | return boxes 110 | 111 | def batch_decode(loc, priors, variances): 112 | """Decode locations from predictions using priors to undo 113 | the encoding we did for offset regression at train time. 114 | Args: 115 | loc (tensor): location predictions for loc layers, 116 | Shape: [num_priors,4] 117 | priors (tensor): Prior boxes in center-offset form. 118 | Shape: [num_priors,4]. 119 | variances: (list[float]) Variances of priorboxes 120 | Return: 121 | decoded bounding box predictions 122 | """ 123 | 124 | boxes = torch.cat(( 125 | priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], 126 | priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2) 127 | boxes[:, :, :2] -= boxes[:, :, 2:] / 2 128 | boxes[:, :, 2:] += boxes[:, :, :2] 129 | return boxes 130 | -------------------------------------------------------------------------------- /face_detection/detection/sfd/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import datetime 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | import scipy.io as sio 14 | import zipfile 15 | from .net_s3fd import s3fd 16 | from .bbox import * 17 | 18 | 19 | def detect(net, img, device): 20 | img = img - np.array([104, 117, 123]) 21 | img = img.transpose(2, 0, 1) 22 | img = img.reshape((1,) + img.shape) 23 | 24 | if 'cuda' in device: 25 | torch.backends.cudnn.benchmark = True 26 | 27 | img = torch.from_numpy(img).float().to(device) 28 | BB, CC, HH, WW = img.size() 29 | with torch.no_grad(): 30 | olist = net(img) 31 | 32 | bboxlist = [] 33 | for i in range(len(olist) // 2): 34 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 35 | olist = [oelem.data.cpu() for oelem in olist] 36 | for i in range(len(olist) // 2): 37 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 38 | FB, FC, FH, FW = ocls.size() # feature map size 39 | stride = 2**(i + 2) # 4,8,16,32,64,128 40 | anchor = stride * 4 41 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 42 | for Iindex, hindex, windex in poss: 43 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 44 | score = ocls[0, 1, hindex, windex] 45 | loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) 46 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) 47 | variances = [0.1, 0.2] 48 | box = decode(loc, priors, variances) 49 | x1, y1, x2, y2 = box[0] * 1.0 50 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) 51 | bboxlist.append([x1, y1, x2, y2, score]) 52 | bboxlist = np.array(bboxlist) 53 | if 0 == len(bboxlist): 54 | bboxlist = np.zeros((1, 5)) 55 | 56 | return bboxlist 57 | 58 | def batch_detect(net, imgs, device): 59 | imgs = imgs - np.array([104, 117, 123]) 60 | imgs = imgs.transpose(0, 3, 1, 2) 61 | 62 | if 'cuda' in device: 63 | torch.backends.cudnn.benchmark = True 64 | 65 | imgs = torch.from_numpy(imgs).float().to(device) 66 | BB, CC, HH, WW = imgs.size() 67 | with torch.no_grad(): 68 | olist = net(imgs) 69 | 70 | bboxlist = [] 71 | for i in range(len(olist) // 2): 72 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 73 | olist = [oelem.data.cpu() for oelem in olist] 74 | for i in range(len(olist) // 2): 75 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 76 | FB, FC, FH, FW = ocls.size() # feature map size 77 | stride = 2**(i + 2) # 4,8,16,32,64,128 78 | anchor = stride * 4 79 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 80 | for Iindex, hindex, windex in poss: 81 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 82 | score = ocls[:, 1, hindex, windex] 83 | loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4) 84 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4) 85 | variances = [0.1, 0.2] 86 | box = batch_decode(loc, priors, variances) 87 | box = box[:, 0] * 1.0 88 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) 89 | bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy()) 90 | bboxlist = np.array(bboxlist) 91 | if 0 == len(bboxlist): 92 | bboxlist = np.zeros((1, BB, 5)) 93 | 94 | return bboxlist 95 | 96 | def flip_detect(net, img, device): 97 | img = cv2.flip(img, 1) 98 | b = detect(net, img, device) 99 | 100 | bboxlist = np.zeros(b.shape) 101 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 102 | bboxlist[:, 1] = b[:, 1] 103 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 104 | bboxlist[:, 3] = b[:, 3] 105 | bboxlist[:, 4] = b[:, 4] 106 | return bboxlist 107 | 108 | 109 | def pts_to_bb(pts): 110 | min_x, min_y = np.min(pts, axis=0) 111 | max_x, max_y = np.max(pts, axis=0) 112 | return np.array([min_x, min_y, max_x, max_y]) 113 | -------------------------------------------------------------------------------- /face_detection/detection/sfd/net_s3fd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Norm(nn.Module): 7 | def __init__(self, n_channels, scale=1.0): 8 | super(L2Norm, self).__init__() 9 | self.n_channels = n_channels 10 | self.scale = scale 11 | self.eps = 1e-10 12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 13 | self.weight.data *= 0.0 14 | self.weight.data += self.scale 15 | 16 | def forward(self, x): 17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 18 | x = x / norm * self.weight.view(1, -1, 1, 1) 19 | return x 20 | 21 | 22 | class s3fd(nn.Module): 23 | def __init__(self): 24 | super(s3fd, self).__init__() 25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 27 | 28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 30 | 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 42 | 43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) 44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) 45 | 46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 48 | 49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 51 | 52 | self.conv3_3_norm = L2Norm(256, scale=10) 53 | self.conv4_3_norm = L2Norm(512, scale=8) 54 | self.conv5_3_norm = L2Norm(512, scale=5) 55 | 56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 62 | 63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) 64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) 65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) 68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 69 | 70 | def forward(self, x): 71 | h = F.relu(self.conv1_1(x)) 72 | h = F.relu(self.conv1_2(h)) 73 | h = F.max_pool2d(h, 2, 2) 74 | 75 | h = F.relu(self.conv2_1(h)) 76 | h = F.relu(self.conv2_2(h)) 77 | h = F.max_pool2d(h, 2, 2) 78 | 79 | h = F.relu(self.conv3_1(h)) 80 | h = F.relu(self.conv3_2(h)) 81 | h = F.relu(self.conv3_3(h)) 82 | f3_3 = h 83 | h = F.max_pool2d(h, 2, 2) 84 | 85 | h = F.relu(self.conv4_1(h)) 86 | h = F.relu(self.conv4_2(h)) 87 | h = F.relu(self.conv4_3(h)) 88 | f4_3 = h 89 | h = F.max_pool2d(h, 2, 2) 90 | 91 | h = F.relu(self.conv5_1(h)) 92 | h = F.relu(self.conv5_2(h)) 93 | h = F.relu(self.conv5_3(h)) 94 | f5_3 = h 95 | h = F.max_pool2d(h, 2, 2) 96 | 97 | h = F.relu(self.fc6(h)) 98 | h = F.relu(self.fc7(h)) 99 | ffc7 = h 100 | h = F.relu(self.conv6_1(h)) 101 | h = F.relu(self.conv6_2(h)) 102 | f6_2 = h 103 | h = F.relu(self.conv7_1(h)) 104 | h = F.relu(self.conv7_2(h)) 105 | f7_2 = h 106 | 107 | f3_3 = self.conv3_3_norm(f3_3) 108 | f4_3 = self.conv4_3_norm(f4_3) 109 | f5_3 = self.conv5_3_norm(f5_3) 110 | 111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3) 112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3) 113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3) 114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3) 115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3) 116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3) 117 | cls4 = self.fc7_mbox_conf(ffc7) 118 | reg4 = self.fc7_mbox_loc(ffc7) 119 | cls5 = self.conv6_2_mbox_conf(f6_2) 120 | reg5 = self.conv6_2_mbox_loc(f6_2) 121 | cls6 = self.conv7_2_mbox_conf(f7_2) 122 | reg6 = self.conv7_2_mbox_loc(f7_2) 123 | 124 | # max-out background label 125 | chunk = torch.chunk(cls1, 4, 1) 126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) 127 | cls1 = torch.cat([bmax, chunk[3]], dim=1) 128 | 129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] 130 | -------------------------------------------------------------------------------- /face_detection/detection/sfd/sfd_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from torch.utils.model_zoo import load_url 4 | 5 | from ..core import FaceDetector 6 | 7 | from .net_s3fd import s3fd 8 | from .bbox import * 9 | from .detect import * 10 | 11 | models_urls = { 12 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', 13 | } 14 | 15 | 16 | class SFDDetector(FaceDetector): 17 | def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): 18 | super(SFDDetector, self).__init__(device, verbose) 19 | 20 | # Initialise the face detector 21 | if not os.path.isfile(path_to_detector): 22 | model_weights = load_url(models_urls['s3fd']) 23 | else: 24 | model_weights = torch.load(path_to_detector) 25 | 26 | self.face_detector = s3fd() 27 | self.face_detector.load_state_dict(model_weights) 28 | self.face_detector.to(device) 29 | self.face_detector.eval() 30 | 31 | def detect_from_image(self, tensor_or_path): 32 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 33 | 34 | bboxlist = detect(self.face_detector, image, device=self.device) 35 | keep = nms(bboxlist, 0.3) 36 | bboxlist = bboxlist[keep, :] 37 | bboxlist = [x for x in bboxlist if x[-1] > 0.5] 38 | 39 | return bboxlist 40 | 41 | def detect_from_batch(self, images): 42 | bboxlists = batch_detect(self.face_detector, images, device=self.device) 43 | keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])] 44 | bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] 45 | bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists] 46 | 47 | return bboxlists 48 | 49 | @property 50 | def reference_scale(self): 51 | return 195 52 | 53 | @property 54 | def reference_x_shift(self): 55 | return 0 56 | 57 | @property 58 | def reference_y_shift(self): 59 | return 0 60 | -------------------------------------------------------------------------------- /face_detection/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 10 | stride=strd, padding=padding, bias=bias) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes): 15 | super(ConvBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 22 | 23 | if in_planes != out_planes: 24 | self.downsample = nn.Sequential( 25 | nn.BatchNorm2d(in_planes), 26 | nn.ReLU(True), 27 | nn.Conv2d(in_planes, out_planes, 28 | kernel_size=1, stride=1, bias=False), 29 | ) 30 | else: 31 | self.downsample = None 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out1 = self.bn1(x) 37 | out1 = F.relu(out1, True) 38 | out1 = self.conv1(out1) 39 | 40 | out2 = self.bn2(out1) 41 | out2 = F.relu(out2, True) 42 | out2 = self.conv2(out2) 43 | 44 | out3 = self.bn3(out2) 45 | out3 = F.relu(out3, True) 46 | out3 = self.conv3(out3) 47 | 48 | out3 = torch.cat((out1, out2, out3), 1) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(residual) 52 | 53 | out3 += residual 54 | 55 | return out3 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class HourGlass(nn.Module): 99 | def __init__(self, num_modules, depth, num_features): 100 | super(HourGlass, self).__init__() 101 | self.num_modules = num_modules 102 | self.depth = depth 103 | self.features = num_features 104 | 105 | self._generate_network(self.depth) 106 | 107 | def _generate_network(self, level): 108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) 109 | 110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) 111 | 112 | if level > 1: 113 | self._generate_network(level - 1) 114 | else: 115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) 116 | 117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) 118 | 119 | def _forward(self, level, inp): 120 | # Upper branch 121 | up1 = inp 122 | up1 = self._modules['b1_' + str(level)](up1) 123 | 124 | # Lower branch 125 | low1 = F.avg_pool2d(inp, 2, stride=2) 126 | low1 = self._modules['b2_' + str(level)](low1) 127 | 128 | if level > 1: 129 | low2 = self._forward(level - 1, low1) 130 | else: 131 | low2 = low1 132 | low2 = self._modules['b2_plus_' + str(level)](low2) 133 | 134 | low3 = low2 135 | low3 = self._modules['b3_' + str(level)](low3) 136 | 137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 138 | 139 | return up1 + up2 140 | 141 | def forward(self, x): 142 | return self._forward(self.depth, x) 143 | 144 | 145 | class FAN(nn.Module): 146 | 147 | def __init__(self, num_modules=1): 148 | super(FAN, self).__init__() 149 | self.num_modules = num_modules 150 | 151 | # Base part 152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 153 | self.bn1 = nn.BatchNorm2d(64) 154 | self.conv2 = ConvBlock(64, 128) 155 | self.conv3 = ConvBlock(128, 128) 156 | self.conv4 = ConvBlock(128, 256) 157 | 158 | # Stacking part 159 | for hg_module in range(self.num_modules): 160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) 161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 162 | self.add_module('conv_last' + str(hg_module), 163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 165 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 166 | 68, kernel_size=1, stride=1, padding=0)) 167 | 168 | if hg_module < self.num_modules - 1: 169 | self.add_module( 170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 171 | self.add_module('al' + str(hg_module), nn.Conv2d(68, 172 | 256, kernel_size=1, stride=1, padding=0)) 173 | 174 | def forward(self, x): 175 | x = F.relu(self.bn1(self.conv1(x)), True) 176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 177 | x = self.conv3(x) 178 | x = self.conv4(x) 179 | 180 | previous = x 181 | 182 | outputs = [] 183 | for i in range(self.num_modules): 184 | hg = self._modules['m' + str(i)](previous) 185 | 186 | ll = hg 187 | ll = self._modules['top_m_' + str(i)](ll) 188 | 189 | ll = F.relu(self._modules['bn_end' + str(i)] 190 | (self._modules['conv_last' + str(i)](ll)), True) 191 | 192 | # Predict heatmaps 193 | tmp_out = self._modules['l' + str(i)](ll) 194 | outputs.append(tmp_out) 195 | 196 | if i < self.num_modules - 1: 197 | ll = self._modules['bl' + str(i)](ll) 198 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 199 | previous = previous + ll + tmp_out_ 200 | 201 | return outputs 202 | 203 | 204 | class ResNetDepth(nn.Module): 205 | 206 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): 207 | self.inplanes = 64 208 | super(ResNetDepth, self).__init__() 209 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, 210 | bias=False) 211 | self.bn1 = nn.BatchNorm2d(64) 212 | self.relu = nn.ReLU(inplace=True) 213 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 214 | self.layer1 = self._make_layer(block, 64, layers[0]) 215 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 216 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 217 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 218 | self.avgpool = nn.AvgPool2d(7) 219 | self.fc = nn.Linear(512 * block.expansion, num_classes) 220 | 221 | for m in self.modules(): 222 | if isinstance(m, nn.Conv2d): 223 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 224 | m.weight.data.normal_(0, math.sqrt(2. / n)) 225 | elif isinstance(m, nn.BatchNorm2d): 226 | m.weight.data.fill_(1) 227 | m.bias.data.zero_() 228 | 229 | def _make_layer(self, block, planes, blocks, stride=1): 230 | downsample = None 231 | if stride != 1 or self.inplanes != planes * block.expansion: 232 | downsample = nn.Sequential( 233 | nn.Conv2d(self.inplanes, planes * block.expansion, 234 | kernel_size=1, stride=stride, bias=False), 235 | nn.BatchNorm2d(planes * block.expansion), 236 | ) 237 | 238 | layers = [] 239 | layers.append(block(self.inplanes, planes, stride, downsample)) 240 | self.inplanes = planes * block.expansion 241 | for i in range(1, blocks): 242 | layers.append(block(self.inplanes, planes)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x): 247 | x = self.conv1(x) 248 | x = self.bn1(x) 249 | x = self.relu(x) 250 | x = self.maxpool(x) 251 | 252 | x = self.layer1(x) 253 | x = self.layer2(x) 254 | x = self.layer3(x) 255 | x = self.layer4(x) 256 | 257 | x = self.avgpool(x) 258 | x = x.view(x.size(0), -1) 259 | x = self.fc(x) 260 | 261 | return x 262 | -------------------------------------------------------------------------------- /face_detection/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import torch 6 | import math 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def _gaussian( 12 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None, 13 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, 14 | mean_vert=0.5): 15 | # handle some defaults 16 | if width is None: 17 | width = size 18 | if height is None: 19 | height = size 20 | if sigma_horz is None: 21 | sigma_horz = sigma 22 | if sigma_vert is None: 23 | sigma_vert = sigma 24 | center_x = mean_horz * width + 0.5 25 | center_y = mean_vert * height + 0.5 26 | gauss = np.empty((height, width), dtype=np.float32) 27 | # generate kernel 28 | for i in range(height): 29 | for j in range(width): 30 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( 31 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) 32 | if normalize: 33 | gauss = gauss / np.sum(gauss) 34 | return gauss 35 | 36 | 37 | def draw_gaussian(image, point, sigma): 38 | # Check if the gaussian is inside 39 | ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] 40 | br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] 41 | if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): 42 | return image 43 | size = 6 * sigma + 1 44 | g = _gaussian(size) 45 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] 46 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] 47 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] 48 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] 49 | assert (g_x[0] > 0 and g_y[1] > 0) 50 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] 51 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] 52 | image[image > 1] = 1 53 | return image 54 | 55 | 56 | def transform(point, center, scale, resolution, invert=False): 57 | """Generate and affine transformation matrix. 58 | 59 | Given a set of points, a center, a scale and a targer resolution, the 60 | function generates and affine transformation matrix. If invert is ``True`` 61 | it will produce the inverse transformation. 62 | 63 | Arguments: 64 | point {torch.tensor} -- the input 2D point 65 | center {torch.tensor or numpy.array} -- the center around which to perform the transformations 66 | scale {float} -- the scale of the face/object 67 | resolution {float} -- the output resolution 68 | 69 | Keyword Arguments: 70 | invert {bool} -- define wherever the function should produce the direct or the 71 | inverse transformation matrix (default: {False}) 72 | """ 73 | _pt = torch.ones(3) 74 | _pt[0] = point[0] 75 | _pt[1] = point[1] 76 | 77 | h = 200.0 * scale 78 | t = torch.eye(3) 79 | t[0, 0] = resolution / h 80 | t[1, 1] = resolution / h 81 | t[0, 2] = resolution * (-center[0] / h + 0.5) 82 | t[1, 2] = resolution * (-center[1] / h + 0.5) 83 | 84 | if invert: 85 | t = torch.inverse(t) 86 | 87 | new_point = (torch.matmul(t, _pt))[0:2] 88 | 89 | return new_point.int() 90 | 91 | 92 | def crop(image, center, scale, resolution=256.0): 93 | """Center crops an image or set of heatmaps 94 | 95 | Arguments: 96 | image {numpy.array} -- an rgb image 97 | center {numpy.array} -- the center of the object, usually the same as of the bounding box 98 | scale {float} -- scale of the face 99 | 100 | Keyword Arguments: 101 | resolution {float} -- the size of the output cropped image (default: {256.0}) 102 | 103 | Returns: 104 | [type] -- [description] 105 | """ # Crop around the center point 106 | """ Crops the image around the center. Input is expected to be an np.ndarray """ 107 | ul = transform([1, 1], center, scale, resolution, True) 108 | br = transform([resolution, resolution], center, scale, resolution, True) 109 | # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) 110 | if image.ndim > 2: 111 | newDim = np.array([br[1] - ul[1], br[0] - ul[0], 112 | image.shape[2]], dtype=np.int32) 113 | newImg = np.zeros(newDim, dtype=np.uint8) 114 | else: 115 | newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) 116 | newImg = np.zeros(newDim, dtype=np.uint8) 117 | ht = image.shape[0] 118 | wd = image.shape[1] 119 | newX = np.array( 120 | [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) 121 | newY = np.array( 122 | [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) 123 | oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) 124 | oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) 125 | newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] 126 | ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] 127 | newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), 128 | interpolation=cv2.INTER_LINEAR) 129 | return newImg 130 | 131 | 132 | def get_preds_fromhm(hm, center=None, scale=None): 133 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center 134 | and the scale is provided the function will return the points also in 135 | the original coordinate frame. 136 | 137 | Arguments: 138 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 139 | 140 | Keyword Arguments: 141 | center {torch.tensor} -- the center of the bounding box (default: {None}) 142 | scale {float} -- face scale (default: {None}) 143 | """ 144 | max, idx = torch.max( 145 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 146 | idx += 1 147 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 148 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 149 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 150 | 151 | for i in range(preds.size(0)): 152 | for j in range(preds.size(1)): 153 | hm_ = hm[i, j, :] 154 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 155 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 156 | diff = torch.FloatTensor( 157 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 158 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 159 | preds[i, j].add_(diff.sign_().mul_(.25)) 160 | 161 | preds.add_(-.5) 162 | 163 | preds_orig = torch.zeros(preds.size()) 164 | if center is not None and scale is not None: 165 | for i in range(hm.size(0)): 166 | for j in range(hm.size(1)): 167 | preds_orig[i, j] = transform( 168 | preds[i, j], center, scale, hm.size(2), True) 169 | 170 | return preds, preds_orig 171 | 172 | def get_preds_fromhm_batch(hm, centers=None, scales=None): 173 | """Obtain (x,y) coordinates given a set of N heatmaps. If the centers 174 | and the scales is provided the function will return the points also in 175 | the original coordinate frame. 176 | 177 | Arguments: 178 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 179 | 180 | Keyword Arguments: 181 | centers {torch.tensor} -- the centers of the bounding box (default: {None}) 182 | scales {float} -- face scales (default: {None}) 183 | """ 184 | max, idx = torch.max( 185 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 186 | idx += 1 187 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 188 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 189 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 190 | 191 | for i in range(preds.size(0)): 192 | for j in range(preds.size(1)): 193 | hm_ = hm[i, j, :] 194 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 195 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 196 | diff = torch.FloatTensor( 197 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 198 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 199 | preds[i, j].add_(diff.sign_().mul_(.25)) 200 | 201 | preds.add_(-.5) 202 | 203 | preds_orig = torch.zeros(preds.size()) 204 | if centers is not None and scales is not None: 205 | for i in range(hm.size(0)): 206 | for j in range(hm.size(1)): 207 | preds_orig[i, j] = transform( 208 | preds[i, j], centers[i], scales[i], hm.size(2), True) 209 | 210 | return preds, preds_orig 211 | 212 | def shuffle_lr(parts, pairs=None): 213 | """Shuffle the points left-right according to the axis of symmetry 214 | of the object. 215 | 216 | Arguments: 217 | parts {torch.tensor} -- a 3D or 4D object containing the 218 | heatmaps. 219 | 220 | Keyword Arguments: 221 | pairs {list of integers} -- [order of the flipped points] (default: {None}) 222 | """ 223 | if pairs is None: 224 | pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 225 | 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 226 | 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 227 | 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 228 | 62, 61, 60, 67, 66, 65] 229 | if parts.ndimension() == 3: 230 | parts = parts[pairs, ...] 231 | else: 232 | parts = parts[:, pairs, ...] 233 | 234 | return parts 235 | 236 | 237 | def flip(tensor, is_label=False): 238 | """Flip an image or a set of heatmaps left-right 239 | 240 | Arguments: 241 | tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] 242 | 243 | Keyword Arguments: 244 | is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) 245 | """ 246 | if not torch.is_tensor(tensor): 247 | tensor = torch.from_numpy(tensor) 248 | 249 | if is_label: 250 | tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) 251 | else: 252 | tensor = tensor.flip(tensor.ndimension() - 1) 253 | 254 | return tensor 255 | 256 | # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) 257 | 258 | 259 | def appdata_dir(appname=None, roaming=False): 260 | """ appdata_dir(appname=None, roaming=False) 261 | 262 | Get the path to the application directory, where applications are allowed 263 | to write user specific files (e.g. configurations). For non-user specific 264 | data, consider using common_appdata_dir(). 265 | If appname is given, a subdir is appended (and created if necessary). 266 | If roaming is True, will prefer a roaming directory (Windows Vista/7). 267 | """ 268 | 269 | # Define default user directory 270 | userDir = os.getenv('FACEALIGNMENT_USERDIR', None) 271 | if userDir is None: 272 | userDir = os.path.expanduser('~') 273 | if not os.path.isdir(userDir): # pragma: no cover 274 | userDir = '/var/tmp' # issue #54 275 | 276 | # Get system app data dir 277 | path = None 278 | if sys.platform.startswith('win'): 279 | path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') 280 | path = (path2 or path1) if roaming else (path1 or path2) 281 | elif sys.platform.startswith('darwin'): 282 | path = os.path.join(userDir, 'Library', 'Application Support') 283 | # On Linux and as fallback 284 | if not (path and os.path.isdir(path)): 285 | path = userDir 286 | 287 | # Maybe we should store things local to the executable (in case of a 288 | # portable distro or a frozen application that wants to be portable) 289 | prefix = sys.prefix 290 | if getattr(sys, 'frozen', None): 291 | prefix = os.path.abspath(os.path.dirname(sys.executable)) 292 | for reldir in ('settings', '../settings'): 293 | localpath = os.path.abspath(os.path.join(prefix, reldir)) 294 | if os.path.isdir(localpath): # pragma: no cover 295 | try: 296 | open(os.path.join(localpath, 'test.write'), 'wb').close() 297 | os.remove(os.path.join(localpath, 'test.write')) 298 | except IOError: 299 | pass # We cannot write in this directory 300 | else: 301 | path = localpath 302 | break 303 | 304 | # Get path specific for this app 305 | if appname: 306 | if path == userDir: 307 | appname = '.' + appname.lstrip('.') # Make it a hidden directory 308 | path = os.path.join(path, appname) 309 | if not os.path.isdir(path): # pragma: no cover 310 | os.mkdir(path) 311 | 312 | # Done 313 | return path 314 | -------------------------------------------------------------------------------- /filelists/README.md: -------------------------------------------------------------------------------- 1 | Place LRS2 (and any other) filelists here for training. -------------------------------------------------------------------------------- /get_filelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from tqdm import tqdm 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--root_dir", help="Root folder of the dataset", 8 | default='data/min_pre/') 9 | args = parser.parse_args() 10 | 11 | root_dir = args.root_dir 12 | 13 | lines_train = [] 14 | lines_val = [] 15 | dir_list = os.listdir(root_dir) 16 | for i, dir in tqdm(enumerate(dir_list), total=len(dir_list)): 17 | if dir.startswith('.'): 18 | continue 19 | 20 | sub_dir_list = os.listdir(os.path.join(root_dir, dir)) 21 | 22 | for sub_dir in sub_dir_list: 23 | if sub_dir.startswith('.'): 24 | continue 25 | 26 | line = dir + '/' + sub_dir 27 | wav_path = os.path.join(root_dir, line, 'audio.wav') 28 | if not os.path.exists(wav_path): 29 | continue 30 | 31 | if i % 10 == 0: 32 | lines_val.append(line) 33 | else: 34 | lines_train.append(line) 35 | 36 | print(len(lines_train)) 37 | print(len(lines_val)) 38 | 39 | with open('filelists/train.txt', 'w') as f: 40 | for line in lines_train: 41 | f.writelines(line + '\n') 42 | 43 | with open('filelists/val.txt', 'w') as f: 44 | for line in lines_val: 45 | f.writelines(line + '\n') 46 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | def get_image_list(data_root, split): 5 | filelist = [] 6 | 7 | with open('filelists/{}.txt'.format(split)) as f: 8 | for line in f: 9 | line = line.strip() 10 | if ' ' in line: line = line.split()[0] 11 | filelist.append(os.path.join(data_root, line)) 12 | 13 | return filelist 14 | 15 | class HParams: 16 | def __init__(self, **kwargs): 17 | self.data = {} 18 | 19 | for key, value in kwargs.items(): 20 | self.data[key] = value 21 | 22 | def __getattr__(self, key): 23 | if key not in self.data: 24 | raise AttributeError("'HParams' object has no attribute %s" % key) 25 | return self.data[key] 26 | 27 | def set_hparam(self, key, value): 28 | self.data[key] = value 29 | 30 | 31 | # Default hyperparameters 32 | hparams = HParams( 33 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 34 | # network 35 | rescale=True, # Whether to rescale audio prior to preprocessing 36 | rescaling_max=0.9, # Rescaling value 37 | 38 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 39 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 40 | # Does not work if n_ffit is not multiple of hop_size!! 41 | use_lws=False, 42 | 43 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 44 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 45 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 46 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 47 | 48 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 49 | 50 | # Mel and Linear spectrograms normalization/scaling and clipping 51 | signal_normalization=True, 52 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 53 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 54 | symmetric_mels=True, 55 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 56 | # faster and cleaner convergence) 57 | max_abs_value=4., 58 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 59 | # be too big to avoid gradient explosion, 60 | # not too small for fast convergence) 61 | # Contribution by @begeekmyfriend 62 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 63 | # levels. Also allows for better G&L phase reconstruction) 64 | preemphasize=True, # whether to apply filter 65 | preemphasis=0.97, # filter coefficient. 66 | 67 | # Limits 68 | min_level_db=-100, 69 | ref_level_db=20, 70 | fmin=55, 71 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 72 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 73 | fmax=7600, # To be increased/reduced depending on data. 74 | 75 | ###################### Our training parameters ################################# 76 | img_size=96, 77 | fps=25, 78 | 79 | batch_size=16, 80 | initial_learning_rate=1e-4, 81 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 82 | num_workers=16, 83 | checkpoint_interval=3000, 84 | eval_interval=3000, 85 | save_optimizer_state=True, 86 | 87 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 88 | syncnet_batch_size=64, 89 | syncnet_lr=1e-4, 90 | syncnet_eval_interval=10000, 91 | syncnet_checkpoint_interval=10000, 92 | 93 | disc_wt=0.07, 94 | disc_initial_learning_rate=1e-4, 95 | ) 96 | 97 | 98 | def hparams_debug_string(): 99 | values = hparams.values() 100 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 101 | return "Hyperparameters:\n" + "\n".join(hp) 102 | -------------------------------------------------------------------------------- /hq_wav2lip_train.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join, basename, isfile 2 | from tqdm import tqdm 3 | 4 | from models import SyncNet_color as SyncNet 5 | from models import Wav2Lip, Wav2Lip_disc_qual 6 | import audio 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch import optim 12 | import torch.backends.cudnn as cudnn 13 | from torch.utils import data as data_utils 14 | import numpy as np 15 | 16 | from glob import glob 17 | 18 | import os, random, cv2, argparse 19 | from hparams import hparams, get_image_list 20 | 21 | parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model WITH the visual quality discriminator') 22 | 23 | parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str) 24 | 25 | parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str) 26 | parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str) 27 | 28 | parser.add_argument('--checkpoint_path', help='Resume generator from this checkpoint', default=None, type=str) 29 | parser.add_argument('--disc_checkpoint_path', help='Resume quality disc from this checkpoint', default=None, type=str) 30 | 31 | args = parser.parse_args() 32 | 33 | 34 | global_step = 0 35 | global_epoch = 0 36 | use_cuda = torch.cuda.is_available() 37 | print('use_cuda: {}'.format(use_cuda)) 38 | 39 | syncnet_T = 5 40 | syncnet_mel_step_size = 16 41 | 42 | class Dataset(object): 43 | def __init__(self, split): 44 | self.all_videos = get_image_list(args.data_root, split) 45 | 46 | def get_frame_id(self, frame): 47 | return int(basename(frame).split('.')[0]) 48 | 49 | def get_window(self, start_frame): 50 | start_id = self.get_frame_id(start_frame) 51 | vidname = dirname(start_frame) 52 | 53 | window_fnames = [] 54 | for frame_id in range(start_id, start_id + syncnet_T): 55 | frame = join(vidname, '{}.jpg'.format(frame_id)) 56 | if not isfile(frame): 57 | return None 58 | window_fnames.append(frame) 59 | return window_fnames 60 | 61 | def read_window(self, window_fnames): 62 | if window_fnames is None: return None 63 | window = [] 64 | for fname in window_fnames: 65 | img = cv2.imread(fname) 66 | if img is None: 67 | return None 68 | try: 69 | img = cv2.resize(img, (hparams.img_size, hparams.img_size)) 70 | except Exception as e: 71 | return None 72 | 73 | window.append(img) 74 | 75 | return window 76 | 77 | def crop_audio_window(self, spec, start_frame): 78 | if type(start_frame) == int: 79 | start_frame_num = start_frame 80 | else: 81 | start_frame_num = self.get_frame_id(start_frame) 82 | start_idx = int(80. * (start_frame_num / float(hparams.fps))) 83 | 84 | end_idx = start_idx + syncnet_mel_step_size 85 | 86 | return spec[start_idx : end_idx, :] 87 | 88 | def get_segmented_mels(self, spec, start_frame): 89 | mels = [] 90 | assert syncnet_T == 5 91 | start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing 92 | if start_frame_num - 2 < 0: return None 93 | for i in range(start_frame_num, start_frame_num + syncnet_T): 94 | m = self.crop_audio_window(spec, i - 2) 95 | if m.shape[0] != syncnet_mel_step_size: 96 | return None 97 | mels.append(m.T) 98 | 99 | mels = np.asarray(mels) 100 | 101 | return mels 102 | 103 | def prepare_window(self, window): 104 | # 3 x T x H x W 105 | x = np.asarray(window) / 255. 106 | x = np.transpose(x, (3, 0, 1, 2)) 107 | 108 | return x 109 | 110 | def __len__(self): 111 | return len(self.all_videos) 112 | 113 | def __getitem__(self, idx): 114 | while 1: 115 | idx = random.randint(0, len(self.all_videos) - 1) 116 | vidname = self.all_videos[idx] 117 | img_names = list(glob(join(vidname, '*.jpg'))) 118 | if len(img_names) <= 3 * syncnet_T: 119 | continue 120 | 121 | img_name = random.choice(img_names) 122 | wrong_img_name = random.choice(img_names) 123 | while wrong_img_name == img_name: 124 | wrong_img_name = random.choice(img_names) 125 | 126 | window_fnames = self.get_window(img_name) 127 | wrong_window_fnames = self.get_window(wrong_img_name) 128 | if window_fnames is None or wrong_window_fnames is None: 129 | continue 130 | 131 | window = self.read_window(window_fnames) 132 | if window is None: 133 | continue 134 | 135 | wrong_window = self.read_window(wrong_window_fnames) 136 | if wrong_window is None: 137 | continue 138 | 139 | try: 140 | wavpath = join(vidname, "audio.wav") 141 | wav = audio.load_wav(wavpath, hparams.sample_rate) 142 | 143 | orig_mel = audio.melspectrogram(wav).T 144 | except Exception as e: 145 | continue 146 | 147 | mel = self.crop_audio_window(orig_mel.copy(), img_name) 148 | 149 | if (mel.shape[0] != syncnet_mel_step_size): 150 | continue 151 | 152 | indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) 153 | if indiv_mels is None: continue 154 | 155 | window = self.prepare_window(window) 156 | y = window.copy() 157 | window[:, :, window.shape[2]//2:] = 0. 158 | 159 | wrong_window = self.prepare_window(wrong_window) 160 | x = np.concatenate([window, wrong_window], axis=0) 161 | 162 | x = torch.FloatTensor(x) 163 | mel = torch.FloatTensor(mel.T).unsqueeze(0) 164 | indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1) 165 | y = torch.FloatTensor(y) 166 | return x, indiv_mels, mel, y 167 | 168 | def save_sample_images(x, g, gt, global_step, checkpoint_dir): 169 | x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 170 | g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 171 | gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 172 | 173 | refs, inps = x[..., 3:], x[..., :3] 174 | folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step)) 175 | if not os.path.exists(folder): os.mkdir(folder) 176 | collage = np.concatenate((refs, inps, g, gt), axis=-2) 177 | for batch_idx, c in enumerate(collage): 178 | for t in range(len(c)): 179 | cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t]) 180 | 181 | logloss = nn.BCELoss() 182 | def cosine_loss(a, v, y): 183 | d = nn.functional.cosine_similarity(a, v) 184 | loss = logloss(d.unsqueeze(1), y) 185 | 186 | return loss 187 | 188 | device = torch.device("cuda" if use_cuda else "cpu") 189 | syncnet = SyncNet().to(device) 190 | for p in syncnet.parameters(): 191 | p.requires_grad = False 192 | 193 | recon_loss = nn.L1Loss() 194 | def get_sync_loss(mel, g): 195 | g = g[:, :, :, g.size(3)//2:] 196 | g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1) 197 | # B, 3 * T, H//2, W 198 | a, v = syncnet(mel, g) 199 | y = torch.ones(g.size(0), 1).float().to(device) 200 | return cosine_loss(a, v, y) 201 | 202 | def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, 203 | checkpoint_dir=None, checkpoint_interval=None, nepochs=None): 204 | global global_step, global_epoch 205 | resumed_step = global_step 206 | 207 | while global_epoch < nepochs: 208 | print('Starting Epoch: {}'.format(global_epoch)) 209 | running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0. 210 | running_disc_real_loss, running_disc_fake_loss = 0., 0. 211 | prog_bar = tqdm(enumerate(train_data_loader)) 212 | for step, (x, indiv_mels, mel, gt) in prog_bar: 213 | disc.train() 214 | model.train() 215 | 216 | x = x.to(device) 217 | mel = mel.to(device) 218 | indiv_mels = indiv_mels.to(device) 219 | gt = gt.to(device) 220 | 221 | ### Train generator now. Remove ALL grads. 222 | optimizer.zero_grad() 223 | disc_optimizer.zero_grad() 224 | 225 | g = model(indiv_mels, x) 226 | 227 | if hparams.syncnet_wt > 0.: 228 | sync_loss = get_sync_loss(mel, g) 229 | else: 230 | sync_loss = 0. 231 | 232 | if hparams.disc_wt > 0.: 233 | perceptual_loss = disc.perceptual_forward(g) 234 | else: 235 | perceptual_loss = 0. 236 | 237 | l1loss = recon_loss(g, gt) 238 | 239 | loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \ 240 | (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss 241 | 242 | loss.backward() 243 | optimizer.step() 244 | 245 | ### Remove all gradients before Training disc 246 | disc_optimizer.zero_grad() 247 | 248 | pred = disc(gt) 249 | disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device)) 250 | disc_real_loss.backward() 251 | 252 | pred = disc(g.detach()) 253 | disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device)) 254 | disc_fake_loss.backward() 255 | 256 | disc_optimizer.step() 257 | 258 | running_disc_real_loss += disc_real_loss.item() 259 | running_disc_fake_loss += disc_fake_loss.item() 260 | 261 | if global_step % checkpoint_interval == 0: 262 | save_sample_images(x, g, gt, global_step, checkpoint_dir) 263 | 264 | # Logs 265 | global_step += 1 266 | cur_session_steps = global_step - resumed_step 267 | 268 | running_l1_loss += l1loss.item() 269 | if hparams.syncnet_wt > 0.: 270 | running_sync_loss += sync_loss.item() 271 | else: 272 | running_sync_loss += 0. 273 | 274 | if hparams.disc_wt > 0.: 275 | running_perceptual_loss += perceptual_loss.item() 276 | else: 277 | running_perceptual_loss += 0. 278 | 279 | if global_step == 1 or global_step % checkpoint_interval == 0: 280 | save_checkpoint( 281 | model, optimizer, global_step, checkpoint_dir, global_epoch) 282 | save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_') 283 | 284 | 285 | if global_step % hparams.eval_interval == 0: 286 | with torch.no_grad(): 287 | average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc) 288 | 289 | if average_sync_loss < .75: 290 | hparams.set_hparam('syncnet_wt', 0.03) 291 | 292 | prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1), 293 | running_sync_loss / (step + 1), 294 | running_perceptual_loss / (step + 1), 295 | running_disc_fake_loss / (step + 1), 296 | running_disc_real_loss / (step + 1))) 297 | 298 | global_epoch += 1 299 | 300 | def eval_model(test_data_loader, global_step, device, model, disc): 301 | eval_steps = 300 302 | print('Evaluating for {} steps'.format(eval_steps)) 303 | running_sync_loss, running_l1_loss, running_disc_real_loss, running_disc_fake_loss, running_perceptual_loss = [], [], [], [], [] 304 | while 1: 305 | for step, (x, indiv_mels, mel, gt) in enumerate((test_data_loader)): 306 | model.eval() 307 | disc.eval() 308 | 309 | x = x.to(device) 310 | mel = mel.to(device) 311 | indiv_mels = indiv_mels.to(device) 312 | gt = gt.to(device) 313 | 314 | pred = disc(gt) 315 | disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device)) 316 | 317 | g = model(indiv_mels, x) 318 | pred = disc(g) 319 | disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device)) 320 | 321 | running_disc_real_loss.append(disc_real_loss.item()) 322 | running_disc_fake_loss.append(disc_fake_loss.item()) 323 | 324 | sync_loss = get_sync_loss(mel, g) 325 | 326 | if hparams.disc_wt > 0.: 327 | perceptual_loss = disc.perceptual_forward(g) 328 | else: 329 | perceptual_loss = 0. 330 | 331 | l1loss = recon_loss(g, gt) 332 | 333 | loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \ 334 | (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss 335 | 336 | running_l1_loss.append(l1loss.item()) 337 | running_sync_loss.append(sync_loss.item()) 338 | 339 | if hparams.disc_wt > 0.: 340 | running_perceptual_loss.append(perceptual_loss.item()) 341 | else: 342 | running_perceptual_loss.append(0.) 343 | 344 | if step > eval_steps: break 345 | 346 | print('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(sum(running_l1_loss) / len(running_l1_loss), 347 | sum(running_sync_loss) / len(running_sync_loss), 348 | sum(running_perceptual_loss) / len(running_perceptual_loss), 349 | sum(running_disc_fake_loss) / len(running_disc_fake_loss), 350 | sum(running_disc_real_loss) / len(running_disc_real_loss))) 351 | return sum(running_sync_loss) / len(running_sync_loss) 352 | 353 | 354 | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, prefix=''): 355 | checkpoint_path = join( 356 | checkpoint_dir, "{}checkpoint_step{:09d}.pth".format(prefix, global_step)) 357 | optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None 358 | torch.save({ 359 | "state_dict": model.state_dict(), 360 | "optimizer": optimizer_state, 361 | "global_step": step, 362 | "global_epoch": epoch, 363 | }, checkpoint_path) 364 | print("Saved checkpoint:", checkpoint_path) 365 | 366 | def _load(checkpoint_path): 367 | if use_cuda: 368 | checkpoint = torch.load(checkpoint_path) 369 | else: 370 | checkpoint = torch.load(checkpoint_path, 371 | map_location=lambda storage, loc: storage) 372 | return checkpoint 373 | 374 | 375 | def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True): 376 | global global_step 377 | global global_epoch 378 | 379 | print("Load checkpoint from: {}".format(path)) 380 | checkpoint = _load(path) 381 | s = checkpoint["state_dict"] 382 | new_s = {} 383 | for k, v in s.items(): 384 | new_s[k.replace('module.', '')] = v 385 | model.load_state_dict(new_s) 386 | if not reset_optimizer: 387 | optimizer_state = checkpoint["optimizer"] 388 | if optimizer_state is not None: 389 | print("Load optimizer state from {}".format(path)) 390 | optimizer.load_state_dict(checkpoint["optimizer"]) 391 | if overwrite_global_states: 392 | global_step = checkpoint["global_step"] 393 | global_epoch = checkpoint["global_epoch"] 394 | 395 | return model 396 | 397 | if __name__ == "__main__": 398 | checkpoint_dir = args.checkpoint_dir 399 | 400 | # Dataset and Dataloader setup 401 | train_dataset = Dataset('train') 402 | test_dataset = Dataset('val') 403 | 404 | train_data_loader = data_utils.DataLoader( 405 | train_dataset, batch_size=hparams.batch_size, shuffle=True, 406 | num_workers=hparams.num_workers) 407 | 408 | test_data_loader = data_utils.DataLoader( 409 | test_dataset, batch_size=hparams.batch_size, 410 | num_workers=4) 411 | 412 | device = torch.device("cuda" if use_cuda else "cpu") 413 | 414 | # Model 415 | model = Wav2Lip().to(device) 416 | disc = Wav2Lip_disc_qual().to(device) 417 | 418 | print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 419 | print('total DISC trainable params {}'.format(sum(p.numel() for p in disc.parameters() if p.requires_grad))) 420 | 421 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], 422 | lr=hparams.initial_learning_rate, betas=(0.5, 0.999)) 423 | disc_optimizer = optim.Adam([p for p in disc.parameters() if p.requires_grad], 424 | lr=hparams.disc_initial_learning_rate, betas=(0.5, 0.999)) 425 | 426 | if args.checkpoint_path is not None: 427 | load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False) 428 | 429 | if args.disc_checkpoint_path is not None: 430 | load_checkpoint(args.disc_checkpoint_path, disc, disc_optimizer, 431 | reset_optimizer=False, overwrite_global_states=False) 432 | 433 | load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, 434 | overwrite_global_states=False) 435 | 436 | if not os.path.exists(checkpoint_dir): 437 | os.mkdir(checkpoint_dir) 438 | 439 | # Train! 440 | train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer, 441 | checkpoint_dir=checkpoint_dir, 442 | checkpoint_interval=hparams.checkpoint_interval, 443 | nepochs=hparams.nepochs) 444 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from os import listdir, path 2 | import numpy as np 3 | import scipy, cv2, os, sys, argparse, audio 4 | import json, subprocess, random, string 5 | from tqdm import tqdm 6 | from glob import glob 7 | import torch, face_detection 8 | from models import Wav2Lip 9 | import platform 10 | 11 | parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models') 12 | 13 | parser.add_argument('--checkpoint_path', type=str, 14 | help='Name of saved checkpoint to load weights from', required=True) 15 | 16 | parser.add_argument('--face', type=str, 17 | help='Filepath of video/image that contains faces to use', required=True) 18 | parser.add_argument('--audio', type=str, 19 | help='Filepath of video/audio file to use as raw audio source', required=True) 20 | parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', 21 | default='results/result_voice.mp4') 22 | 23 | parser.add_argument('--static', type=bool, 24 | help='If True, then use only first video frame for inference', default=False) 25 | parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', 26 | default=25., required=False) 27 | 28 | parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], 29 | help='Padding (top, bottom, left, right). Please adjust to include chin at least') 30 | 31 | parser.add_argument('--face_det_batch_size', type=int, 32 | help='Batch size for face detection', default=16) 33 | parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128) 34 | 35 | parser.add_argument('--resize_factor', default=1, type=int, 36 | help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') 37 | 38 | parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], 39 | help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' 40 | 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') 41 | 42 | parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], 43 | help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' 44 | 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') 45 | 46 | parser.add_argument('--rotate', default=False, action='store_true', 47 | help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' 48 | 'Use if you get a flipped result, despite feeding a normal looking video') 49 | 50 | parser.add_argument('--nosmooth', default=False, action='store_true', 51 | help='Prevent smoothing face detections over a short temporal window') 52 | 53 | args = parser.parse_args() 54 | args.img_size = 96 55 | 56 | if os.path.isfile(args.face) and args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: 57 | args.static = True 58 | 59 | def get_smoothened_boxes(boxes, T): 60 | for i in range(len(boxes)): 61 | if i + T > len(boxes): 62 | window = boxes[len(boxes) - T:] 63 | else: 64 | window = boxes[i : i + T] 65 | boxes[i] = np.mean(window, axis=0) 66 | return boxes 67 | 68 | def face_detect(images): 69 | detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, 70 | flip_input=False, device=device) 71 | 72 | batch_size = args.face_det_batch_size 73 | 74 | while 1: 75 | predictions = [] 76 | try: 77 | for i in tqdm(range(0, len(images), batch_size)): 78 | predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) 79 | except RuntimeError: 80 | if batch_size == 1: 81 | raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument') 82 | batch_size //= 2 83 | print('Recovering from OOM error; New batch size: {}'.format(batch_size)) 84 | continue 85 | break 86 | 87 | results = [] 88 | pady1, pady2, padx1, padx2 = args.pads 89 | for rect, image in zip(predictions, images): 90 | if rect is None: 91 | cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. 92 | raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') 93 | 94 | y1 = max(0, rect[1] - pady1) 95 | y2 = min(image.shape[0], rect[3] + pady2) 96 | x1 = max(0, rect[0] - padx1) 97 | x2 = min(image.shape[1], rect[2] + padx2) 98 | 99 | results.append([x1, y1, x2, y2]) 100 | 101 | boxes = np.array(results) 102 | if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5) 103 | results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] 104 | 105 | del detector 106 | return results 107 | 108 | def datagen(frames, mels): 109 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 110 | 111 | if args.box[0] == -1: 112 | if not args.static: 113 | face_det_results = face_detect(frames) # BGR2RGB for CNN face detection 114 | else: 115 | face_det_results = face_detect([frames[0]]) 116 | else: 117 | print('Using the specified bounding box instead of face detection...') 118 | y1, y2, x1, x2 = args.box 119 | face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] 120 | 121 | for i, m in enumerate(mels): 122 | idx = 0 if args.static else i%len(frames) 123 | frame_to_save = frames[idx].copy() 124 | face, coords = face_det_results[idx].copy() 125 | 126 | face = cv2.resize(face, (args.img_size, args.img_size)) 127 | 128 | img_batch.append(face) 129 | mel_batch.append(m) 130 | frame_batch.append(frame_to_save) 131 | coords_batch.append(coords) 132 | 133 | if len(img_batch) >= args.wav2lip_batch_size: 134 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 135 | 136 | img_masked = img_batch.copy() 137 | img_masked[:, args.img_size//2:] = 0 138 | 139 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 140 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 141 | 142 | yield img_batch, mel_batch, frame_batch, coords_batch 143 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 144 | 145 | if len(img_batch) > 0: 146 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 147 | 148 | img_masked = img_batch.copy() 149 | img_masked[:, args.img_size//2:] = 0 150 | 151 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 152 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 153 | 154 | yield img_batch, mel_batch, frame_batch, coords_batch 155 | 156 | mel_step_size = 16 157 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 158 | print('Using {} for inference.'.format(device)) 159 | 160 | def _load(checkpoint_path): 161 | if device == 'cuda': 162 | checkpoint = torch.load(checkpoint_path) 163 | else: 164 | checkpoint = torch.load(checkpoint_path, 165 | map_location=lambda storage, loc: storage) 166 | return checkpoint 167 | 168 | def load_model(path): 169 | model = Wav2Lip() 170 | print("Load checkpoint from: {}".format(path)) 171 | checkpoint = _load(path) 172 | s = checkpoint["state_dict"] 173 | new_s = {} 174 | for k, v in s.items(): 175 | new_s[k.replace('module.', '')] = v 176 | model.load_state_dict(new_s) 177 | 178 | model = model.to(device) 179 | return model.eval() 180 | 181 | def main(): 182 | if not os.path.isfile(args.face): 183 | raise ValueError('--face argument must be a valid path to video/image file') 184 | 185 | elif args.face.split('.')[1] in ['jpg', 'png', 'jpeg']: 186 | full_frames = [cv2.imread(args.face)] 187 | fps = args.fps 188 | 189 | else: 190 | video_stream = cv2.VideoCapture(args.face) 191 | fps = video_stream.get(cv2.CAP_PROP_FPS) 192 | 193 | print('Reading video frames...') 194 | 195 | full_frames = [] 196 | while 1: 197 | still_reading, frame = video_stream.read() 198 | if not still_reading: 199 | video_stream.release() 200 | break 201 | if args.resize_factor > 1: 202 | frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor)) 203 | 204 | if args.rotate: 205 | frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) 206 | 207 | y1, y2, x1, x2 = args.crop 208 | if x2 == -1: x2 = frame.shape[1] 209 | if y2 == -1: y2 = frame.shape[0] 210 | 211 | frame = frame[y1:y2, x1:x2] 212 | 213 | full_frames.append(frame) 214 | 215 | print ("Number of frames available for inference: "+str(len(full_frames))) 216 | 217 | if not args.audio.endswith('.wav'): 218 | print('Extracting raw audio...') 219 | command = 'ffmpeg -y -i {} -strict -2 {}'.format(args.audio, 'temp/temp.wav') 220 | 221 | subprocess.call(command, shell=True) 222 | args.audio = 'temp/temp.wav' 223 | 224 | wav = audio.load_wav(args.audio, 16000) 225 | mel = audio.melspectrogram(wav) 226 | print(mel.shape) 227 | 228 | if np.isnan(mel.reshape(-1)).sum() > 0: 229 | raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 230 | 231 | mel_chunks = [] 232 | mel_idx_multiplier = 80./fps 233 | i = 0 234 | while 1: 235 | start_idx = int(i * mel_idx_multiplier) 236 | if start_idx + mel_step_size > len(mel[0]): 237 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) 238 | break 239 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 240 | i += 1 241 | 242 | print("Length of mel chunks: {}".format(len(mel_chunks))) 243 | 244 | full_frames = full_frames[:len(mel_chunks)] 245 | 246 | batch_size = args.wav2lip_batch_size 247 | gen = datagen(full_frames.copy(), mel_chunks) 248 | 249 | for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, 250 | total=int(np.ceil(float(len(mel_chunks))/batch_size)))): 251 | if i == 0: 252 | model = load_model(args.checkpoint_path) 253 | print ("Model loaded") 254 | 255 | frame_h, frame_w = full_frames[0].shape[:-1] 256 | out = cv2.VideoWriter('temp/result.avi', 257 | cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) 258 | 259 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) 260 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) 261 | 262 | with torch.no_grad(): 263 | pred = model(mel_batch, img_batch) 264 | 265 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. 266 | 267 | for p, f, c in zip(pred, frames, coords): 268 | y1, y2, x1, x2 = c 269 | p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) 270 | 271 | f[y1:y2, x1:x2] = p 272 | out.write(f) 273 | 274 | out.release() 275 | 276 | command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(args.audio, 'temp/result.avi', args.outfile) 277 | subprocess.call(command, shell=platform.system() != 'Windows') 278 | 279 | if __name__ == '__main__': 280 | main() 281 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wav2lip import Wav2Lip, Wav2Lip_disc_qual 2 | from .syncnet import SyncNet_color -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | 15 | def forward(self, x): 16 | out = self.conv_block(x) 17 | if self.residual: 18 | out += x 19 | return self.act(out) 20 | 21 | class nonorm_Conv2d(nn.Module): 22 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.conv_block = nn.Sequential( 25 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 26 | ) 27 | self.act = nn.LeakyReLU(0.01, inplace=True) 28 | 29 | def forward(self, x): 30 | out = self.conv_block(x) 31 | return self.act(out) 32 | 33 | class Conv2dTranspose(nn.Module): 34 | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.conv_block = nn.Sequential( 37 | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), 38 | nn.BatchNorm2d(cout) 39 | ) 40 | self.act = nn.ReLU() 41 | 42 | def forward(self, x): 43 | out = self.conv_block(x) 44 | return self.act(out) 45 | -------------------------------------------------------------------------------- /models/quantize_vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch import einsum 6 | from einops import rearrange 7 | 8 | 9 | class VectorQuantizer(nn.Module): 10 | """ 11 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 12 | ____________________________________________ 13 | Discretization bottleneck part of the VQ-VAE. 14 | Inputs: 15 | - n_e : number of embeddings 16 | - e_dim : dimension of embedding 17 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 18 | _____________________________________________ 19 | """ 20 | 21 | # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for 22 | # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be 23 | # used wherever VectorQuantizer has been used before and is additionally 24 | # more efficient. 25 | def __init__(self, n_e, e_dim, beta): 26 | super(VectorQuantizer, self).__init__() 27 | self.n_e = n_e 28 | self.e_dim = e_dim 29 | self.beta = beta 30 | 31 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 32 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 33 | 34 | def forward(self, z): 35 | """ 36 | Inputs the output of the encoder network z and maps it to a discrete 37 | one-hot vector that is the index of the closest embedding vector e_j 38 | z (continuous) -> z_q (discrete) 39 | z.shape = (batch, channel, height, width) 40 | quantization pipeline: 41 | 1. get encoder input (B,C,H,W) 42 | 2. flatten input to (B*H*W,C) 43 | """ 44 | # reshape z -> (batch, height, width, channel) and flatten 45 | z = z.permute(0, 2, 3, 1).contiguous() 46 | z_flattened = z.view(-1, self.e_dim) 47 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 48 | 49 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 50 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 51 | torch.matmul(z_flattened, self.embedding.weight.t()) 52 | 53 | ## could possible replace this here 54 | # #\start... 55 | # find closest encodings 56 | min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) 57 | 58 | min_encodings = torch.zeros( 59 | min_encoding_indices.shape[0], self.n_e).to(z) 60 | min_encodings.scatter_(1, min_encoding_indices, 1) 61 | 62 | # dtype min encodings: torch.float32 63 | # min_encodings shape: torch.Size([2048, 512]) 64 | # min_encoding_indices.shape: torch.Size([2048, 1]) 65 | 66 | # get quantized latent vectors 67 | z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) 68 | #.........\end 69 | 70 | # with: 71 | # .........\start 72 | #min_encoding_indices = torch.argmin(d, dim=1) 73 | #z_q = self.embedding(min_encoding_indices) 74 | # ......\end......... (TODO) 75 | 76 | # compute loss for embedding 77 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 78 | torch.mean((z_q - z.detach()) ** 2) 79 | 80 | # preserve gradients 81 | z_q = z + (z_q - z).detach() 82 | 83 | # perplexity 84 | e_mean = torch.mean(min_encodings, dim=0) 85 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) 86 | 87 | # reshape back to match original input shape 88 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 89 | 90 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 91 | 92 | def get_codebook_entry(self, indices, shape): 93 | # shape specifying (batch, height, width, channel) 94 | # TODO: check for more easy handling with nn.Embedding 95 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) 96 | min_encodings.scatter_(1, indices[:,None], 1) 97 | 98 | # get quantized latent vectors 99 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight) 100 | 101 | if shape is not None: 102 | z_q = z_q.view(shape) 103 | 104 | # reshape back to match original input shape 105 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 106 | 107 | return z_q 108 | 109 | 110 | class GumbelQuantize(nn.Module): 111 | """ 112 | credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 113 | Gumbel Softmax trick quantizer 114 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 115 | https://arxiv.org/abs/1611.01144 116 | """ 117 | def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, 118 | kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, 119 | remap=None, unknown_index="random"): 120 | super().__init__() 121 | 122 | self.embedding_dim = embedding_dim 123 | self.n_embed = n_embed 124 | 125 | self.straight_through = straight_through 126 | self.temperature = temp_init 127 | self.kl_weight = kl_weight 128 | 129 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 130 | self.embed = nn.Embedding(n_embed, embedding_dim) 131 | 132 | self.use_vqinterface = use_vqinterface 133 | 134 | self.remap = remap 135 | if self.remap is not None: 136 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 137 | self.re_embed = self.used.shape[0] 138 | self.unknown_index = unknown_index # "random" or "extra" or integer 139 | if self.unknown_index == "extra": 140 | self.unknown_index = self.re_embed 141 | self.re_embed = self.re_embed+1 142 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 143 | f"Using {self.unknown_index} for unknown indices.") 144 | else: 145 | self.re_embed = n_embed 146 | 147 | def remap_to_used(self, inds): 148 | ishape = inds.shape 149 | assert len(ishape)>1 150 | inds = inds.reshape(ishape[0],-1) 151 | used = self.used.to(inds) 152 | match = (inds[:,:,None]==used[None,None,...]).long() 153 | new = match.argmax(-1) 154 | unknown = match.sum(2)<1 155 | if self.unknown_index == "random": 156 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 157 | else: 158 | new[unknown] = self.unknown_index 159 | return new.reshape(ishape) 160 | 161 | def unmap_to_all(self, inds): 162 | ishape = inds.shape 163 | assert len(ishape)>1 164 | inds = inds.reshape(ishape[0],-1) 165 | used = self.used.to(inds) 166 | if self.re_embed > self.used.shape[0]: # extra token 167 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 168 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 169 | return back.reshape(ishape) 170 | 171 | def forward(self, z, temp=None, return_logits=False): 172 | # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work 173 | hard = self.straight_through if self.training else True 174 | temp = self.temperature if temp is None else temp 175 | 176 | logits = self.proj(z) 177 | if self.remap is not None: 178 | # continue only with used logits 179 | full_zeros = torch.zeros_like(logits) 180 | logits = logits[:,self.used,...] 181 | 182 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 183 | if self.remap is not None: 184 | # go back to all entries but unused set to zero 185 | full_zeros[:,self.used,...] = soft_one_hot 186 | soft_one_hot = full_zeros 187 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) 188 | 189 | # + kl divergence to the prior loss 190 | qy = F.softmax(logits, dim=1) 191 | diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 192 | 193 | ind = soft_one_hot.argmax(dim=1) 194 | if self.remap is not None: 195 | ind = self.remap_to_used(ind) 196 | if self.use_vqinterface: 197 | if return_logits: 198 | return z_q, diff, (None, None, ind), logits 199 | return z_q, diff, (None, None, ind) 200 | return z_q, diff, ind 201 | 202 | def get_codebook_entry(self, indices, shape): 203 | b, h, w, c = shape 204 | assert b*h*w == indices.shape[0] 205 | indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) 206 | if self.remap is not None: 207 | indices = self.unmap_to_all(indices) 208 | one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 209 | z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) 210 | return z_q 211 | 212 | 213 | class VectorQuantizer2(nn.Module): 214 | """ 215 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 216 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 217 | """ 218 | # NOTE: due to a bug the beta term was applied to the wrong term. for 219 | # backwards compatibility we use the buggy version by default, but you can 220 | # specify legacy=False to fix it. 221 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 222 | sane_index_shape=False, legacy=True): 223 | super().__init__() 224 | self.n_e = n_e 225 | self.e_dim = e_dim 226 | self.beta = beta 227 | self.legacy = legacy 228 | 229 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 230 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 231 | 232 | self.remap = remap 233 | if self.remap is not None: 234 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 235 | self.re_embed = self.used.shape[0] 236 | self.unknown_index = unknown_index # "random" or "extra" or integer 237 | if self.unknown_index == "extra": 238 | self.unknown_index = self.re_embed 239 | self.re_embed = self.re_embed+1 240 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 241 | f"Using {self.unknown_index} for unknown indices.") 242 | else: 243 | self.re_embed = n_e 244 | 245 | self.sane_index_shape = sane_index_shape 246 | 247 | def remap_to_used(self, inds): 248 | ishape = inds.shape 249 | assert len(ishape)>1 250 | inds = inds.reshape(ishape[0],-1) 251 | used = self.used.to(inds) 252 | match = (inds[:,:,None]==used[None,None,...]).long() 253 | new = match.argmax(-1) 254 | unknown = match.sum(2)<1 255 | if self.unknown_index == "random": 256 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 257 | else: 258 | new[unknown] = self.unknown_index 259 | return new.reshape(ishape) 260 | 261 | def unmap_to_all(self, inds): 262 | ishape = inds.shape 263 | assert len(ishape)>1 264 | inds = inds.reshape(ishape[0],-1) 265 | used = self.used.to(inds) 266 | if self.re_embed > self.used.shape[0]: # extra token 267 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 268 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 269 | return back.reshape(ishape) 270 | 271 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 272 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 273 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 274 | assert return_logits==False, "Only for interface compatible with Gumbel" 275 | # reshape z -> (batch, height, width, channel) and flatten 276 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 277 | z_flattened = z.view(-1, self.e_dim) 278 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 279 | 280 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 281 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 282 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 283 | 284 | min_encoding_indices = torch.argmin(d, dim=1) 285 | z_q = self.embedding(min_encoding_indices).view(z.shape) 286 | perplexity = None 287 | min_encodings = None 288 | 289 | # compute loss for embedding 290 | if not self.legacy: 291 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 292 | torch.mean((z_q - z.detach()) ** 2) 293 | else: 294 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 295 | torch.mean((z_q - z.detach()) ** 2) 296 | 297 | # preserve gradients 298 | z_q = z + (z_q - z).detach() 299 | 300 | # reshape back to match original input shape 301 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 302 | 303 | if self.remap is not None: 304 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 305 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 306 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 307 | 308 | if self.sane_index_shape: 309 | min_encoding_indices = min_encoding_indices.reshape( 310 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 311 | 312 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 313 | 314 | def get_codebook_entry(self, indices, shape): 315 | # shape specifying (batch, height, width, channel) 316 | if self.remap is not None: 317 | indices = indices.reshape(shape[0],-1) # add batch axis 318 | indices = self.unmap_to_all(indices) 319 | indices = indices.reshape(-1) # flatten again 320 | 321 | # get quantized latent vectors 322 | z_q = self.embedding(indices) 323 | 324 | if shape is not None: 325 | z_q = z_q.view(shape) 326 | # reshape back to match original input shape 327 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 328 | 329 | return z_q 330 | 331 | class EmbeddingEMA(nn.Module): 332 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): 333 | super().__init__() 334 | self.decay = decay 335 | self.eps = eps 336 | weight = torch.randn(num_tokens, codebook_dim) 337 | self.weight = nn.Parameter(weight, requires_grad = False) 338 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) 339 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) 340 | self.update = True 341 | 342 | def forward(self, embed_id): 343 | return F.embedding(embed_id, self.weight) 344 | 345 | def cluster_size_ema_update(self, new_cluster_size): 346 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) 347 | 348 | def embed_avg_ema_update(self, new_embed_avg): 349 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 350 | 351 | def weight_update(self, num_tokens): 352 | n = self.cluster_size.sum() 353 | smoothed_cluster_size = ( 354 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 355 | ) 356 | #normalize embedding average with smoothed cluster size 357 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 358 | self.weight.data.copy_(embed_normalized) 359 | 360 | 361 | class EMAVectorQuantizer(nn.Module): 362 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, 363 | remap=None, unknown_index="random"): 364 | super().__init__() 365 | self.codebook_dim = codebook_dim 366 | self.num_tokens = num_tokens 367 | self.beta = beta 368 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) 369 | 370 | self.remap = remap 371 | if self.remap is not None: 372 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 373 | self.re_embed = self.used.shape[0] 374 | self.unknown_index = unknown_index # "random" or "extra" or integer 375 | if self.unknown_index == "extra": 376 | self.unknown_index = self.re_embed 377 | self.re_embed = self.re_embed+1 378 | print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 379 | f"Using {self.unknown_index} for unknown indices.") 380 | else: 381 | self.re_embed = n_embed 382 | 383 | def remap_to_used(self, inds): 384 | ishape = inds.shape 385 | assert len(ishape)>1 386 | inds = inds.reshape(ishape[0],-1) 387 | used = self.used.to(inds) 388 | match = (inds[:,:,None]==used[None,None,...]).long() 389 | new = match.argmax(-1) 390 | unknown = match.sum(2)<1 391 | if self.unknown_index == "random": 392 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 393 | else: 394 | new[unknown] = self.unknown_index 395 | return new.reshape(ishape) 396 | 397 | def unmap_to_all(self, inds): 398 | ishape = inds.shape 399 | assert len(ishape)>1 400 | inds = inds.reshape(ishape[0],-1) 401 | used = self.used.to(inds) 402 | if self.re_embed > self.used.shape[0]: # extra token 403 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 404 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 405 | return back.reshape(ishape) 406 | 407 | def forward(self, z): 408 | # reshape z -> (batch, height, width, channel) and flatten 409 | #z, 'b c h w -> b h w c' 410 | z = rearrange(z, 'b c h w -> b h w c') 411 | z_flattened = z.reshape(-1, self.codebook_dim) 412 | 413 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 414 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 415 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 416 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 417 | 418 | 419 | encoding_indices = torch.argmin(d, dim=1) 420 | 421 | z_q = self.embedding(encoding_indices).view(z.shape) 422 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 423 | avg_probs = torch.mean(encodings, dim=0) 424 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 425 | 426 | if self.training and self.embedding.update: 427 | #EMA cluster size 428 | encodings_sum = encodings.sum(0) 429 | self.embedding.cluster_size_ema_update(encodings_sum) 430 | #EMA embedding average 431 | embed_sum = encodings.transpose(0,1) @ z_flattened 432 | self.embedding.embed_avg_ema_update(embed_sum) 433 | #normalize embed_avg and update weight 434 | self.embedding.weight_update(self.num_tokens) 435 | 436 | # compute loss for embedding 437 | loss = self.beta * F.mse_loss(z_q.detach(), z) 438 | 439 | # preserve gradients 440 | z_q = z + (z_q - z).detach() 441 | 442 | # reshape back to match original input shape 443 | #z_q, 'b h w c -> b c h w' 444 | z_q = rearrange(z_q, 'b h w c -> b c h w') 445 | return z_q, loss, (perplexity, encodings, encoding_indices) 446 | -------------------------------------------------------------------------------- /models/syncnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .conv import Conv2d 6 | 7 | class SyncNet_color(nn.Module): 8 | def __init__(self): 9 | super(SyncNet_color, self).__init__() 10 | 11 | self.face_encoder = nn.Sequential( 12 | Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), 13 | 14 | Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), 15 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 16 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 17 | 18 | Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 19 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 20 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 22 | 23 | Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 24 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 25 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 26 | 27 | Conv2d(256, 512, kernel_size=3, stride=2, padding=1), 28 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 29 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 30 | 31 | Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 32 | Conv2d(512, 512, kernel_size=3, stride=1, padding=0), 33 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 34 | 35 | self.audio_encoder = nn.Sequential( 36 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 37 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 38 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 39 | 40 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 41 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 42 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 43 | 44 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 45 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 46 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 47 | 48 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 49 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 50 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 51 | 52 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 53 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 54 | 55 | def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) 56 | face_embedding = self.face_encoder(face_sequences) 57 | audio_embedding = self.audio_encoder(audio_sequences) 58 | 59 | audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) 60 | face_embedding = face_embedding.view(face_embedding.size(0), -1) 61 | 62 | audio_embedding = F.normalize(audio_embedding, p=2, dim=1) 63 | face_embedding = F.normalize(face_embedding, p=2, dim=1) 64 | 65 | 66 | return audio_embedding, face_embedding 67 | -------------------------------------------------------------------------------- /models/syncnet_vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from omegaconf import OmegaConf 5 | 6 | try: 7 | from .conv import Conv2d 8 | from .vqgan import VQModel 9 | except: 10 | from conv import Conv2d 11 | from vqgan import VQModel 12 | 13 | 14 | class SyncNet_color(nn.Module): 15 | def __init__(self, config_path, ckpt_path=None, syncnet_T=5): 16 | super(SyncNet_color, self).__init__() 17 | self.T = syncnet_T 18 | 19 | # (B, 5 x 256, 16, 16) -> (B, 512, 1, 1) 20 | self.face_encoder = nn.Sequential( 21 | Conv2d(self.T * 256, 256, kernel_size=3, stride=2, padding=1), # 16, 16 -> 8, 8 22 | Conv2d(256, 64, kernel_size=3, stride=1, padding=1), 23 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 24 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 25 | 26 | Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 8, 8 -> 4, 4 27 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 28 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 29 | 30 | Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 4, 4 -> 2, 2 31 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 32 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 33 | 34 | Conv2d(256, 512, kernel_size=2, stride=1, padding=0), # 2, 2 -> 1, 1 35 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0) 36 | ) 37 | 38 | # (B, 1, 80, 16) -> (B, 512, 1, 1) 39 | self.audio_encoder = nn.Sequential( 40 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 41 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 42 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 43 | 44 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 45 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 46 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 47 | 48 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 49 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 50 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 51 | 52 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 53 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 54 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 55 | 56 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 57 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0) 58 | ) 59 | 60 | config = OmegaConf.load(config_path) 61 | self.vq_model = VQModel(ckpt_path=ckpt_path, **config.model.params) 62 | for parameter in self.vq_model.parameters(): 63 | parameter.requires_grad = False 64 | self.vq_model.eval() 65 | 66 | def forward(self, audio_sequences, face_sequences, vq_encoded=False): 67 | batch_size = face_sequences.size(0) # audio_sequences := (B, dim, T) 68 | input_dim_size = len(face_sequences.size()) 69 | if input_dim_size > 4: 70 | # (2, 5, 3, 256, 256) -> (2 x 5, 3, 256, 256) 71 | face_sequences = torch.cat([face_sequences[:, i] for i in range(face_sequences.size(1))], dim=0) 72 | else: 73 | batch_size = batch_size // self.T 74 | 75 | # (2 x 5, 3, 256, 256) -> (2 x 5, 256, 16, 16) 76 | if not vq_encoded: 77 | face_sequences, _, _ = self.vq_model.encode(face_sequences) 78 | 79 | # (2 x 5, 256, 16, 16) resize (2, 5 x 256, 16, 16) 80 | # face_sequences = face_sequences.view(batch_size, 5 * 256, 16, 16) 81 | face_sequences = torch.split(face_sequences, batch_size, dim=0) 82 | face_sequences = torch.cat(face_sequences, dim=1) 83 | 84 | # (2, 5 x 256, 16, 16) -> (2, 512, 1, 1) 85 | face_embedding = self.face_encoder(face_sequences) 86 | 87 | # (2, 1, 80, 16) -> (2, 512, 1, 1) 88 | audio_embedding = self.audio_encoder(audio_sequences) 89 | 90 | audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) # (2, 512) 91 | face_embedding = face_embedding.view(face_embedding.size(0), -1) # (2, 512) 92 | 93 | audio_embedding = F.normalize(audio_embedding, p=2, dim=1) 94 | face_embedding = F.normalize(face_embedding, p=2, dim=1) 95 | 96 | return audio_embedding, face_embedding 97 | 98 | 99 | if __name__ == '__main__': 100 | config_path = '../data/vqgan-project.yaml' 101 | 102 | model = SyncNet_color(config_path) 103 | model.eval() 104 | print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 105 | 106 | audio_sequences = torch.randn(size=(2, 1, 80, 16)) 107 | face_sequences = torch.randn(size=(2, 5, 3, 256, 256)) 108 | 109 | audio_embedding, face_embedding = model(audio_sequences, face_sequences) 110 | 111 | print(audio_embedding.shape) # (B, 512) 112 | print(face_embedding.shape) # (B, 512) 113 | -------------------------------------------------------------------------------- /models/vqgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytorch_lightning as pl 4 | 5 | # from main import instantiate_from_config 6 | 7 | try: 8 | from .encoder_vq import Encoder, Decoder 9 | from .quantize_vq import VectorQuantizer2 as VectorQuantizer 10 | except: 11 | from encoder_vq import Encoder, Decoder 12 | from quantize_vq import VectorQuantizer2 as VectorQuantizer 13 | 14 | class VQModel(pl.LightningModule): 15 | def __init__(self, 16 | ddconfig, 17 | lossconfig, 18 | n_embed, 19 | embed_dim, 20 | ckpt_path=None, 21 | ignore_keys=[], 22 | image_key="image", 23 | colorize_nlabels=None, 24 | monitor=None, 25 | remap=None, 26 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 27 | ): 28 | super().__init__() 29 | self.image_key = image_key 30 | self.encoder = Encoder(**ddconfig) 31 | self.decoder = Decoder(**ddconfig) 32 | # self.loss = instantiate_from_config(lossconfig) 33 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 34 | remap=remap, sane_index_shape=sane_index_shape) 35 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 36 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 37 | if ckpt_path is not None: 38 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 39 | self.image_key = image_key 40 | if colorize_nlabels is not None: 41 | assert type(colorize_nlabels)==int 42 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 43 | if monitor is not None: 44 | self.monitor = monitor 45 | 46 | def init_from_ckpt(self, path, ignore_keys=list()): 47 | sd = torch.load(path, map_location="cpu")["state_dict"] 48 | keys = list(sd.keys()) 49 | for k in keys: 50 | for ik in ignore_keys: 51 | if k.startswith(ik): 52 | print("Deleting key {} from state_dict.".format(k)) 53 | del sd[k] 54 | self.load_state_dict(sd, strict=False) 55 | print(f"Restored from {path}") 56 | 57 | def encode(self, x): 58 | h = self.encoder(x) 59 | h = self.quant_conv(h) 60 | quant, emb_loss, info = self.quantize(h) 61 | return quant, emb_loss, info 62 | 63 | def decode(self, quant): 64 | quant = self.post_quant_conv(quant) 65 | dec = self.decoder(quant) 66 | return dec 67 | 68 | def decode_code(self, code_b): 69 | quant_b = self.quantize.embed_code(code_b) 70 | dec = self.decode(quant_b) 71 | return dec 72 | 73 | def forward(self, input): 74 | quant, diff, _ = self.encode(input) 75 | dec = self.decode(quant) 76 | return dec, diff 77 | 78 | def get_input(self, batch, k): 79 | x = batch[k] 80 | if len(x.shape) == 3: 81 | x = x.unsqueeze(0) 82 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 83 | return x.float() 84 | 85 | # def training_step(self, batch, batch_idx, optimizer_idx): 86 | # x = self.get_input(batch, self.image_key) 87 | # xrec, qloss = self(x) 88 | # 89 | # if optimizer_idx == 0: 90 | # # autoencode 91 | # aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 92 | # last_layer=self.get_last_layer(), split="train") 93 | # 94 | # self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 95 | # self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) 96 | # return aeloss 97 | # 98 | # if optimizer_idx == 1: 99 | # # discriminator 100 | # discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, 101 | # last_layer=self.get_last_layer(), split="train") 102 | # self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) 103 | # self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) 104 | # return discloss 105 | # 106 | # def validation_step(self, batch, batch_idx): 107 | # x = self.get_input(batch, self.image_key) 108 | # xrec, qloss = self(x) 109 | # aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, 110 | # last_layer=self.get_last_layer(), split="val") 111 | # 112 | # discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, 113 | # last_layer=self.get_last_layer(), split="val") 114 | # rec_loss = log_dict_ae["val/rec_loss"] 115 | # # self.log("val/rec_loss", rec_loss, 116 | # # prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 117 | # self.log("val/aeloss", aeloss, 118 | # prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) 119 | # self.log_dict(log_dict_ae) 120 | # self.log_dict(log_dict_disc) 121 | # return self.log_dict 122 | # 123 | # def configure_optimizers(self): 124 | # lr = self.learning_rate 125 | # opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ 126 | # list(self.decoder.parameters())+ 127 | # list(self.quantize.parameters())+ 128 | # list(self.quant_conv.parameters())+ 129 | # list(self.post_quant_conv.parameters()), 130 | # lr=lr, betas=(0.5, 0.9)) 131 | # opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 132 | # lr=lr, betas=(0.5, 0.9)) 133 | # return [opt_ae, opt_disc], [] 134 | 135 | def get_last_layer(self): 136 | return self.decoder.conv_out.weight 137 | 138 | def log_images(self, batch, **kwargs): 139 | log = dict() 140 | x = self.get_input(batch, self.image_key) 141 | x = x.to(self.device) 142 | xrec, _ = self(x) 143 | if x.shape[1] > 3: 144 | # colorize with random projection 145 | assert xrec.shape[1] > 3 146 | x = self.to_rgb(x) 147 | xrec = self.to_rgb(xrec) 148 | log["inputs"] = x 149 | log["reconstructions"] = xrec 150 | return log 151 | 152 | def to_rgb(self, x): 153 | assert self.image_key == "segmentation" 154 | if not hasattr(self, "colorize"): 155 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 156 | x = F.conv2d(x, weight=self.colorize) 157 | x = 2.*(x-x.min())/(x.max()-x.min()) - 1. 158 | return x 159 | 160 | -------------------------------------------------------------------------------- /models/wav2lip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d 7 | 8 | class Wav2Lip(nn.Module): 9 | def __init__(self): 10 | super(Wav2Lip, self).__init__() 11 | 12 | self.face_encoder_blocks = nn.ModuleList([ 13 | nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 14 | 15 | nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 16 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 17 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), 18 | 19 | nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 20 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 22 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), 23 | 24 | nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 25 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 26 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), 27 | 28 | nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 29 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 30 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), 31 | 32 | nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 33 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 34 | 35 | nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 36 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 37 | 38 | self.audio_encoder = nn.Sequential( 39 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 40 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 41 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 42 | 43 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 44 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 45 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 46 | 47 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 48 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 49 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 50 | 51 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 52 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 53 | 54 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 55 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 56 | 57 | self.face_decoder_blocks = nn.ModuleList([ 58 | nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), 59 | 60 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 61 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 62 | 63 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), 64 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 65 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 66 | 67 | nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), 68 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), 69 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 70 | 71 | nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), 72 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 73 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 74 | 75 | nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), 76 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 77 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 78 | 79 | nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 80 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 81 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 82 | 83 | self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), 84 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), 85 | nn.Sigmoid()) 86 | 87 | def forward(self, audio_sequences, face_sequences): 88 | # audio_sequences = (B, T, 1, 80, 16) 89 | B = audio_sequences.size(0) 90 | 91 | input_dim_size = len(face_sequences.size()) 92 | if input_dim_size > 4: 93 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) 94 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 95 | 96 | audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 97 | 98 | feats = [] 99 | x = face_sequences 100 | for f in self.face_encoder_blocks: 101 | x = f(x) 102 | feats.append(x) 103 | 104 | x = audio_embedding 105 | for f in self.face_decoder_blocks: 106 | x = f(x) 107 | try: 108 | x = torch.cat((x, feats[-1]), dim=1) 109 | except Exception as e: 110 | print(x.size()) 111 | print(feats[-1].size()) 112 | raise e 113 | 114 | feats.pop() 115 | 116 | x = self.output_block(x) 117 | 118 | if input_dim_size > 4: 119 | x = torch.split(x, B, dim=0) # [(B, C, H, W)] 120 | outputs = torch.stack(x, dim=2) # (B, C, T, H, W) 121 | 122 | else: 123 | outputs = x 124 | 125 | return outputs 126 | 127 | class Wav2Lip_disc_qual(nn.Module): 128 | def __init__(self): 129 | super(Wav2Lip_disc_qual, self).__init__() 130 | 131 | self.face_encoder_blocks = nn.ModuleList([ 132 | nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 133 | 134 | nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 135 | nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), 136 | 137 | nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 138 | nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), 139 | 140 | nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 141 | nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), 142 | 143 | nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 144 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), 145 | 146 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 147 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), 148 | 149 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 150 | nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 151 | 152 | self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) 153 | self.label_noise = .0 154 | 155 | def get_lower_half(self, face_sequences): 156 | return face_sequences[:, :, face_sequences.size(2)//2:] 157 | 158 | def to_2d(self, face_sequences): 159 | B = face_sequences.size(0) 160 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 161 | return face_sequences 162 | 163 | def perceptual_forward(self, false_face_sequences): 164 | false_face_sequences = self.to_2d(false_face_sequences) 165 | false_face_sequences = self.get_lower_half(false_face_sequences) 166 | 167 | false_feats = false_face_sequences 168 | for f in self.face_encoder_blocks: 169 | false_feats = f(false_feats) 170 | 171 | false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), 172 | torch.ones((len(false_feats), 1)).cuda()) 173 | 174 | return false_pred_loss 175 | 176 | def forward(self, face_sequences): 177 | face_sequences = self.to_2d(face_sequences) 178 | face_sequences = self.get_lower_half(face_sequences) 179 | 180 | x = face_sequences 181 | for f in self.face_encoder_blocks: 182 | x = f(x) 183 | 184 | return self.binary_pred(x).view(len(x), -1) 185 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[0] < 3 and sys.version_info[1] < 2: 4 | raise Exception("Must be using >= Python 3.2") 5 | 6 | from os import listdir, path 7 | 8 | if not path.isfile('face_detection/detection/sfd/s3fd.pth'): 9 | raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \ 10 | before running this script!') 11 | 12 | import multiprocessing as mp 13 | from concurrent.futures import ThreadPoolExecutor, as_completed 14 | import numpy as np 15 | import argparse, os, cv2, traceback, subprocess 16 | from tqdm import tqdm 17 | from glob import glob 18 | import audio 19 | from hparams import hparams as hp 20 | 21 | import face_detection 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int) 26 | parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int) 27 | parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True) 28 | parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True) 29 | 30 | args = parser.parse_args() 31 | 32 | fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, 33 | device='cuda:{}'.format(id)) for id in range(args.ngpu)] 34 | 35 | template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}' 36 | # template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}' 37 | 38 | def process_video_file(vfile, args, gpu_id): 39 | video_stream = cv2.VideoCapture(vfile) 40 | 41 | frames = [] 42 | while 1: 43 | still_reading, frame = video_stream.read() 44 | if not still_reading: 45 | video_stream.release() 46 | break 47 | frames.append(frame) 48 | 49 | vidname = os.path.basename(vfile).split('.')[0] 50 | dirname = vfile.split('/')[-2] 51 | 52 | fulldir = path.join(args.preprocessed_root, dirname, vidname) 53 | os.makedirs(fulldir, exist_ok=True) 54 | 55 | batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)] 56 | 57 | i = -1 58 | for fb in batches: 59 | preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb)) 60 | 61 | for j, f in enumerate(preds): 62 | i += 1 63 | if f is None: 64 | continue 65 | 66 | x1, y1, x2, y2 = f 67 | cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2]) 68 | 69 | def process_audio_file(vfile, args): 70 | vidname = os.path.basename(vfile).split('.')[0] 71 | dirname = vfile.split('/')[-2] 72 | 73 | fulldir = path.join(args.preprocessed_root, dirname, vidname) 74 | os.makedirs(fulldir, exist_ok=True) 75 | 76 | wavpath = path.join(fulldir, 'audio.wav') 77 | 78 | command = template.format(vfile, wavpath) 79 | subprocess.call(command, shell=True) 80 | 81 | 82 | def mp_handler(job): 83 | vfile, args, gpu_id = job 84 | try: 85 | process_video_file(vfile, args, gpu_id) 86 | except KeyboardInterrupt: 87 | exit(0) 88 | except: 89 | traceback.print_exc() 90 | 91 | def main(args): 92 | print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu)) 93 | 94 | filelist = glob(path.join(args.data_root, '*/*.mp4')) 95 | 96 | jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)] 97 | p = ThreadPoolExecutor(args.ngpu) 98 | futures = [p.submit(mp_handler, j) for j in jobs] 99 | _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))] 100 | 101 | print('Dumping audios...') 102 | 103 | for vfile in tqdm(filelist): 104 | try: 105 | process_audio_file(vfile, args) 106 | except KeyboardInterrupt: 107 | exit(0) 108 | except: 109 | traceback.print_exc() 110 | continue 111 | 112 | if __name__ == '__main__': 113 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.7.0 2 | numpy==1.17.1 3 | opencv-contrib-python>=4.2.0.34 4 | opencv-python==4.1.0.25 5 | torch==1.1.0 6 | torchvision==0.3.0 7 | tqdm==4.45.0 8 | numba==0.48 9 | -------------------------------------------------------------------------------- /wav2lip_train.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join, basename, isfile 2 | from tqdm import tqdm 3 | 4 | from models import SyncNet_color as SyncNet 5 | from models import Wav2Lip as Wav2Lip 6 | import audio 7 | 8 | import torch 9 | from torch import nn 10 | from torch import optim 11 | import torch.backends.cudnn as cudnn 12 | from torch.utils import data as data_utils 13 | import numpy as np 14 | 15 | from glob import glob 16 | 17 | import os, random, cv2, argparse 18 | from hparams import hparams, get_image_list 19 | 20 | parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator') 21 | 22 | parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str) 23 | 24 | parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str) 25 | parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str) 26 | 27 | parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str) 28 | 29 | args = parser.parse_args() 30 | 31 | 32 | global_step = 0 33 | global_epoch = 0 34 | use_cuda = torch.cuda.is_available() 35 | print('use_cuda: {}'.format(use_cuda)) 36 | 37 | syncnet_T = 5 38 | syncnet_mel_step_size = 16 39 | 40 | class Dataset(object): 41 | def __init__(self, split): 42 | self.all_videos = get_image_list(args.data_root, split) 43 | 44 | def get_frame_id(self, frame): 45 | return int(basename(frame).split('.')[0]) 46 | 47 | def get_window(self, start_frame): 48 | start_id = self.get_frame_id(start_frame) 49 | vidname = dirname(start_frame) 50 | 51 | window_fnames = [] 52 | for frame_id in range(start_id, start_id + syncnet_T): 53 | frame = join(vidname, '{}.jpg'.format(frame_id)) 54 | if not isfile(frame): 55 | return None 56 | window_fnames.append(frame) 57 | return window_fnames 58 | 59 | def read_window(self, window_fnames): 60 | if window_fnames is None: return None 61 | window = [] 62 | for fname in window_fnames: 63 | img = cv2.imread(fname) 64 | if img is None: 65 | return None 66 | try: 67 | img = cv2.resize(img, (hparams.img_size, hparams.img_size)) 68 | except Exception as e: 69 | return None 70 | 71 | window.append(img) 72 | 73 | return window 74 | 75 | def crop_audio_window(self, spec, start_frame): 76 | if type(start_frame) == int: 77 | start_frame_num = start_frame 78 | else: 79 | start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing 80 | start_idx = int(80. * (start_frame_num / float(hparams.fps))) 81 | 82 | end_idx = start_idx + syncnet_mel_step_size 83 | 84 | return spec[start_idx : end_idx, :] 85 | 86 | def get_segmented_mels(self, spec, start_frame): 87 | mels = [] 88 | assert syncnet_T == 5 89 | start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing 90 | if start_frame_num - 2 < 0: return None 91 | for i in range(start_frame_num, start_frame_num + syncnet_T): 92 | m = self.crop_audio_window(spec, i - 2) 93 | if m.shape[0] != syncnet_mel_step_size: 94 | return None 95 | mels.append(m.T) 96 | 97 | mels = np.asarray(mels) 98 | 99 | return mels 100 | 101 | def prepare_window(self, window): 102 | # 3 x T x H x W 103 | x = np.asarray(window) / 255. 104 | x = np.transpose(x, (3, 0, 1, 2)) 105 | 106 | return x 107 | 108 | def __len__(self): 109 | return len(self.all_videos) 110 | 111 | def __getitem__(self, idx): 112 | while 1: 113 | idx = random.randint(0, len(self.all_videos) - 1) 114 | vidname = self.all_videos[idx] 115 | img_names = list(glob(join(vidname, '*.jpg'))) 116 | if len(img_names) <= 3 * syncnet_T: 117 | continue 118 | 119 | img_name = random.choice(img_names) 120 | wrong_img_name = random.choice(img_names) 121 | while wrong_img_name == img_name: 122 | wrong_img_name = random.choice(img_names) 123 | 124 | window_fnames = self.get_window(img_name) 125 | wrong_window_fnames = self.get_window(wrong_img_name) 126 | if window_fnames is None or wrong_window_fnames is None: 127 | continue 128 | 129 | window = self.read_window(window_fnames) 130 | if window is None: 131 | continue 132 | 133 | wrong_window = self.read_window(wrong_window_fnames) 134 | if wrong_window is None: 135 | continue 136 | 137 | try: 138 | wavpath = join(vidname, "audio.wav") 139 | wav = audio.load_wav(wavpath, hparams.sample_rate) 140 | 141 | orig_mel = audio.melspectrogram(wav).T 142 | except Exception as e: 143 | continue 144 | 145 | mel = self.crop_audio_window(orig_mel.copy(), img_name) 146 | 147 | if (mel.shape[0] != syncnet_mel_step_size): 148 | continue 149 | 150 | indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) 151 | if indiv_mels is None: continue 152 | 153 | window = self.prepare_window(window) 154 | y = window.copy() 155 | window[:, :, window.shape[2]//2:] = 0. 156 | 157 | wrong_window = self.prepare_window(wrong_window) 158 | x = np.concatenate([window, wrong_window], axis=0) 159 | 160 | x = torch.FloatTensor(x) 161 | mel = torch.FloatTensor(mel.T).unsqueeze(0) 162 | indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1) 163 | y = torch.FloatTensor(y) 164 | return x, indiv_mels, mel, y 165 | 166 | def save_sample_images(x, g, gt, global_step, checkpoint_dir): 167 | x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 168 | g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 169 | gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8) 170 | 171 | refs, inps = x[..., 3:], x[..., :3] 172 | folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step)) 173 | if not os.path.exists(folder): os.mkdir(folder) 174 | collage = np.concatenate((refs, inps, g, gt), axis=-2) 175 | for batch_idx, c in enumerate(collage): 176 | for t in range(len(c)): 177 | cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t]) 178 | 179 | logloss = nn.BCELoss() 180 | def cosine_loss(a, v, y): 181 | d = nn.functional.cosine_similarity(a, v) 182 | loss = logloss(d.unsqueeze(1), y) 183 | 184 | return loss 185 | 186 | device = torch.device("cuda" if use_cuda else "cpu") 187 | syncnet = SyncNet().to(device) 188 | for p in syncnet.parameters(): 189 | p.requires_grad = False 190 | 191 | recon_loss = nn.L1Loss() 192 | def get_sync_loss(mel, g): 193 | g = g[:, :, :, g.size(3)//2:] 194 | g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1) 195 | # B, 3 * T, H//2, W 196 | a, v = syncnet(mel, g) 197 | y = torch.ones(g.size(0), 1).float().to(device) 198 | return cosine_loss(a, v, y) 199 | 200 | def train(device, model, train_data_loader, test_data_loader, optimizer, 201 | checkpoint_dir=None, checkpoint_interval=None, nepochs=None): 202 | 203 | global global_step, global_epoch 204 | resumed_step = global_step 205 | 206 | while global_epoch < nepochs: 207 | print('Starting Epoch: {}'.format(global_epoch)) 208 | running_sync_loss, running_l1_loss = 0., 0. 209 | prog_bar = tqdm(enumerate(train_data_loader)) 210 | for step, (x, indiv_mels, mel, gt) in prog_bar: 211 | model.train() 212 | optimizer.zero_grad() 213 | 214 | # Move data to CUDA device 215 | x = x.to(device) 216 | mel = mel.to(device) 217 | indiv_mels = indiv_mels.to(device) 218 | gt = gt.to(device) 219 | 220 | g = model(indiv_mels, x) 221 | 222 | if hparams.syncnet_wt > 0.: 223 | sync_loss = get_sync_loss(mel, g) 224 | else: 225 | sync_loss = 0. 226 | 227 | l1loss = recon_loss(g, gt) 228 | 229 | loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss 230 | loss.backward() 231 | optimizer.step() 232 | 233 | if global_step % checkpoint_interval == 0: 234 | save_sample_images(x, g, gt, global_step, checkpoint_dir) 235 | 236 | global_step += 1 237 | cur_session_steps = global_step - resumed_step 238 | 239 | running_l1_loss += l1loss.item() 240 | if hparams.syncnet_wt > 0.: 241 | running_sync_loss += sync_loss.item() 242 | else: 243 | running_sync_loss += 0. 244 | 245 | if global_step == 1 or global_step % checkpoint_interval == 0: 246 | save_checkpoint( 247 | model, optimizer, global_step, checkpoint_dir, global_epoch) 248 | 249 | if global_step == 1 or global_step % hparams.eval_interval == 0: 250 | with torch.no_grad(): 251 | average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir) 252 | 253 | if average_sync_loss < .75: 254 | hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient 255 | 256 | prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1), 257 | running_sync_loss / (step + 1))) 258 | 259 | global_epoch += 1 260 | 261 | 262 | def eval_model(test_data_loader, global_step, device, model, checkpoint_dir): 263 | eval_steps = 700 264 | print('Evaluating for {} steps'.format(eval_steps)) 265 | sync_losses, recon_losses = [], [] 266 | step = 0 267 | while 1: 268 | for x, indiv_mels, mel, gt in test_data_loader: 269 | step += 1 270 | model.eval() 271 | 272 | # Move data to CUDA device 273 | x = x.to(device) 274 | gt = gt.to(device) 275 | indiv_mels = indiv_mels.to(device) 276 | mel = mel.to(device) 277 | 278 | g = model(indiv_mels, x) 279 | 280 | sync_loss = get_sync_loss(mel, g) 281 | l1loss = recon_loss(g, gt) 282 | 283 | sync_losses.append(sync_loss.item()) 284 | recon_losses.append(l1loss.item()) 285 | 286 | if step > eval_steps: 287 | averaged_sync_loss = sum(sync_losses) / len(sync_losses) 288 | averaged_recon_loss = sum(recon_losses) / len(recon_losses) 289 | 290 | print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss)) 291 | 292 | return averaged_sync_loss 293 | 294 | def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch): 295 | 296 | checkpoint_path = join( 297 | checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step)) 298 | optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None 299 | torch.save({ 300 | "state_dict": model.state_dict(), 301 | "optimizer": optimizer_state, 302 | "global_step": step, 303 | "global_epoch": epoch, 304 | }, checkpoint_path) 305 | print("Saved checkpoint:", checkpoint_path) 306 | 307 | 308 | def _load(checkpoint_path): 309 | if use_cuda: 310 | checkpoint = torch.load(checkpoint_path) 311 | else: 312 | checkpoint = torch.load(checkpoint_path, 313 | map_location=lambda storage, loc: storage) 314 | return checkpoint 315 | 316 | def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True): 317 | global global_step 318 | global global_epoch 319 | 320 | print("Load checkpoint from: {}".format(path)) 321 | checkpoint = _load(path) 322 | s = checkpoint["state_dict"] 323 | new_s = {} 324 | for k, v in s.items(): 325 | new_s[k.replace('module.', '')] = v 326 | model.load_state_dict(new_s) 327 | if not reset_optimizer: 328 | optimizer_state = checkpoint["optimizer"] 329 | if optimizer_state is not None: 330 | print("Load optimizer state from {}".format(path)) 331 | optimizer.load_state_dict(checkpoint["optimizer"]) 332 | if overwrite_global_states: 333 | global_step = checkpoint["global_step"] 334 | global_epoch = checkpoint["global_epoch"] 335 | 336 | return model 337 | 338 | if __name__ == "__main__": 339 | checkpoint_dir = args.checkpoint_dir 340 | 341 | # Dataset and Dataloader setup 342 | train_dataset = Dataset('train') 343 | test_dataset = Dataset('val') 344 | 345 | train_data_loader = data_utils.DataLoader( 346 | train_dataset, batch_size=hparams.batch_size, shuffle=True, 347 | num_workers=hparams.num_workers) 348 | 349 | test_data_loader = data_utils.DataLoader( 350 | test_dataset, batch_size=hparams.batch_size, 351 | num_workers=4) 352 | 353 | device = torch.device("cuda" if use_cuda else "cpu") 354 | 355 | # Model 356 | model = Wav2Lip().to(device) 357 | print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 358 | 359 | optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], 360 | lr=hparams.initial_learning_rate) 361 | 362 | if args.checkpoint_path is not None: 363 | load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False) 364 | 365 | load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False) 366 | 367 | if not os.path.exists(checkpoint_dir): 368 | os.mkdir(checkpoint_dir) 369 | 370 | # Train! 371 | train(device, model, train_data_loader, test_data_loader, optimizer, 372 | checkpoint_dir=checkpoint_dir, 373 | checkpoint_interval=hparams.checkpoint_interval, 374 | nepochs=hparams.nepochs) 375 | --------------------------------------------------------------------------------