├── .gitignore ├── .gitattributes ├── packages.txt ├── wav2lip ├── face_detection │ ├── detection │ │ ├── __init__.py │ │ ├── sfd │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── bbox.cpython-311.pyc │ │ │ │ ├── detect.cpython-311.pyc │ │ │ │ ├── __init__.cpython-311.pyc │ │ │ │ ├── net_s3fd.cpython-311.pyc │ │ │ │ └── sfd_detector.cpython-311.pyc │ │ │ ├── sfd_detector.py │ │ │ ├── detect.py │ │ │ ├── bbox.py │ │ │ └── net_s3fd.py │ │ ├── __pycache__ │ │ │ ├── core.cpython-311.pyc │ │ │ └── __init__.cpython-311.pyc │ │ └── core.py │ ├── README.md │ ├── __init__.py │ ├── api.py │ ├── models.py │ └── utils.py ├── results │ └── README.md ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── conv.cpython-311.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── syncnet.cpython-311.pyc │ │ └── wav2lip.cpython-311.pyc │ ├── conv.py │ ├── syncnet.py │ └── wav2lip.py ├── temp │ └── README.md ├── hparams.py ├── audio.py └── inference.py ├── .streamlit └── config.toml ├── avatars_images ├── avatar1.jpg ├── avatar2.jpg └── avatar3.png ├── README.md ├── requirements.txt └── app.py /.gitignore: -------------------------------------------------------------------------------- 1 | sound.wav 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /packages.txt: -------------------------------------------------------------------------------- 1 | python3-opencv 2 | libgl1-mesa-dev 3 | ffmpeg 4 | -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import FaceDetector -------------------------------------------------------------------------------- /wav2lip/results/README.md: -------------------------------------------------------------------------------- 1 | Generated results will be placed in this folder by default. -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="dark" 3 | primaryColor="#865bf1" 4 | font="monospace" -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/__init__.py: -------------------------------------------------------------------------------- 1 | from .sfd_detector import SFDDetector as FaceDetector -------------------------------------------------------------------------------- /wav2lip/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wav2lip import Wav2Lip, Wav2Lip_disc_qual 2 | from .syncnet import SyncNet_color -------------------------------------------------------------------------------- /wav2lip/temp/README.md: -------------------------------------------------------------------------------- 1 | Temporary files at the time of inference/testing will be saved here. You can ignore them. -------------------------------------------------------------------------------- /avatars_images/avatar1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/avatars_images/avatar1.jpg -------------------------------------------------------------------------------- /avatars_images/avatar2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/avatars_images/avatar2.jpg -------------------------------------------------------------------------------- /avatars_images/avatar3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/avatars_images/avatar3.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI Lip Sync 2 | 3 | This library is forked from https://github.com/Aml-Hassan-Abd-El-hamid/ai-lip-sync-app 4 | -------------------------------------------------------------------------------- /wav2lip/models/__pycache__/conv.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/models/__pycache__/conv.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/models/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/models/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/models/__pycache__/syncnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/models/__pycache__/syncnet.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/models/__pycache__/wav2lip.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/models/__pycache__/wav2lip.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/__pycache__/core.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/face_detection/detection/__pycache__/core.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/face_detection/detection/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/__pycache__/bbox.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/face_detection/detection/sfd/__pycache__/bbox.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/__pycache__/detect.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/face_detection/detection/sfd/__pycache__/detect.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/face_detection/detection/sfd/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/face_detection/detection/sfd/__pycache__/net_s3fd.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shashigharti/ai-lip-sync-app/main/wav2lip/face_detection/detection/sfd/__pycache__/sfd_detector.cpython-311.pyc -------------------------------------------------------------------------------- /wav2lip/face_detection/README.md: -------------------------------------------------------------------------------- 1 | The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. -------------------------------------------------------------------------------- /wav2lip/face_detection/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __author__ = """Adrian Bulat""" 4 | __email__ = 'adrian.bulat@nottingham.ac.uk' 5 | __version__ = '1.0.1' 6 | 7 | from .api import FaceAlignment, LandmarksType, NetworkSize 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.3 2 | scipy==1.12.0 3 | iou==0.1.0 4 | librosa==0.10.1 5 | opencv_contrib_python==4.9.0.80 6 | streamlit >= 1.9.2 7 | streamlit_image_select==0.6.0 8 | streamlit_mic_recorder==0.0.4 9 | torch==2.1.2 10 | tqdm==4.64.1 11 | gdown -------------------------------------------------------------------------------- /wav2lip/models/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | 15 | def forward(self, x): 16 | out = self.conv_block(x) 17 | if self.residual: 18 | out += x 19 | return self.act(out) 20 | 21 | class nonorm_Conv2d(nn.Module): 22 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.conv_block = nn.Sequential( 25 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 26 | ) 27 | self.act = nn.LeakyReLU(0.01, inplace=True) 28 | 29 | def forward(self, x): 30 | out = self.conv_block(x) 31 | return self.act(out) 32 | 33 | class Conv2dTranspose(nn.Module): 34 | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.conv_block = nn.Sequential( 37 | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), 38 | nn.BatchNorm2d(cout) 39 | ) 40 | self.act = nn.ReLU() 41 | 42 | def forward(self, x): 43 | out = self.conv_block(x) 44 | return self.act(out) 45 | -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/sfd_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from torch.utils.model_zoo import load_url 4 | 5 | from ..core import FaceDetector 6 | 7 | from .net_s3fd import s3fd 8 | from .bbox import * 9 | from .detect import * 10 | 11 | models_urls = { 12 | 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', 13 | } 14 | 15 | 16 | class SFDDetector(FaceDetector): 17 | def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): 18 | super(SFDDetector, self).__init__(device, verbose) 19 | 20 | # Initialise the face detector 21 | if not os.path.isfile(path_to_detector): 22 | model_weights = load_url(models_urls['s3fd']) 23 | else: 24 | model_weights = torch.load(path_to_detector) 25 | 26 | self.face_detector = s3fd() 27 | self.face_detector.load_state_dict(model_weights) 28 | self.face_detector.to(device) 29 | self.face_detector.eval() 30 | 31 | def detect_from_image(self, tensor_or_path): 32 | image = self.tensor_or_path_to_ndarray(tensor_or_path) 33 | 34 | bboxlist = detect(self.face_detector, image, device=self.device) 35 | keep = nms(bboxlist, 0.3) 36 | bboxlist = bboxlist[keep, :] 37 | bboxlist = [x for x in bboxlist if x[-1] > 0.5] 38 | 39 | return bboxlist 40 | 41 | def detect_from_batch(self, images): 42 | bboxlists = batch_detect(self.face_detector, images, device=self.device) 43 | keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])] 44 | bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] 45 | bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists] 46 | 47 | return bboxlists 48 | 49 | @property 50 | def reference_scale(self): 51 | return 195 52 | 53 | @property 54 | def reference_x_shift(self): 55 | return 0 56 | 57 | @property 58 | def reference_y_shift(self): 59 | return 0 60 | -------------------------------------------------------------------------------- /wav2lip/face_detection/api.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from torch.utils.model_zoo import load_url 5 | from enum import Enum 6 | import numpy as np 7 | import cv2 8 | from .detection import sfd 9 | try: 10 | import urllib.request as request_file 11 | except BaseException: 12 | import urllib as request_file 13 | 14 | from .models import FAN, ResNetDepth 15 | from .utils import * 16 | 17 | 18 | class LandmarksType(Enum): 19 | """Enum class defining the type of landmarks to detect. 20 | 21 | ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face 22 | ``_2halfD`` - this points represent the projection of the 3D points into 3D 23 | ``_3D`` - detect the points ``(x,y,z)``` in a 3D space 24 | 25 | """ 26 | _2D = 1 27 | _2halfD = 2 28 | _3D = 3 29 | 30 | 31 | class NetworkSize(Enum): 32 | # TINY = 1 33 | # SMALL = 2 34 | # MEDIUM = 3 35 | LARGE = 4 36 | 37 | def __new__(cls, value): 38 | member = object.__new__(cls) 39 | member._value_ = value 40 | return member 41 | 42 | def __int__(self): 43 | return self.value 44 | 45 | ROOT = os.path.dirname(os.path.abspath(__file__)) 46 | 47 | class FaceAlignment: 48 | def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, 49 | device='cuda', flip_input=False, face_detector='sfd', verbose=False): 50 | self.device = device 51 | self.flip_input = flip_input 52 | self.landmarks_type = landmarks_type 53 | self.verbose = verbose 54 | 55 | network_size = int(network_size) 56 | 57 | if 'cuda' in device: 58 | torch.backends.cudnn.benchmark = True 59 | 60 | # Get the face detector 61 | #face_detector_module = __import__('from .detection. import' + face_detector, 62 | # globals(), locals(), [face_detector], 0) 63 | #self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) 64 | self.face_detector = sfd.FaceDetector(device=device, verbose=verbose) 65 | 66 | def get_detections_for_batch(self, images): 67 | images = images[..., ::-1] 68 | detected_faces = self.face_detector.detect_from_batch(images.copy()) 69 | results = [] 70 | 71 | for i, d in enumerate(detected_faces): 72 | if len(d) == 0: 73 | results.append(None) 74 | continue 75 | d = d[0] 76 | d = np.clip(d, 0, None) 77 | 78 | x1, y1, x2, y2 = map(int, d[:-1]) 79 | results.append((x1, y1, x2, y2)) 80 | 81 | return results -------------------------------------------------------------------------------- /wav2lip/models/syncnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .conv import Conv2d 6 | 7 | class SyncNet_color(nn.Module): 8 | def __init__(self): 9 | super(SyncNet_color, self).__init__() 10 | 11 | self.face_encoder = nn.Sequential( 12 | Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), 13 | 14 | Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), 15 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 16 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 17 | 18 | Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 19 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 20 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 22 | 23 | Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 24 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 25 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 26 | 27 | Conv2d(256, 512, kernel_size=3, stride=2, padding=1), 28 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 29 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 30 | 31 | Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 32 | Conv2d(512, 512, kernel_size=3, stride=1, padding=0), 33 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 34 | 35 | self.audio_encoder = nn.Sequential( 36 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 37 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 38 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 39 | 40 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 41 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 42 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 43 | 44 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 45 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 46 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 47 | 48 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 49 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 50 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 51 | 52 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 53 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 54 | 55 | def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) 56 | face_embedding = self.face_encoder(face_sequences) 57 | audio_embedding = self.audio_encoder(audio_sequences) 58 | 59 | audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) 60 | face_embedding = face_embedding.view(face_embedding.size(0), -1) 61 | 62 | audio_embedding = F.normalize(audio_embedding, p=2, dim=1) 63 | face_embedding = F.normalize(face_embedding, p=2, dim=1) 64 | 65 | 66 | return audio_embedding, face_embedding 67 | -------------------------------------------------------------------------------- /wav2lip/hparams.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | def get_image_list(data_root, split): 5 | filelist = [] 6 | 7 | with open('filelists/{}.txt'.format(split)) as f: 8 | for line in f: 9 | line = line.strip() 10 | if ' ' in line: line = line.split()[0] 11 | filelist.append(os.path.join(data_root, line)) 12 | 13 | return filelist 14 | 15 | class HParams: 16 | def __init__(self, **kwargs): 17 | self.data = {} 18 | 19 | for key, value in kwargs.items(): 20 | self.data[key] = value 21 | 22 | def __getattr__(self, key): 23 | if key not in self.data: 24 | raise AttributeError("'HParams' object has no attribute %s" % key) 25 | return self.data[key] 26 | 27 | def set_hparam(self, key, value): 28 | self.data[key] = value 29 | 30 | 31 | # Default hyperparameters 32 | hparams = HParams( 33 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 34 | # network 35 | rescale=True, # Whether to rescale audio prior to preprocessing 36 | rescaling_max=0.9, # Rescaling value 37 | 38 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 39 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 40 | # Does not work if n_ffit is not multiple of hop_size!! 41 | use_lws=False, 42 | 43 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 44 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 45 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 46 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 47 | 48 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 49 | 50 | # Mel and Linear spectrograms normalization/scaling and clipping 51 | signal_normalization=True, 52 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 53 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 54 | symmetric_mels=True, 55 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 56 | # faster and cleaner convergence) 57 | max_abs_value=4., 58 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 59 | # be too big to avoid gradient explosion, 60 | # not too small for fast convergence) 61 | # Contribution by @begeekmyfriend 62 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 63 | # levels. Also allows for better G&L phase reconstruction) 64 | preemphasize=True, # whether to apply filter 65 | preemphasis=0.97, # filter coefficient. 66 | 67 | # Limits 68 | min_level_db=-100, 69 | ref_level_db=20, 70 | fmin=55, 71 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 72 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 73 | fmax=7600, # To be increased/reduced depending on data. 74 | 75 | ###################### Our training parameters ################################# 76 | img_size=96, 77 | fps=25, 78 | 79 | batch_size=16, 80 | initial_learning_rate=1e-4, 81 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 82 | num_workers=16, 83 | checkpoint_interval=3000, 84 | eval_interval=3000, 85 | save_optimizer_state=True, 86 | 87 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 88 | syncnet_batch_size=64, 89 | syncnet_lr=1e-4, 90 | syncnet_eval_interval=10000, 91 | syncnet_checkpoint_interval=10000, 92 | 93 | disc_wt=0.07, 94 | disc_initial_learning_rate=1e-4, 95 | ) 96 | 97 | 98 | def hparams_debug_string(): 99 | values = hparams.values() 100 | hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"] 101 | return "Hyperparameters:\n" + "\n".join(hp) 102 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import streamlit as st 3 | from streamlit_image_select import image_select 4 | import torch 5 | from streamlit_mic_recorder import mic_recorder 6 | from wav2lip import inference 7 | from wav2lip.models import Wav2Lip 8 | import gdown 9 | 10 | device='cpu' 11 | #@st.cache_data is used to only load the model once 12 | #@st.cache_data 13 | @st.cache_resource 14 | def load_model(path): 15 | st.write("Please wait for the model to be loaded or it will cause an error") 16 | wav2lip_checkpoints_url = "https://drive.google.com/drive/folders/1Sy5SHRmI3zgg2RJaOttNsN3iJS9VVkbg?usp=sharing" 17 | if not os.path.exists(path): 18 | gdown.download_folder(wav2lip_checkpoints_url, quiet=True, use_cookies=False) 19 | st.write("Please wait") 20 | model = Wav2Lip() 21 | print("Load checkpoint from: {}".format(path)) 22 | checkpoint = torch.load(path,map_location=lambda storage, loc: storage) 23 | s = checkpoint["state_dict"] 24 | new_s = {} 25 | for k, v in s.items(): 26 | new_s[k.replace('module.', '')] = v 27 | model.load_state_dict(new_s) 28 | model = model.to(device) 29 | st.write("model is loaded!") 30 | return model.eval() 31 | @st.cache_resource 32 | def load_avatar_videos_for_slow_animation(path): 33 | avatar_videos_url = "https://drive.google.com/drive/folders/1h9pkU5wenrS2vmKqXBfFmrg-1hYw5s4q?usp=sharing" 34 | if not os.path.exists(path): 35 | gdown.download_folder(avatar_videos_url, quiet=True, use_cookies=False) 36 | 37 | 38 | 39 | image_video_map = { 40 | "avatars_images/avatar1.jpg":"avatars_videos/avatar1.mp4", 41 | "avatars_images/avatar2.jpg":"avatars_videos/avatar2.mp4", 42 | "avatars_images/avatar3.png":"avatars_videos/avatar3.mp4" 43 | } 44 | def streamlit_look(): 45 | """ 46 | Modest front-end code:) 47 | """ 48 | data={} 49 | st.title("Welcome to AI Lip Sync :)") 50 | st.write("Please choose your avatar from the following options:") 51 | avatar_img = image_select("", 52 | ["avatars_images/avatar1.jpg", 53 | "avatars_images/avatar2.jpg", 54 | "avatars_images/avatar3.png", 55 | ]) 56 | data["imge_path"] = avatar_img 57 | audio=mic_recorder( 58 | start_prompt="Start recording", 59 | stop_prompt="Stop recording", 60 | just_once=False, 61 | use_container_width=False, 62 | callback=None, 63 | args=(), 64 | kwargs={}, 65 | key=None) 66 | if audio: 67 | st.audio(audio["bytes"]) 68 | data["audio"]= audio["bytes"] 69 | return data 70 | 71 | def main(): 72 | data=streamlit_look() 73 | st.write("Don't forget to save the record or there will be an error!") 74 | save_record = st.button("save record") 75 | st.write("With fast animation only the lips of the avatar will move, and it will take probably less than a minute for a record of about 30 seconds, but with fast animation choise, the full face of the avatar will move and it will take about 30 minute for a record of about 30 seconds to get ready.") 76 | model = load_model("wav2lip_checkpoints/wav2lip_gan.pth") 77 | fast_animate = st.button("fast animate") 78 | slower_animate = st.button("slower animate") 79 | if save_record: 80 | if os.path.exists('record.wav'): 81 | os.remove('record.wav') 82 | with open('record.wav', mode='bx') as f: 83 | f.write(data["audio"]) 84 | st.write("record saved!") 85 | if fast_animate: 86 | inference.main(data["imge_path"],"record.wav",model) 87 | if os.path.exists('wav2lip/results/result_voice.mp4'): 88 | st.video('wav2lip/results/result_voice.mp4') 89 | if slower_animate: 90 | load_avatar_videos_for_slow_animation("avatars_videos") 91 | inference.main(image_video_map[data["imge_path"]],"record.wav",model) 92 | if os.path.exists('wav2lip/results/result_voice.mp4'): 93 | st.video('wav2lip/results/result_voice.mp4') 94 | 95 | if __name__ == "__main__": 96 | main() -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import random 8 | import datetime 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | import scipy.io as sio 14 | import zipfile 15 | from .net_s3fd import s3fd 16 | from .bbox import * 17 | 18 | 19 | def detect(net, img, device): 20 | img = img - np.array([104, 117, 123]) 21 | img = img.transpose(2, 0, 1) 22 | img = img.reshape((1,) + img.shape) 23 | 24 | if 'cuda' in device: 25 | torch.backends.cudnn.benchmark = True 26 | 27 | img = torch.from_numpy(img).float().to(device) 28 | BB, CC, HH, WW = img.size() 29 | with torch.no_grad(): 30 | olist = net(img) 31 | 32 | bboxlist = [] 33 | for i in range(len(olist) // 2): 34 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 35 | olist = [oelem.data.cpu() for oelem in olist] 36 | for i in range(len(olist) // 2): 37 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 38 | FB, FC, FH, FW = ocls.size() # feature map size 39 | stride = 2**(i + 2) # 4,8,16,32,64,128 40 | anchor = stride * 4 41 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 42 | for Iindex, hindex, windex in poss: 43 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 44 | score = ocls[0, 1, hindex, windex] 45 | loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) 46 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) 47 | variances = [0.1, 0.2] 48 | box = decode(loc, priors, variances) 49 | x1, y1, x2, y2 = box[0] * 1.0 50 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) 51 | bboxlist.append([x1, y1, x2, y2, score]) 52 | bboxlist = np.array(bboxlist) 53 | if 0 == len(bboxlist): 54 | bboxlist = np.zeros((1, 5)) 55 | 56 | return bboxlist 57 | 58 | def batch_detect(net, imgs, device): 59 | imgs = imgs - np.array([104, 117, 123]) 60 | imgs = imgs.transpose(0, 3, 1, 2) 61 | 62 | if 'cuda' in device: 63 | torch.backends.cudnn.benchmark = True 64 | 65 | imgs = torch.from_numpy(imgs).float().to(device) 66 | BB, CC, HH, WW = imgs.size() 67 | with torch.no_grad(): 68 | olist = net(imgs) 69 | 70 | bboxlist = [] 71 | for i in range(len(olist) // 2): 72 | olist[i * 2] = F.softmax(olist[i * 2], dim=1) 73 | olist = [oelem.data.cpu() for oelem in olist] 74 | for i in range(len(olist) // 2): 75 | ocls, oreg = olist[i * 2], olist[i * 2 + 1] 76 | FB, FC, FH, FW = ocls.size() # feature map size 77 | stride = 2**(i + 2) # 4,8,16,32,64,128 78 | anchor = stride * 4 79 | poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) 80 | for Iindex, hindex, windex in poss: 81 | axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride 82 | score = ocls[:, 1, hindex, windex] 83 | loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4) 84 | priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4) 85 | variances = [0.1, 0.2] 86 | box = batch_decode(loc, priors, variances) 87 | box = box[:, 0] * 1.0 88 | # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) 89 | bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy()) 90 | bboxlist = np.array(bboxlist) 91 | if 0 == len(bboxlist): 92 | bboxlist = np.zeros((1, BB, 5)) 93 | 94 | return bboxlist 95 | 96 | def flip_detect(net, img, device): 97 | img = cv2.flip(img, 1) 98 | b = detect(net, img, device) 99 | 100 | bboxlist = np.zeros(b.shape) 101 | bboxlist[:, 0] = img.shape[1] - b[:, 2] 102 | bboxlist[:, 1] = b[:, 1] 103 | bboxlist[:, 2] = img.shape[1] - b[:, 0] 104 | bboxlist[:, 3] = b[:, 3] 105 | bboxlist[:, 4] = b[:, 4] 106 | return bboxlist 107 | 108 | 109 | def pts_to_bb(pts): 110 | min_x, min_y = np.min(pts, axis=0) 111 | max_x, max_y = np.max(pts, axis=0) 112 | return np.array([min_x, min_y, max_x, max_y]) 113 | -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import cv2 5 | import random 6 | import datetime 7 | import time 8 | import math 9 | import argparse 10 | import numpy as np 11 | import torch 12 | 13 | try: 14 | from iou import IOU 15 | except BaseException: 16 | # IOU cython speedup 10x 17 | def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): 18 | sa = abs((ax2 - ax1) * (ay2 - ay1)) 19 | sb = abs((bx2 - bx1) * (by2 - by1)) 20 | x1, y1 = max(ax1, bx1), max(ay1, by1) 21 | x2, y2 = min(ax2, bx2), min(ay2, by2) 22 | w = x2 - x1 23 | h = y2 - y1 24 | if w < 0 or h < 0: 25 | return 0.0 26 | else: 27 | return 1.0 * w * h / (sa + sb - w * h) 28 | 29 | 30 | def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): 31 | xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 32 | dx, dy = (xc - axc) / aww, (yc - ayc) / ahh 33 | dw, dh = math.log(ww / aww), math.log(hh / ahh) 34 | return dx, dy, dw, dh 35 | 36 | 37 | def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): 38 | xc, yc = dx * aww + axc, dy * ahh + ayc 39 | ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh 40 | x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 41 | return x1, y1, x2, y2 42 | 43 | 44 | def nms(dets, thresh): 45 | if 0 == len(dets): 46 | return [] 47 | x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] 48 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 49 | order = scores.argsort()[::-1] 50 | 51 | keep = [] 52 | while order.size > 0: 53 | i = order[0] 54 | keep.append(i) 55 | xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) 56 | xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) 57 | 58 | w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) 59 | ovr = w * h / (areas[i] + areas[order[1:]] - w * h) 60 | 61 | inds = np.where(ovr <= thresh)[0] 62 | order = order[inds + 1] 63 | 64 | return keep 65 | 66 | 67 | def encode(matched, priors, variances): 68 | """Encode the variances from the priorbox layers into the ground truth boxes 69 | we have matched (based on jaccard overlap) with the prior boxes. 70 | Args: 71 | matched: (tensor) Coords of ground truth for each prior in point-form 72 | Shape: [num_priors, 4]. 73 | priors: (tensor) Prior boxes in center-offset form 74 | Shape: [num_priors,4]. 75 | variances: (list[float]) Variances of priorboxes 76 | Return: 77 | encoded boxes (tensor), Shape: [num_priors, 4] 78 | """ 79 | 80 | # dist b/t match center and prior's center 81 | g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] 82 | # encode variance 83 | g_cxcy /= (variances[0] * priors[:, 2:]) 84 | # match wh / prior wh 85 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 86 | g_wh = torch.log(g_wh) / variances[1] 87 | # return target for smooth_l1_loss 88 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 89 | 90 | 91 | def decode(loc, priors, variances): 92 | """Decode locations from predictions using priors to undo 93 | the encoding we did for offset regression at train time. 94 | Args: 95 | loc (tensor): location predictions for loc layers, 96 | Shape: [num_priors,4] 97 | priors (tensor): Prior boxes in center-offset form. 98 | Shape: [num_priors,4]. 99 | variances: (list[float]) Variances of priorboxes 100 | Return: 101 | decoded bounding box predictions 102 | """ 103 | 104 | boxes = torch.cat(( 105 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 106 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 107 | boxes[:, :2] -= boxes[:, 2:] / 2 108 | boxes[:, 2:] += boxes[:, :2] 109 | return boxes 110 | 111 | def batch_decode(loc, priors, variances): 112 | """Decode locations from predictions using priors to undo 113 | the encoding we did for offset regression at train time. 114 | Args: 115 | loc (tensor): location predictions for loc layers, 116 | Shape: [num_priors,4] 117 | priors (tensor): Prior boxes in center-offset form. 118 | Shape: [num_priors,4]. 119 | variances: (list[float]) Variances of priorboxes 120 | Return: 121 | decoded bounding box predictions 122 | """ 123 | 124 | boxes = torch.cat(( 125 | priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], 126 | priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2) 127 | boxes[:, :, :2] -= boxes[:, :, 2:] / 2 128 | boxes[:, :, 2:] += boxes[:, :, :2] 129 | return boxes 130 | -------------------------------------------------------------------------------- /wav2lip/audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | # import tensorflow as tf 5 | from scipy import signal 6 | from scipy.io import wavfile 7 | from .hparams import hparams as hp 8 | 9 | def load_wav(path, sr): 10 | return librosa.core.load(path, sr=sr)[0] 11 | 12 | def save_wav(wav, path, sr): 13 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 14 | #proposed by @dsmiller 15 | wavfile.write(path, sr, wav.astype(np.int16)) 16 | 17 | def save_wavenet_wav(wav, path, sr): 18 | librosa.output.write_wav(path, wav, sr=sr) 19 | 20 | def preemphasis(wav, k, preemphasize=True): 21 | if preemphasize: 22 | return signal.lfilter([1, -k], [1], wav) 23 | return wav 24 | 25 | def inv_preemphasis(wav, k, inv_preemphasize=True): 26 | if inv_preemphasize: 27 | return signal.lfilter([1], [1, -k], wav) 28 | return wav 29 | 30 | def get_hop_size(): 31 | hop_size = hp.hop_size 32 | if hop_size is None: 33 | assert hp.frame_shift_ms is not None 34 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 35 | return hop_size 36 | 37 | def linearspectrogram(wav): 38 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 39 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 40 | 41 | if hp.signal_normalization: 42 | return _normalize(S) 43 | return S 44 | 45 | def melspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | def _lws_processor(): 54 | import lws 55 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 56 | 57 | def _stft(y): 58 | if hp.use_lws: 59 | return _lws_processor(hp).stft(y).T 60 | else: 61 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 62 | 63 | ########################################################## 64 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 65 | def num_frames(length, fsize, fshift): 66 | """Compute number of time frames of spectrogram 67 | """ 68 | pad = (fsize - fshift) 69 | if length % fshift == 0: 70 | M = (length + pad * 2 - fsize) // fshift + 1 71 | else: 72 | M = (length + pad * 2 - fsize) // fshift + 2 73 | return M 74 | 75 | 76 | def pad_lr(x, fsize, fshift): 77 | """Compute left and right padding 78 | """ 79 | M = num_frames(len(x), fsize, fshift) 80 | pad = (fsize - fshift) 81 | T = len(x) + 2 * pad 82 | r = (M - 1) * fshift + fsize - T 83 | return pad, pad + r 84 | ########################################################## 85 | #Librosa correct padding 86 | def librosa_pad_lr(x, fsize, fshift): 87 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 88 | 89 | # Conversions 90 | _mel_basis = None 91 | 92 | def _linear_to_mel(spectogram): 93 | global _mel_basis 94 | if _mel_basis is None: 95 | _mel_basis = _build_mel_basis() 96 | return np.dot(_mel_basis, spectogram) 97 | 98 | 99 | def _build_mel_basis(): 100 | assert hp.fmax <= hp.sample_rate // 2 101 | return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, 102 | fmin=hp.fmin, fmax=hp.fmax) 103 | def _amp_to_db(x): 104 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 105 | return 20 * np.log10(np.maximum(min_level, x)) 106 | 107 | def _db_to_amp(x): 108 | return np.power(10.0, (x) * 0.05) 109 | 110 | def _normalize(S): 111 | if hp.allow_clipping_in_normalization: 112 | if hp.symmetric_mels: 113 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 114 | -hp.max_abs_value, hp.max_abs_value) 115 | else: 116 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 117 | 118 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 119 | if hp.symmetric_mels: 120 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 121 | else: 122 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 123 | 124 | def _denormalize(D): 125 | if hp.allow_clipping_in_normalization: 126 | if hp.symmetric_mels: 127 | return (((np.clip(D, -hp.max_abs_value, 128 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 129 | + hp.min_level_db) 130 | else: 131 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 132 | 133 | if hp.symmetric_mels: 134 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 135 | else: 136 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 137 | -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/core.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import glob 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import cv2 7 | 8 | 9 | class FaceDetector(object): 10 | """An abstract class representing a face detector. 11 | 12 | Any other face detection implementation must subclass it. All subclasses 13 | must implement ``detect_from_image``, that return a list of detected 14 | bounding boxes. Optionally, for speed considerations detect from path is 15 | recommended. 16 | """ 17 | 18 | def __init__(self, device, verbose): 19 | self.device = device 20 | self.verbose = verbose 21 | 22 | if verbose: 23 | if 'cpu' in device: 24 | logger = logging.getLogger(__name__) 25 | logger.warning("Detection running on CPU, this may be potentially slow.") 26 | 27 | if 'cpu' not in device and 'cuda' not in device: 28 | if verbose: 29 | logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) 30 | raise ValueError 31 | 32 | def detect_from_image(self, tensor_or_path): 33 | """Detects faces in a given image. 34 | 35 | This function detects the faces present in a provided BGR(usually) 36 | image. The input can be either the image itself or the path to it. 37 | 38 | Arguments: 39 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path 40 | to an image or the image itself. 41 | 42 | Example:: 43 | 44 | >>> path_to_image = 'data/image_01.jpg' 45 | ... detected_faces = detect_from_image(path_to_image) 46 | [A list of bounding boxes (x1, y1, x2, y2)] 47 | >>> image = cv2.imread(path_to_image) 48 | ... detected_faces = detect_from_image(image) 49 | [A list of bounding boxes (x1, y1, x2, y2)] 50 | 51 | """ 52 | raise NotImplementedError 53 | 54 | def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): 55 | """Detects faces from all the images present in a given directory. 56 | 57 | Arguments: 58 | path {string} -- a string containing a path that points to the folder containing the images 59 | 60 | Keyword Arguments: 61 | extensions {list} -- list of string containing the extensions to be 62 | consider in the following format: ``.extension_name`` (default: 63 | {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the 64 | folder recursively (default: {False}) show_progress_bar {bool} -- 65 | display a progressbar (default: {True}) 66 | 67 | Example: 68 | >>> directory = 'data' 69 | ... detected_faces = detect_from_directory(directory) 70 | {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} 71 | 72 | """ 73 | if self.verbose: 74 | logger = logging.getLogger(__name__) 75 | 76 | if len(extensions) == 0: 77 | if self.verbose: 78 | logger.error("Expected at list one extension, but none was received.") 79 | raise ValueError 80 | 81 | if self.verbose: 82 | logger.info("Constructing the list of images.") 83 | additional_pattern = '/**/*' if recursive else '/*' 84 | files = [] 85 | for extension in extensions: 86 | files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) 87 | 88 | if self.verbose: 89 | logger.info("Finished searching for images. %s images found", len(files)) 90 | logger.info("Preparing to run the detection.") 91 | 92 | predictions = {} 93 | for image_path in tqdm(files, disable=not show_progress_bar): 94 | if self.verbose: 95 | logger.info("Running the face detector on image: %s", image_path) 96 | predictions[image_path] = self.detect_from_image(image_path) 97 | 98 | if self.verbose: 99 | logger.info("The detector was successfully run on all %s images", len(files)) 100 | 101 | return predictions 102 | 103 | @property 104 | def reference_scale(self): 105 | raise NotImplementedError 106 | 107 | @property 108 | def reference_x_shift(self): 109 | raise NotImplementedError 110 | 111 | @property 112 | def reference_y_shift(self): 113 | raise NotImplementedError 114 | 115 | @staticmethod 116 | def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): 117 | """Convert path (represented as a string) or torch.tensor to a numpy.ndarray 118 | 119 | Arguments: 120 | tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself 121 | """ 122 | if isinstance(tensor_or_path, str): 123 | return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1] 124 | elif torch.is_tensor(tensor_or_path): 125 | # Call cpu in case its coming from cuda 126 | return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() 127 | elif isinstance(tensor_or_path, np.ndarray): 128 | return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path 129 | else: 130 | raise TypeError 131 | -------------------------------------------------------------------------------- /wav2lip/face_detection/detection/sfd/net_s3fd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L2Norm(nn.Module): 7 | def __init__(self, n_channels, scale=1.0): 8 | super(L2Norm, self).__init__() 9 | self.n_channels = n_channels 10 | self.scale = scale 11 | self.eps = 1e-10 12 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 13 | self.weight.data *= 0.0 14 | self.weight.data += self.scale 15 | 16 | def forward(self, x): 17 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps 18 | x = x / norm * self.weight.view(1, -1, 1, 1) 19 | return x 20 | 21 | 22 | class s3fd(nn.Module): 23 | def __init__(self): 24 | super(s3fd, self).__init__() 25 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 26 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 27 | 28 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 29 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 30 | 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 34 | 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 38 | 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 42 | 43 | self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) 44 | self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) 45 | 46 | self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 47 | self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) 48 | 49 | self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) 50 | self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) 51 | 52 | self.conv3_3_norm = L2Norm(256, scale=10) 53 | self.conv4_3_norm = L2Norm(512, scale=8) 54 | self.conv5_3_norm = L2Norm(512, scale=5) 55 | 56 | self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 57 | self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 58 | self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 59 | self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 60 | self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 61 | self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 62 | 63 | self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) 64 | self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) 65 | self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 66 | self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) 67 | self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) 68 | self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) 69 | 70 | def forward(self, x): 71 | h = F.relu(self.conv1_1(x)) 72 | h = F.relu(self.conv1_2(h)) 73 | h = F.max_pool2d(h, 2, 2) 74 | 75 | h = F.relu(self.conv2_1(h)) 76 | h = F.relu(self.conv2_2(h)) 77 | h = F.max_pool2d(h, 2, 2) 78 | 79 | h = F.relu(self.conv3_1(h)) 80 | h = F.relu(self.conv3_2(h)) 81 | h = F.relu(self.conv3_3(h)) 82 | f3_3 = h 83 | h = F.max_pool2d(h, 2, 2) 84 | 85 | h = F.relu(self.conv4_1(h)) 86 | h = F.relu(self.conv4_2(h)) 87 | h = F.relu(self.conv4_3(h)) 88 | f4_3 = h 89 | h = F.max_pool2d(h, 2, 2) 90 | 91 | h = F.relu(self.conv5_1(h)) 92 | h = F.relu(self.conv5_2(h)) 93 | h = F.relu(self.conv5_3(h)) 94 | f5_3 = h 95 | h = F.max_pool2d(h, 2, 2) 96 | 97 | h = F.relu(self.fc6(h)) 98 | h = F.relu(self.fc7(h)) 99 | ffc7 = h 100 | h = F.relu(self.conv6_1(h)) 101 | h = F.relu(self.conv6_2(h)) 102 | f6_2 = h 103 | h = F.relu(self.conv7_1(h)) 104 | h = F.relu(self.conv7_2(h)) 105 | f7_2 = h 106 | 107 | f3_3 = self.conv3_3_norm(f3_3) 108 | f4_3 = self.conv4_3_norm(f4_3) 109 | f5_3 = self.conv5_3_norm(f5_3) 110 | 111 | cls1 = self.conv3_3_norm_mbox_conf(f3_3) 112 | reg1 = self.conv3_3_norm_mbox_loc(f3_3) 113 | cls2 = self.conv4_3_norm_mbox_conf(f4_3) 114 | reg2 = self.conv4_3_norm_mbox_loc(f4_3) 115 | cls3 = self.conv5_3_norm_mbox_conf(f5_3) 116 | reg3 = self.conv5_3_norm_mbox_loc(f5_3) 117 | cls4 = self.fc7_mbox_conf(ffc7) 118 | reg4 = self.fc7_mbox_loc(ffc7) 119 | cls5 = self.conv6_2_mbox_conf(f6_2) 120 | reg5 = self.conv6_2_mbox_loc(f6_2) 121 | cls6 = self.conv7_2_mbox_conf(f7_2) 122 | reg6 = self.conv7_2_mbox_loc(f7_2) 123 | 124 | # max-out background label 125 | chunk = torch.chunk(cls1, 4, 1) 126 | bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) 127 | cls1 = torch.cat([bmax, chunk[3]], dim=1) 128 | 129 | return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] 130 | -------------------------------------------------------------------------------- /wav2lip/models/wav2lip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d 7 | 8 | class Wav2Lip(nn.Module): 9 | def __init__(self): 10 | super(Wav2Lip, self).__init__() 11 | 12 | self.face_encoder_blocks = nn.ModuleList([ 13 | nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 14 | 15 | nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 16 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 17 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), 18 | 19 | nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 20 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 22 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), 23 | 24 | nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 25 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 26 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), 27 | 28 | nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 29 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 30 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), 31 | 32 | nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 33 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 34 | 35 | nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 36 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 37 | 38 | self.audio_encoder = nn.Sequential( 39 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 40 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 41 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 42 | 43 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 44 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 45 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 46 | 47 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 48 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 49 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 50 | 51 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 52 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 53 | 54 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 55 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 56 | 57 | self.face_decoder_blocks = nn.ModuleList([ 58 | nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), 59 | 60 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 61 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 62 | 63 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), 64 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 65 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 66 | 67 | nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), 68 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), 69 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 70 | 71 | nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), 72 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 73 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 74 | 75 | nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), 76 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 77 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 78 | 79 | nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 80 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 81 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 82 | 83 | self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), 84 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), 85 | nn.Sigmoid()) 86 | 87 | def forward(self, audio_sequences, face_sequences): 88 | # audio_sequences = (B, T, 1, 80, 16) 89 | B = audio_sequences.size(0) 90 | 91 | input_dim_size = len(face_sequences.size()) 92 | if input_dim_size > 4: 93 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) 94 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 95 | 96 | audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 97 | 98 | feats = [] 99 | x = face_sequences 100 | for f in self.face_encoder_blocks: 101 | x = f(x) 102 | feats.append(x) 103 | 104 | x = audio_embedding 105 | for f in self.face_decoder_blocks: 106 | x = f(x) 107 | try: 108 | x = torch.cat((x, feats[-1]), dim=1) 109 | except Exception as e: 110 | print(x.size()) 111 | print(feats[-1].size()) 112 | raise e 113 | 114 | feats.pop() 115 | 116 | x = self.output_block(x) 117 | 118 | if input_dim_size > 4: 119 | x = torch.split(x, B, dim=0) # [(B, C, H, W)] 120 | outputs = torch.stack(x, dim=2) # (B, C, T, H, W) 121 | 122 | else: 123 | outputs = x 124 | 125 | return outputs 126 | 127 | class Wav2Lip_disc_qual(nn.Module): 128 | def __init__(self): 129 | super(Wav2Lip_disc_qual, self).__init__() 130 | 131 | self.face_encoder_blocks = nn.ModuleList([ 132 | nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 133 | 134 | nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 135 | nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), 136 | 137 | nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 138 | nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), 139 | 140 | nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 141 | nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), 142 | 143 | nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 144 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), 145 | 146 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 147 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), 148 | 149 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 150 | nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 151 | 152 | self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) 153 | self.label_noise = .0 154 | 155 | def get_lower_half(self, face_sequences): 156 | return face_sequences[:, :, face_sequences.size(2)//2:] 157 | 158 | def to_2d(self, face_sequences): 159 | B = face_sequences.size(0) 160 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 161 | return face_sequences 162 | 163 | def perceptual_forward(self, false_face_sequences): 164 | false_face_sequences = self.to_2d(false_face_sequences) 165 | false_face_sequences = self.get_lower_half(false_face_sequences) 166 | 167 | false_feats = false_face_sequences 168 | for f in self.face_encoder_blocks: 169 | false_feats = f(false_feats) 170 | 171 | false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), 172 | torch.ones((len(false_feats), 1)).cuda()) 173 | 174 | return false_pred_loss 175 | 176 | def forward(self, face_sequences): 177 | face_sequences = self.to_2d(face_sequences) 178 | face_sequences = self.get_lower_half(face_sequences) 179 | 180 | x = face_sequences 181 | for f in self.face_encoder_blocks: 182 | x = f(x) 183 | 184 | return self.binary_pred(x).view(len(x), -1) 185 | -------------------------------------------------------------------------------- /wav2lip/face_detection/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, 10 | stride=strd, padding=padding, bias=bias) 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes): 15 | super(ConvBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 18 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 19 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 20 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 21 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 22 | 23 | if in_planes != out_planes: 24 | self.downsample = nn.Sequential( 25 | nn.BatchNorm2d(in_planes), 26 | nn.ReLU(True), 27 | nn.Conv2d(in_planes, out_planes, 28 | kernel_size=1, stride=1, bias=False), 29 | ) 30 | else: 31 | self.downsample = None 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out1 = self.bn1(x) 37 | out1 = F.relu(out1, True) 38 | out1 = self.conv1(out1) 39 | 40 | out2 = self.bn2(out1) 41 | out2 = F.relu(out2, True) 42 | out2 = self.conv2(out2) 43 | 44 | out3 = self.bn3(out2) 45 | out3 = F.relu(out3, True) 46 | out3 = self.conv3(out3) 47 | 48 | out3 = torch.cat((out1, out2, out3), 1) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(residual) 52 | 53 | out3 += residual 54 | 55 | return out3 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class HourGlass(nn.Module): 99 | def __init__(self, num_modules, depth, num_features): 100 | super(HourGlass, self).__init__() 101 | self.num_modules = num_modules 102 | self.depth = depth 103 | self.features = num_features 104 | 105 | self._generate_network(self.depth) 106 | 107 | def _generate_network(self, level): 108 | self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) 109 | 110 | self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) 111 | 112 | if level > 1: 113 | self._generate_network(level - 1) 114 | else: 115 | self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) 116 | 117 | self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) 118 | 119 | def _forward(self, level, inp): 120 | # Upper branch 121 | up1 = inp 122 | up1 = self._modules['b1_' + str(level)](up1) 123 | 124 | # Lower branch 125 | low1 = F.avg_pool2d(inp, 2, stride=2) 126 | low1 = self._modules['b2_' + str(level)](low1) 127 | 128 | if level > 1: 129 | low2 = self._forward(level - 1, low1) 130 | else: 131 | low2 = low1 132 | low2 = self._modules['b2_plus_' + str(level)](low2) 133 | 134 | low3 = low2 135 | low3 = self._modules['b3_' + str(level)](low3) 136 | 137 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 138 | 139 | return up1 + up2 140 | 141 | def forward(self, x): 142 | return self._forward(self.depth, x) 143 | 144 | 145 | class FAN(nn.Module): 146 | 147 | def __init__(self, num_modules=1): 148 | super(FAN, self).__init__() 149 | self.num_modules = num_modules 150 | 151 | # Base part 152 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 153 | self.bn1 = nn.BatchNorm2d(64) 154 | self.conv2 = ConvBlock(64, 128) 155 | self.conv3 = ConvBlock(128, 128) 156 | self.conv4 = ConvBlock(128, 256) 157 | 158 | # Stacking part 159 | for hg_module in range(self.num_modules): 160 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) 161 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 162 | self.add_module('conv_last' + str(hg_module), 163 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 164 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 165 | self.add_module('l' + str(hg_module), nn.Conv2d(256, 166 | 68, kernel_size=1, stride=1, padding=0)) 167 | 168 | if hg_module < self.num_modules - 1: 169 | self.add_module( 170 | 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 171 | self.add_module('al' + str(hg_module), nn.Conv2d(68, 172 | 256, kernel_size=1, stride=1, padding=0)) 173 | 174 | def forward(self, x): 175 | x = F.relu(self.bn1(self.conv1(x)), True) 176 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 177 | x = self.conv3(x) 178 | x = self.conv4(x) 179 | 180 | previous = x 181 | 182 | outputs = [] 183 | for i in range(self.num_modules): 184 | hg = self._modules['m' + str(i)](previous) 185 | 186 | ll = hg 187 | ll = self._modules['top_m_' + str(i)](ll) 188 | 189 | ll = F.relu(self._modules['bn_end' + str(i)] 190 | (self._modules['conv_last' + str(i)](ll)), True) 191 | 192 | # Predict heatmaps 193 | tmp_out = self._modules['l' + str(i)](ll) 194 | outputs.append(tmp_out) 195 | 196 | if i < self.num_modules - 1: 197 | ll = self._modules['bl' + str(i)](ll) 198 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 199 | previous = previous + ll + tmp_out_ 200 | 201 | return outputs 202 | 203 | 204 | class ResNetDepth(nn.Module): 205 | 206 | def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): 207 | self.inplanes = 64 208 | super(ResNetDepth, self).__init__() 209 | self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, 210 | bias=False) 211 | self.bn1 = nn.BatchNorm2d(64) 212 | self.relu = nn.ReLU(inplace=True) 213 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 214 | self.layer1 = self._make_layer(block, 64, layers[0]) 215 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 216 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 217 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 218 | self.avgpool = nn.AvgPool2d(7) 219 | self.fc = nn.Linear(512 * block.expansion, num_classes) 220 | 221 | for m in self.modules(): 222 | if isinstance(m, nn.Conv2d): 223 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 224 | m.weight.data.normal_(0, math.sqrt(2. / n)) 225 | elif isinstance(m, nn.BatchNorm2d): 226 | m.weight.data.fill_(1) 227 | m.bias.data.zero_() 228 | 229 | def _make_layer(self, block, planes, blocks, stride=1): 230 | downsample = None 231 | if stride != 1 or self.inplanes != planes * block.expansion: 232 | downsample = nn.Sequential( 233 | nn.Conv2d(self.inplanes, planes * block.expansion, 234 | kernel_size=1, stride=stride, bias=False), 235 | nn.BatchNorm2d(planes * block.expansion), 236 | ) 237 | 238 | layers = [] 239 | layers.append(block(self.inplanes, planes, stride, downsample)) 240 | self.inplanes = planes * block.expansion 241 | for i in range(1, blocks): 242 | layers.append(block(self.inplanes, planes)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def forward(self, x): 247 | x = self.conv1(x) 248 | x = self.bn1(x) 249 | x = self.relu(x) 250 | x = self.maxpool(x) 251 | 252 | x = self.layer1(x) 253 | x = self.layer2(x) 254 | x = self.layer3(x) 255 | x = self.layer4(x) 256 | 257 | x = self.avgpool(x) 258 | x = x.view(x.size(0), -1) 259 | x = self.fc(x) 260 | 261 | return x 262 | -------------------------------------------------------------------------------- /wav2lip/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import argparse 5 | import subprocess 6 | from tqdm import tqdm 7 | from .audio import load_wav, melspectrogram 8 | from .face_detection import FaceAlignment,LandmarksType 9 | import torch 10 | import platform 11 | 12 | parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models') 13 | 14 | parser.add_argument('--outfile', type=str, help='Video path to save result. See default for an e.g.', 15 | default='wav2lip/results/result_voice.mp4') 16 | 17 | parser.add_argument('--static', type=bool, 18 | help='If True, then use only first video frame for inference', default=False) 19 | parser.add_argument('--fps', type=float, help='Can be specified only if input is a static image (default: 25)', 20 | default=25., required=False) 21 | 22 | parser.add_argument('--pads', nargs='+', type=int, default=[0, 10, 0, 0], 23 | help='Padding (top, bottom, left, right). Please adjust to include chin at least') 24 | 25 | parser.add_argument('--face_det_batch_size', type=int, 26 | help='Batch size for face detection', default=16) 27 | parser.add_argument('--wav2lip_batch_size', type=int, help='Batch size for Wav2Lip model(s)', default=128) 28 | 29 | parser.add_argument('--resize_factor', default=1, type=int, 30 | help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') 31 | 32 | parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], 33 | help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' 34 | 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') 35 | 36 | parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], 37 | help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' 38 | 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') 39 | 40 | parser.add_argument('--rotate', default=False, action='store_true', 41 | help='Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg.' 42 | 'Use if you get a flipped result, despite feeding a normal looking video') 43 | 44 | parser.add_argument('--nosmooth', default=False, action='store_true', 45 | help='Prevent smoothing face detections over a short temporal window') 46 | 47 | args = parser.parse_args() 48 | args.img_size = 96 49 | 50 | 51 | def get_smoothened_boxes(boxes, T): 52 | for i in range(len(boxes)): 53 | if i + T > len(boxes): 54 | window = boxes[len(boxes) - T:] 55 | else: 56 | window = boxes[i : i + T] 57 | boxes[i] = np.mean(window, axis=0) 58 | return boxes 59 | 60 | def face_detect(images): 61 | detector = FaceAlignment(LandmarksType._2D, 62 | flip_input=False, device=device) 63 | 64 | batch_size = args.face_det_batch_size 65 | 66 | while 1: 67 | predictions = [] 68 | try: 69 | for i in tqdm(range(0, len(images), batch_size)): 70 | predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) 71 | except RuntimeError: 72 | if batch_size == 1: 73 | raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument') 74 | batch_size //= 2 75 | print('Recovering from OOM error; New batch size: {}'.format(batch_size)) 76 | continue 77 | break 78 | 79 | results = [] 80 | pady1, pady2, padx1, padx2 = args.pads 81 | for rect, image in zip(predictions, images): 82 | if rect is None: 83 | cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. 84 | raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') 85 | 86 | y1 = max(0, rect[1] - pady1) 87 | y2 = min(image.shape[0], rect[3] + pady2) 88 | x1 = max(0, rect[0] - padx1) 89 | x2 = min(image.shape[1], rect[2] + padx2) 90 | 91 | results.append([x1, y1, x2, y2]) 92 | 93 | boxes = np.array(results) 94 | if not args.nosmooth: boxes = get_smoothened_boxes(boxes, T=5) 95 | results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] 96 | 97 | del detector 98 | return results 99 | 100 | def datagen(frames, mels): 101 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 102 | 103 | if args.box[0] == -1: 104 | if not args.static: 105 | face_det_results = face_detect(frames) # BGR2RGB for CNN face detection 106 | else: 107 | face_det_results = face_detect([frames[0]]) 108 | else: 109 | print('Using the specified bounding box instead of face detection...') 110 | y1, y2, x1, x2 = args.box 111 | face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] 112 | 113 | for i, m in enumerate(mels): 114 | idx = 0 if args.static else i%len(frames) 115 | frame_to_save = frames[idx].copy() 116 | face, coords = face_det_results[idx].copy() 117 | 118 | face = cv2.resize(face, (args.img_size, args.img_size)) 119 | 120 | img_batch.append(face) 121 | mel_batch.append(m) 122 | frame_batch.append(frame_to_save) 123 | coords_batch.append(coords) 124 | 125 | if len(img_batch) >= args.wav2lip_batch_size: 126 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 127 | 128 | img_masked = img_batch.copy() 129 | img_masked[:, args.img_size//2:] = 0 130 | 131 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 132 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 133 | 134 | yield img_batch, mel_batch, frame_batch, coords_batch 135 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 136 | 137 | if len(img_batch) > 0: 138 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 139 | 140 | img_masked = img_batch.copy() 141 | img_masked[:, args.img_size//2:] = 0 142 | 143 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. 144 | mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) 145 | 146 | yield img_batch, mel_batch, frame_batch, coords_batch 147 | 148 | mel_step_size = 16 149 | device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu' 150 | print('Using {} for inference.'.format(device)) 151 | 152 | def _load(checkpoint_path): 153 | if device == 'cuda': 154 | checkpoint = torch.load(checkpoint_path) 155 | else: 156 | checkpoint = torch.load(checkpoint_path, 157 | map_location=lambda storage, loc: storage) 158 | return checkpoint 159 | 160 | 161 | def main(face,audio,model): 162 | if not os.path.isfile(face): 163 | raise ValueError('--face argument must be a valid path to video/image file') 164 | 165 | elif face.split('.')[1] in ['jpg', 'png', 'jpeg']: 166 | full_frames = [cv2.imread(face)] 167 | fps = args.fps 168 | 169 | else: 170 | video_stream = cv2.VideoCapture(face) 171 | fps = video_stream.get(cv2.CAP_PROP_FPS) 172 | 173 | print('Reading video frames...') 174 | 175 | full_frames = [] 176 | while 1: 177 | still_reading, frame = video_stream.read() 178 | if not still_reading: 179 | video_stream.release() 180 | break 181 | if args.resize_factor > 1: 182 | frame = cv2.resize(frame, (frame.shape[1]//args.resize_factor, frame.shape[0]//args.resize_factor)) 183 | 184 | if args.rotate: 185 | frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) 186 | 187 | y1, y2, x1, x2 = args.crop 188 | if x2 == -1: x2 = frame.shape[1] 189 | if y2 == -1: y2 = frame.shape[0] 190 | 191 | frame = frame[y1:y2, x1:x2] 192 | 193 | full_frames.append(frame) 194 | 195 | print ("Number of frames available for inference: "+str(len(full_frames))) 196 | 197 | if not audio.endswith('.wav'): 198 | print('Extracting raw audio...') 199 | command = 'ffmpeg -y -i {} -strict -2 {}'.format(audio, 'temp/temp.wav') 200 | 201 | subprocess.call(command, shell=True) 202 | audio = 'temp/temp.wav' 203 | 204 | wav = load_wav(audio, 16000) 205 | mel = melspectrogram(wav) 206 | print(mel.shape) 207 | 208 | if np.isnan(mel.reshape(-1)).sum() > 0: 209 | raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 210 | 211 | mel_chunks = [] 212 | mel_idx_multiplier = 80./fps 213 | i = 0 214 | while 1: 215 | start_idx = int(i * mel_idx_multiplier) 216 | if start_idx + mel_step_size > len(mel[0]): 217 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) 218 | break 219 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 220 | i += 1 221 | 222 | print("Length of mel chunks: {}".format(len(mel_chunks))) 223 | 224 | full_frames = full_frames[:len(mel_chunks)] 225 | 226 | batch_size = args.wav2lip_batch_size 227 | gen = datagen(full_frames.copy(), mel_chunks) 228 | 229 | for i, (img_batch, mel_batch, frames, coords) in enumerate(tqdm(gen, 230 | total=int(np.ceil(float(len(mel_chunks))/batch_size)))): 231 | if i == 0: 232 | #model = load_model(checkpoint_path) 233 | print ("Model loaded") 234 | 235 | frame_h, frame_w = full_frames[0].shape[:-1] 236 | out = cv2.VideoWriter('wav2lip/temp/result.avi', 237 | cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_w, frame_h)) 238 | 239 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) 240 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) 241 | 242 | with torch.no_grad(): 243 | pred = model(mel_batch, img_batch) 244 | 245 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. 246 | 247 | for p, f, c in zip(pred, frames, coords): 248 | y1, y2, x1, x2 = c 249 | p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) 250 | 251 | f[y1:y2, x1:x2] = p 252 | out.write(f) 253 | 254 | out.release() 255 | 256 | command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio, 'wav2lip/temp/result.avi', args.outfile) 257 | subprocess.call(command, shell=platform.system() != 'Windows') 258 | print("done :)") 259 | -------------------------------------------------------------------------------- /wav2lip/face_detection/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import torch 6 | import math 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def _gaussian( 12 | size=3, sigma=0.25, amplitude=1, normalize=False, width=None, 13 | height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, 14 | mean_vert=0.5): 15 | # handle some defaults 16 | if width is None: 17 | width = size 18 | if height is None: 19 | height = size 20 | if sigma_horz is None: 21 | sigma_horz = sigma 22 | if sigma_vert is None: 23 | sigma_vert = sigma 24 | center_x = mean_horz * width + 0.5 25 | center_y = mean_vert * height + 0.5 26 | gauss = np.empty((height, width), dtype=np.float32) 27 | # generate kernel 28 | for i in range(height): 29 | for j in range(width): 30 | gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( 31 | sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) 32 | if normalize: 33 | gauss = gauss / np.sum(gauss) 34 | return gauss 35 | 36 | 37 | def draw_gaussian(image, point, sigma): 38 | # Check if the gaussian is inside 39 | ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] 40 | br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] 41 | if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): 42 | return image 43 | size = 6 * sigma + 1 44 | g = _gaussian(size) 45 | g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] 46 | g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] 47 | img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] 48 | img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] 49 | assert (g_x[0] > 0 and g_y[1] > 0) 50 | image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] 51 | ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] 52 | image[image > 1] = 1 53 | return image 54 | 55 | 56 | def transform(point, center, scale, resolution, invert=False): 57 | """Generate and affine transformation matrix. 58 | 59 | Given a set of points, a center, a scale and a targer resolution, the 60 | function generates and affine transformation matrix. If invert is ``True`` 61 | it will produce the inverse transformation. 62 | 63 | Arguments: 64 | point {torch.tensor} -- the input 2D point 65 | center {torch.tensor or numpy.array} -- the center around which to perform the transformations 66 | scale {float} -- the scale of the face/object 67 | resolution {float} -- the output resolution 68 | 69 | Keyword Arguments: 70 | invert {bool} -- define wherever the function should produce the direct or the 71 | inverse transformation matrix (default: {False}) 72 | """ 73 | _pt = torch.ones(3) 74 | _pt[0] = point[0] 75 | _pt[1] = point[1] 76 | 77 | h = 200.0 * scale 78 | t = torch.eye(3) 79 | t[0, 0] = resolution / h 80 | t[1, 1] = resolution / h 81 | t[0, 2] = resolution * (-center[0] / h + 0.5) 82 | t[1, 2] = resolution * (-center[1] / h + 0.5) 83 | 84 | if invert: 85 | t = torch.inverse(t) 86 | 87 | new_point = (torch.matmul(t, _pt))[0:2] 88 | 89 | return new_point.int() 90 | 91 | 92 | def crop(image, center, scale, resolution=256.0): 93 | """Center crops an image or set of heatmaps 94 | 95 | Arguments: 96 | image {numpy.array} -- an rgb image 97 | center {numpy.array} -- the center of the object, usually the same as of the bounding box 98 | scale {float} -- scale of the face 99 | 100 | Keyword Arguments: 101 | resolution {float} -- the size of the output cropped image (default: {256.0}) 102 | 103 | Returns: 104 | [type] -- [description] 105 | """ # Crop around the center point 106 | """ Crops the image around the center. Input is expected to be an np.ndarray """ 107 | ul = transform([1, 1], center, scale, resolution, True) 108 | br = transform([resolution, resolution], center, scale, resolution, True) 109 | # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) 110 | if image.ndim > 2: 111 | newDim = np.array([br[1] - ul[1], br[0] - ul[0], 112 | image.shape[2]], dtype=np.int32) 113 | newImg = np.zeros(newDim, dtype=np.uint8) 114 | else: 115 | newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) 116 | newImg = np.zeros(newDim, dtype=np.uint8) 117 | ht = image.shape[0] 118 | wd = image.shape[1] 119 | newX = np.array( 120 | [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) 121 | newY = np.array( 122 | [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) 123 | oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) 124 | oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) 125 | newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] 126 | ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] 127 | newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), 128 | interpolation=cv2.INTER_LINEAR) 129 | return newImg 130 | 131 | 132 | def get_preds_fromhm(hm, center=None, scale=None): 133 | """Obtain (x,y) coordinates given a set of N heatmaps. If the center 134 | and the scale is provided the function will return the points also in 135 | the original coordinate frame. 136 | 137 | Arguments: 138 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 139 | 140 | Keyword Arguments: 141 | center {torch.tensor} -- the center of the bounding box (default: {None}) 142 | scale {float} -- face scale (default: {None}) 143 | """ 144 | max, idx = torch.max( 145 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 146 | idx += 1 147 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 148 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 149 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 150 | 151 | for i in range(preds.size(0)): 152 | for j in range(preds.size(1)): 153 | hm_ = hm[i, j, :] 154 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 155 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 156 | diff = torch.FloatTensor( 157 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 158 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 159 | preds[i, j].add_(diff.sign_().mul_(.25)) 160 | 161 | preds.add_(-.5) 162 | 163 | preds_orig = torch.zeros(preds.size()) 164 | if center is not None and scale is not None: 165 | for i in range(hm.size(0)): 166 | for j in range(hm.size(1)): 167 | preds_orig[i, j] = transform( 168 | preds[i, j], center, scale, hm.size(2), True) 169 | 170 | return preds, preds_orig 171 | 172 | def get_preds_fromhm_batch(hm, centers=None, scales=None): 173 | """Obtain (x,y) coordinates given a set of N heatmaps. If the centers 174 | and the scales is provided the function will return the points also in 175 | the original coordinate frame. 176 | 177 | Arguments: 178 | hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] 179 | 180 | Keyword Arguments: 181 | centers {torch.tensor} -- the centers of the bounding box (default: {None}) 182 | scales {float} -- face scales (default: {None}) 183 | """ 184 | max, idx = torch.max( 185 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 186 | idx += 1 187 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 188 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 189 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 190 | 191 | for i in range(preds.size(0)): 192 | for j in range(preds.size(1)): 193 | hm_ = hm[i, j, :] 194 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 195 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 196 | diff = torch.FloatTensor( 197 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 198 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 199 | preds[i, j].add_(diff.sign_().mul_(.25)) 200 | 201 | preds.add_(-.5) 202 | 203 | preds_orig = torch.zeros(preds.size()) 204 | if centers is not None and scales is not None: 205 | for i in range(hm.size(0)): 206 | for j in range(hm.size(1)): 207 | preds_orig[i, j] = transform( 208 | preds[i, j], centers[i], scales[i], hm.size(2), True) 209 | 210 | return preds, preds_orig 211 | 212 | def shuffle_lr(parts, pairs=None): 213 | """Shuffle the points left-right according to the axis of symmetry 214 | of the object. 215 | 216 | Arguments: 217 | parts {torch.tensor} -- a 3D or 4D object containing the 218 | heatmaps. 219 | 220 | Keyword Arguments: 221 | pairs {list of integers} -- [order of the flipped points] (default: {None}) 222 | """ 223 | if pairs is None: 224 | pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 225 | 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, 226 | 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, 227 | 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, 228 | 62, 61, 60, 67, 66, 65] 229 | if parts.ndimension() == 3: 230 | parts = parts[pairs, ...] 231 | else: 232 | parts = parts[:, pairs, ...] 233 | 234 | return parts 235 | 236 | 237 | def flip(tensor, is_label=False): 238 | """Flip an image or a set of heatmaps left-right 239 | 240 | Arguments: 241 | tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] 242 | 243 | Keyword Arguments: 244 | is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) 245 | """ 246 | if not torch.is_tensor(tensor): 247 | tensor = torch.from_numpy(tensor) 248 | 249 | if is_label: 250 | tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) 251 | else: 252 | tensor = tensor.flip(tensor.ndimension() - 1) 253 | 254 | return tensor 255 | 256 | # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) 257 | 258 | 259 | def appdata_dir(appname=None, roaming=False): 260 | """ appdata_dir(appname=None, roaming=False) 261 | 262 | Get the path to the application directory, where applications are allowed 263 | to write user specific files (e.g. configurations). For non-user specific 264 | data, consider using common_appdata_dir(). 265 | If appname is given, a subdir is appended (and created if necessary). 266 | If roaming is True, will prefer a roaming directory (Windows Vista/7). 267 | """ 268 | 269 | # Define default user directory 270 | userDir = os.getenv('FACEALIGNMENT_USERDIR', None) 271 | if userDir is None: 272 | userDir = os.path.expanduser('~') 273 | if not os.path.isdir(userDir): # pragma: no cover 274 | userDir = '/var/tmp' # issue #54 275 | 276 | # Get system app data dir 277 | path = None 278 | if sys.platform.startswith('win'): 279 | path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') 280 | path = (path2 or path1) if roaming else (path1 or path2) 281 | elif sys.platform.startswith('darwin'): 282 | path = os.path.join(userDir, 'Library', 'Application Support') 283 | # On Linux and as fallback 284 | if not (path and os.path.isdir(path)): 285 | path = userDir 286 | 287 | # Maybe we should store things local to the executable (in case of a 288 | # portable distro or a frozen application that wants to be portable) 289 | prefix = sys.prefix 290 | if getattr(sys, 'frozen', None): 291 | prefix = os.path.abspath(os.path.dirname(sys.executable)) 292 | for reldir in ('settings', '../settings'): 293 | localpath = os.path.abspath(os.path.join(prefix, reldir)) 294 | if os.path.isdir(localpath): # pragma: no cover 295 | try: 296 | open(os.path.join(localpath, 'test.write'), 'wb').close() 297 | os.remove(os.path.join(localpath, 'test.write')) 298 | except IOError: 299 | pass # We cannot write in this directory 300 | else: 301 | path = localpath 302 | break 303 | 304 | # Get path specific for this app 305 | if appname: 306 | if path == userDir: 307 | appname = '.' + appname.lstrip('.') # Make it a hidden directory 308 | path = os.path.join(path, appname) 309 | if not os.path.isdir(path): # pragma: no cover 310 | os.mkdir(path) 311 | 312 | # Done 313 | return path 314 | --------------------------------------------------------------------------------