├── .gitattributes ├── .gitignore ├── README.md ├── example.gif ├── example ├── audio.wav └── image.bmp ├── sda ├── .gitignore ├── __init__.py ├── data │ ├── crema.dat │ ├── grid.dat │ └── timit.dat ├── encoder_audio.py ├── encoder_image.py ├── img_generator.py ├── rnn_audio.py ├── sda.py └── utils.py └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | #*.dat filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speech-Driven Animation 2 | 3 | This library implements the end-to-end facial synthesis model described in this [paper](https://sites.google.com/view/facialsynthesis/home). 4 | 5 | This library is maintained by Konstantinos Vougioukas, Honglie Chen and Pingchuan Ma. 6 | 7 | ![speech-driven-animation](example.gif) 8 | 9 | ## Downloading the models 10 | The models were hosted on git LFS. However the demand was so high that I reached the quota for free gitLFS storage. I have moved the models to GoogleDrive. Models can be found [here](https://drive.google.com/drive/folders/17Dc2keVoNSrlrOdLL3kXdM8wjb20zkbF?usp=sharing). 11 | Place the model file(s) under *`sda/data/`* 12 | 13 | ## Installing 14 | 15 | To install the library do: 16 | ``` 17 | $ pip install . 18 | ``` 19 | 20 | ## Running the example 21 | 22 | To create the animations you will need to instantiate the VideoAnimator class. Then you provide an image and audio clip (or the paths to the files) and a video will be produced. 23 | 24 | 25 | ## Choosing the model 26 | The model has been trained on the GRID, TCD-TIMIT, CREMA-D and LRW datasets. The default model is GRID. To load another pretrained model simply instantiate the VideoAnimator with the following arguments: 27 | 28 | ``` 29 | import sda 30 | va = sda.VideoAnimator(gpu=0, model_path="crema") # Instantiate the animator 31 | ``` 32 | 33 | The models that are currently uploaded are: 34 | - [x] GRID 35 | - [x] TIMIT 36 | - [x] CREMA 37 | - [ ] LRW 38 | 39 | 40 | ### Example with image and audio paths 41 | ``` 42 | import sda 43 | va = sda.VideoAnimator(gpu=0) # Instantiate the animator 44 | vid, aud = va("example/image.bmp", "example/audio.wav") 45 | ``` 46 | 47 | ### Example with numpy arrays 48 | ``` 49 | import sda 50 | import scipy.io.wavfile as wav 51 | from PIL import Image 52 | 53 | va = sda.VideoAnimator(gpu=0) # Instantiate the animator 54 | fs, audio_clip = wav.read("example/audio.wav") 55 | still_frame = Image.open("example/image.bmp") 56 | vid, aud = va(frame, audio_clip, fs=fs) 57 | ``` 58 | 59 | ### Saving video with audio 60 | ``` 61 | va.save_video(vid, aud, "generated.mp4") 62 | ``` 63 | 64 | ## Using the encodings 65 | The encoders for audio and video are made available so that they can be used to produce features for classification tasks. 66 | 67 | ### Audio encoder 68 | The Audio encoder (which is made of Audio-frame encoder and RNN) is provided along with a dictionary which has information such as the feature length (in seconds) required by the Audio Frame encoder and the overlap between audio frames. 69 | ``` 70 | import sda 71 | encoder, info = sda.get_audio_feature_extractor(gpu=0) 72 | ``` 73 | 74 | ## Citation 75 | 76 | If you find this code useful in your research, please consider to cite the following papers: 77 | 78 | ```bibtex 79 | @inproceedings{vougioukas2019end, 80 | title={End-to-End Speech-Driven Realistic Facial Animation with Temporal GANs.}, 81 | author={Vougioukas, Konstantinos and Petridis, Stavros and Pantic, Maja}, 82 | booktitle={CVPR Workshops}, 83 | pages={37--40}, 84 | year={2019} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DinoMan/speech-driven-animation/d85adfeba6459edbd2c3ad12656baebc59e13c74/example.gif -------------------------------------------------------------------------------- /example/audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DinoMan/speech-driven-animation/d85adfeba6459edbd2c3ad12656baebc59e13c74/example/audio.wav -------------------------------------------------------------------------------- /example/image.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DinoMan/speech-driven-animation/d85adfeba6459edbd2c3ad12656baebc59e13c74/example/image.bmp -------------------------------------------------------------------------------- /sda/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /sda/__init__.py: -------------------------------------------------------------------------------- 1 | from .sda import VideoAnimator, get_audio_feature_extractor, cut_audio_sequence 2 | -------------------------------------------------------------------------------- /sda/data/crema.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bd81f9c671f510c683c610c211522de1c5c675e77ffca659a7abc8d5535156a7 3 | size 221171678 4 | -------------------------------------------------------------------------------- /sda/data/grid.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:cfd703ecb1bbbb2a81ea62a0df948627a6e4e6481e9c444fccb4a41701867ca7 3 | size 221179358 4 | -------------------------------------------------------------------------------- /sda/data/timit.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e725836f8da8ea1951bbeff885839b273f503a39f0d4bf75f4686f16844ad94b 3 | size 221181916 4 | -------------------------------------------------------------------------------- /sda/encoder_audio.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from .utils import calculate_padding, prime_factors, calculate_output_size 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, code_size, rate, feat_length, init_kernel=None, init_stride=None, num_feature_maps=16, 7 | increasing_stride=True): 8 | super(Encoder, self).__init__() 9 | 10 | self.code_size = code_size 11 | self.cl = nn.ModuleList() 12 | self.activations = nn.ModuleList() 13 | self.strides = [] 14 | self.kernels = [] 15 | 16 | features = feat_length * rate 17 | strides = prime_factors(features) 18 | kernels = [2 * s for s in strides] 19 | 20 | if init_kernel is not None and init_stride is not None: 21 | self.strides.append(int(init_stride * rate)) 22 | self.kernels.append(int(init_kernel * rate)) 23 | padding = calculate_padding(init_kernel * rate, stride=init_stride * rate, in_size=features) 24 | init_features = calculate_output_size(features, init_kernel * rate, stride=init_stride * rate, 25 | padding=padding) 26 | strides = prime_factors(init_features) 27 | kernels = [2 * s for s in strides] 28 | 29 | if not increasing_stride: 30 | strides.reverse() 31 | kernels.reverse() 32 | 33 | self.strides.extend(strides) 34 | self.kernels.extend(kernels) 35 | 36 | for i in range(len(self.strides) - 1): 37 | padding = calculate_padding(self.kernels[i], stride=self.strides[i], in_size=features) 38 | features = calculate_output_size(features, self.kernels[i], stride=self.strides[i], padding=padding) 39 | pad = int(math.ceil(padding / 2.0)) 40 | 41 | if i == 0: 42 | self.cl.append( 43 | nn.Conv1d(1, num_feature_maps, self.kernels[i], stride=self.strides[i], padding=pad)) 44 | self.activations.append(nn.Sequential(nn.BatchNorm1d(num_feature_maps), nn.ReLU(True))) 45 | else: 46 | self.cl.append(nn.Conv1d(num_feature_maps, 2 * num_feature_maps, self.kernels[i], 47 | stride=self.strides[i], padding=pad)) 48 | self.activations.append(nn.Sequential(nn.BatchNorm1d(2 * num_feature_maps), nn.ReLU(True))) 49 | 50 | num_feature_maps *= 2 51 | 52 | self.cl.append(nn.Conv1d(num_feature_maps, self.code_size, features)) 53 | self.activations.append(nn.Tanh()) 54 | 55 | def forward(self, x): 56 | for i in range(len(self.strides)): 57 | x = self.cl[i](x) 58 | x = self.activations[i](x) 59 | 60 | return x.squeeze() 61 | -------------------------------------------------------------------------------- /sda/encoder_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from .utils import calculate_padding, is_power2 4 | 5 | class Encoder(nn.Module): 6 | def __init__(self, code_size, img_size, kernel_size=4, num_input_channels=3, num_feature_maps=64, batch_norm=True): 7 | super(Encoder, self).__init__() 8 | 9 | # Get the dimension which is a power of 2 10 | if is_power2(max(img_size)): 11 | stable_dim = max(img_size) 12 | else: 13 | stable_dim = min(img_size) 14 | 15 | if isinstance(img_size, tuple): 16 | self.img_size = img_size 17 | self.final_size = tuple(int(4 * x // stable_dim) for x in self.img_size) 18 | else: 19 | self.img_size = (img_size, img_size) 20 | self.final_size = (4, 4) 21 | 22 | self.code_size = code_size 23 | self.num_feature_maps = num_feature_maps 24 | self.cl = nn.ModuleList() 25 | self.num_layers = int(np.log2(max(self.img_size))) - 2 26 | 27 | stride = 2 28 | # This ensures that we have same padding no matter if we have even or odd kernels 29 | padding = calculate_padding(kernel_size, stride) 30 | 31 | if batch_norm: 32 | self.cl.append(nn.Sequential( 33 | nn.Conv2d(num_input_channels, self.num_feature_maps, kernel_size, stride=stride, padding=padding // 2, 34 | bias=False), 35 | nn.BatchNorm2d(self.num_feature_maps), 36 | nn.ReLU(inplace=True))) 37 | else: 38 | self.cl.append(nn.Sequential( 39 | nn.Conv2d(num_input_channels, self.num_feature_maps, kernel_size, stride=stride, padding=padding // 2, 40 | bias=False), 41 | nn.ReLU(inplace=True))) 42 | 43 | self.channels = [self.num_feature_maps] 44 | for i in range(self.num_layers - 1): 45 | 46 | if batch_norm: 47 | self.cl.append(nn.Sequential( 48 | nn.Conv2d(self.channels[-1], self.channels[-1] * 2, kernel_size, stride=stride, 49 | padding=padding // 2, 50 | bias=False), 51 | nn.BatchNorm2d(self.channels[-1] * 2), 52 | nn.ReLU(inplace=True))) 53 | else: 54 | self.cl.append(nn.Sequential( 55 | nn.Conv2d(self.channels[-1], self.channels[-1] * 2, kernel_size, stride=stride, 56 | padding=padding // 2, bias=False), 57 | nn.ReLU(inplace=True))) 58 | 59 | self.channels.append(2 * self.channels[-1]) 60 | 61 | self.cl.append(nn.Sequential( 62 | nn.Conv2d(self.channels[-1], code_size, self.final_size, stride=1, padding=0, bias=False), 63 | nn.Tanh())) 64 | 65 | def forward(self, x, retain_intermediate=False): 66 | if retain_intermediate: 67 | h = [x] 68 | for conv_layer in self.cl: 69 | h.append(conv_layer(h[-1])) 70 | return h[-1].view(-1, self.code_size), h[1:-1] 71 | else: 72 | for conv_layer in self.cl: 73 | x = conv_layer(x) 74 | 75 | return x.view(-1, self.code_size) 76 | -------------------------------------------------------------------------------- /sda/img_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from .utils import calculate_padding 5 | 6 | 7 | class Deconv(nn.Module): 8 | def __init__(self, in_channels, out_channels, in_size, kernel_size, stride=1, batch_norm=True): 9 | super(Deconv, self).__init__() 10 | # This ensures that we have same padding no matter if we have even or odd kernels 11 | padding = calculate_padding(kernel_size, stride) 12 | self.dcl = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding // 2, 13 | bias=False) 14 | 15 | if batch_norm: 16 | self.activation = nn.Sequential(nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) 17 | else: 18 | self.activation = nn.ReLU(inplace=True) 19 | 20 | self.required_channels = out_channels 21 | self.out_size_required = tuple(x * stride for x in in_size) 22 | 23 | def forward(self, x): 24 | x = self.dcl(x, 25 | output_size=[-1, self.required_channels, self.out_size_required[0], self.out_size_required[1]]) 26 | 27 | return self.activation(x) 28 | 29 | 30 | class UnetBlock(nn.Module): 31 | def __init__(self, in_channels, out_channels, skip_channels, in_size, kernel_size, stride=1, batch_norm=True): 32 | super(UnetBlock, self).__init__() 33 | # This ensures that we have same padding no matter if we have even or odd kernels 34 | padding = calculate_padding(kernel_size, stride) 35 | self.dcl1 = nn.ConvTranspose2d(in_channels + skip_channels, in_channels, 3, padding=1, bias=False) 36 | self.dcl2 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 37 | padding=padding // 2, bias=False) 38 | if batch_norm: 39 | self.activation1 = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True)) 40 | self.activation2 = nn.Sequential(nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) 41 | else: 42 | self.activation1 = nn.ReLU(inplace=True) 43 | self.activation2 = nn.ReLU(inplace=True) 44 | 45 | self.required_channels = out_channels 46 | self.out_size_required = tuple(x * stride for x in in_size) 47 | 48 | def forward(self, x, s): 49 | s = s.view(x.size()) 50 | 51 | x = torch.cat([x, s], 1) 52 | 53 | x = self.dcl1(x) 54 | x = self.activation1(x) 55 | 56 | x = self.dcl2(x, output_size=[-1, self.required_channels, self.out_size_required[0], self.out_size_required[1]]) 57 | x = self.activation2(x) 58 | return x 59 | 60 | 61 | class Generator(nn.Module): 62 | def __init__(self, img_size, latent_size, condition_size=0, aux_size=0, kernel_size=4, num_channels=3, 63 | num_gen_channels=1024, skip_channels=[], batch_norm=True, sequential_noise=False, 64 | aux_only_on_top=False): 65 | super(Generator, self).__init__() 66 | # If we have a tuple make sure we maintain the aspect ratio 67 | if isinstance(img_size, tuple): 68 | self.img_size = img_size 69 | self.init_size = tuple(int(4 * x / max(img_size)) for x in self.img_size) 70 | else: 71 | self.img_size = (img_size, img_size) 72 | self.init_size = (4, 4) 73 | 74 | self.latent_size = latent_size 75 | self.condition_size = condition_size 76 | self.aux_size = aux_size 77 | 78 | self.rnn_noise = None 79 | if self.aux_size > 0 and sequential_noise: 80 | self.rnn_noise = nn.GRU(self.aux_size, self.aux_size, batch_first=True) 81 | self.rnn_noise_squashing = nn.Tanh() 82 | 83 | self.num_layers = int(np.log2(max(self.img_size))) - 1 84 | self.num_channels = num_channels 85 | self.num_gen_channels = num_gen_channels 86 | 87 | self.dcl = nn.ModuleList() 88 | 89 | self.aux_only_on_top = aux_only_on_top 90 | self.total_latent_size = self.latent_size + self.condition_size 91 | 92 | if self.aux_size > 0 and self.aux_only_on_top: 93 | self.aux_dcl = nn.Sequential( 94 | nn.ConvTranspose2d(self.aux_size, num_gen_channels, (self.init_size[0] // 2, self.init_size[1]), 95 | bias=False), 96 | nn.BatchNorm2d(num_gen_channels), 97 | nn.ReLU(inplace=True), 98 | nn.ConstantPad2d((0, 0, 0, self.init_size[0] // 2), 0)) 99 | else: 100 | self.total_latent_size += self.aux_size 101 | 102 | stride = 2 103 | if batch_norm: 104 | self.dcl.append( 105 | nn.Sequential( 106 | nn.ConvTranspose2d(self.total_latent_size, num_gen_channels, self.init_size, bias=False), 107 | nn.BatchNorm2d(num_gen_channels), 108 | nn.ReLU(inplace=True))) 109 | else: 110 | self.dcl.append( 111 | nn.Sequential( 112 | nn.ConvTranspose2d(self.total_latent_size, num_gen_channels, self.init_size, bias=False), 113 | nn.ReLU(inplace=True))) 114 | 115 | num_input_channels = self.num_gen_channels 116 | in_size = self.init_size 117 | for i in range(self.num_layers - 2): 118 | if not skip_channels: 119 | self.dcl.append(Deconv(num_input_channels, num_input_channels // 2, in_size, kernel_size, stride=stride, 120 | batch_norm=batch_norm)) 121 | else: 122 | self.dcl.append( 123 | UnetBlock(num_input_channels, num_input_channels // 2, skip_channels[i], in_size, 124 | kernel_size, stride=stride, batch_norm=batch_norm)) 125 | 126 | num_input_channels //= 2 127 | in_size = tuple(2 * x for x in in_size) 128 | 129 | padding = calculate_padding(kernel_size, stride) 130 | self.dcl.append(nn.ConvTranspose2d(num_input_channels, self.num_channels, kernel_size, 131 | stride=stride, padding=padding // 2, bias=False)) 132 | self.final_activation = nn.Tanh() 133 | 134 | def forward(self, x, c=None, aux=None, skip=[]): 135 | if aux is not None: 136 | if self.rnn_noise is not None: 137 | aux, h = self.rnn_noise(aux) 138 | aux = self.rnn_noise_squashing(aux) 139 | 140 | if self.aux_only_on_top: 141 | aux = self.aux_dcl(aux.view(-1, self.aux_size, 1, 1)) 142 | else: 143 | x = torch.cat([x, aux], 2) 144 | 145 | if c is not None: 146 | x = torch.cat([x, c], 2) 147 | 148 | x = x.view(-1, self.total_latent_size, 1, 1) 149 | x = self.dcl[0](x) 150 | 151 | if self.aux_only_on_top: 152 | x = x + aux 153 | 154 | if not skip: 155 | for i in range(1, self.num_layers - 1): 156 | x = self.dcl[i](x) 157 | else: 158 | for i in range(1, self.num_layers - 1): 159 | x = self.dcl[i](x, skip[i - 1]) 160 | 161 | x = self.dcl[-1](x, output_size=[-1, 3, self.img_size[0], self.img_size[1]]) 162 | return self.final_activation(x) 163 | -------------------------------------------------------------------------------- /sda/rnn_audio.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .encoder_audio import Encoder 3 | 4 | 5 | class RNN(nn.Module): 6 | def __init__(self, feat_length, enc_code_size, rnn_code_size, rate, n_layers=2, init_kernel=None, 7 | init_stride=None): 8 | super(RNN, self).__init__() 9 | self.audio_feat_samples = int(rate * feat_length) 10 | self.enc_code_size = enc_code_size 11 | self.rnn_code_size = rnn_code_size 12 | self.encoder = Encoder(self.enc_code_size, rate, feat_length, init_kernel=init_kernel, 13 | init_stride=init_stride) 14 | self.rnn = nn.GRU(self.enc_code_size, self.rnn_code_size, n_layers, batch_first=True) 15 | 16 | def forward(self, x, lengths): 17 | seq_length = x.size()[1] 18 | x = x.view(-1, 1, self.audio_feat_samples) 19 | x = self.encoder(x) 20 | x = x.view(-1, seq_length, self.enc_code_size) 21 | x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) 22 | x, h = self.rnn(x) 23 | x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) 24 | return x.contiguous() 25 | -------------------------------------------------------------------------------- /sda/sda.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import torch 3 | from .encoder_image import Encoder 4 | from .img_generator import Generator 5 | from .rnn_audio import RNN 6 | 7 | from scipy import signal 8 | from skimage import transform as tf 9 | import numpy as np 10 | from PIL import Image 11 | import contextlib 12 | import os 13 | import shutil 14 | import tempfile 15 | import skvideo.io as sio 16 | import scipy.io.wavfile as wav 17 | import ffmpeg 18 | import face_alignment 19 | from pydub import AudioSegment 20 | from pydub.utils import mediainfo 21 | 22 | 23 | @contextlib.contextmanager 24 | def cd(newdir, cleanup=lambda: True): 25 | prevdir = os.getcwd() 26 | os.chdir(os.path.expanduser(newdir)) 27 | try: 28 | yield 29 | finally: 30 | os.chdir(prevdir) 31 | cleanup() 32 | 33 | 34 | @contextlib.contextmanager 35 | def tempdir(): 36 | dirpath = tempfile.mkdtemp() 37 | 38 | def cleanup(): 39 | shutil.rmtree(dirpath) 40 | 41 | with cd(dirpath, cleanup): 42 | yield dirpath 43 | 44 | 45 | def get_audio_feature_extractor(model_path="grid", gpu=-1): 46 | if model_path == "grid": 47 | model_path = os.path.join(os.path.split(__file__)[0], "data", "grid.dat") 48 | elif model_path == "timit": 49 | model_path = os.path.join(os.path.split(__file__)[0], "data", "timit.dat") 50 | elif model_path == "crema": 51 | model_path = os.path.join(os.path.split(__file__)[0], "data", "crema.dat") 52 | 53 | if gpu < 0: 54 | device = torch.device("cpu") 55 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 56 | else: 57 | device = torch.device("cuda:" + str(gpu)) 58 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(gpu)) 59 | 60 | audio_rate = model_dict["audio_rate"] 61 | audio_feat_len = model_dict['audio_feat_len'] 62 | rnn_gen_dim = model_dict['rnn_gen_dim'] 63 | aud_enc_dim = model_dict['aud_enc_dim'] 64 | video_rate = model_dict["video_rate"] 65 | 66 | encoder = RNN(audio_feat_len, aud_enc_dim, rnn_gen_dim, audio_rate, init_kernel=0.005, init_stride=0.001) 67 | encoder.to(device) 68 | encoder.load_state_dict(model_dict['encoder']) 69 | 70 | overlap = audio_feat_len - 1.0 / video_rate 71 | return encoder, {"rate": audio_rate, "feature length": audio_feat_len, "overlap": overlap} 72 | 73 | 74 | def cut_audio_sequence(seq, feature_length, overlap, rate): 75 | seq = seq.view(-1, 1) 76 | snip_length = int(feature_length * rate) 77 | cutting_stride = int((feature_length - overlap) * rate) 78 | pad_samples = snip_length - cutting_stride 79 | 80 | pad_left = torch.zeros(pad_samples // 2, 1, device=seq.device) 81 | pad_right = torch.zeros(pad_samples - pad_samples // 2, 1, device=seq.device) 82 | 83 | seq = torch.cat((pad_left, seq), 0) 84 | seq = torch.cat((seq, pad_right), 0) 85 | 86 | stacked = seq.narrow(0, 0, snip_length).unsqueeze(0) 87 | iterations = (seq.size()[0] - snip_length) // cutting_stride + 1 88 | for i in range(1, iterations): 89 | stacked = torch.cat((stacked, seq.narrow(0, i * cutting_stride, snip_length).unsqueeze(0))) 90 | return stacked 91 | 92 | 93 | class VideoAnimator(): 94 | def __init__(self, model_path="grid", gpu=-1): 95 | 96 | if model_path == "grid": 97 | model_path = os.path.join(os.path.split(__file__)[0], "data", "grid.dat") 98 | elif model_path == "timit": 99 | model_path = os.path.join(os.path.split(__file__)[0], "data", "timit.dat") 100 | elif model_path == "crema": 101 | model_path = os.path.join(os.path.split(__file__)[0], "data", "crema.dat") 102 | 103 | if gpu < 0: 104 | self.device = torch.device("cpu") 105 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 106 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device="cpu", flip_input=False) 107 | else: 108 | self.device = torch.device("cuda:" + str(gpu)) 109 | model_dict = torch.load(model_path, map_location=lambda storage, loc: storage.cuda(gpu)) 110 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, device="cuda:" + str(gpu), 111 | flip_input=False) 112 | 113 | self.stablePntsIDs = [33, 36, 39, 42, 45] 114 | self.mean_face = model_dict["mean_face"] 115 | self.img_size = model_dict["img_size"] 116 | self.audio_rate = model_dict["audio_rate"] 117 | self.video_rate = model_dict["video_rate"] 118 | self.audio_feat_len = model_dict['audio_feat_len'] 119 | self.audio_feat_samples = model_dict['audio_feat_samples'] 120 | self.id_enc_dim = model_dict['id_enc_dim'] 121 | self.rnn_gen_dim = model_dict['rnn_gen_dim'] 122 | self.aud_enc_dim = model_dict['aud_enc_dim'] 123 | self.aux_latent = model_dict['aux_latent'] 124 | self.sequential_noise = model_dict['sequential_noise'] 125 | self.conversion_dict = {'s16': np.int16, 's32': np.int32} 126 | 127 | self.img_transform = transforms.Compose([ 128 | transforms.ToPILImage(), 129 | transforms.Resize((self.img_size[0], self.img_size[1])), 130 | transforms.ToTensor(), 131 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 132 | 133 | self.encoder = RNN(self.audio_feat_len, self.aud_enc_dim, self.rnn_gen_dim, 134 | self.audio_rate, init_kernel=0.005, init_stride=0.001) 135 | self.encoder.to(self.device) 136 | self.encoder.load_state_dict(model_dict['encoder']) 137 | 138 | self.encoder_id = Encoder(self.id_enc_dim, self.img_size) 139 | self.encoder_id.to(self.device) 140 | self.encoder_id.load_state_dict(model_dict['encoder_id']) 141 | 142 | skip_channels = list(self.encoder_id.channels) 143 | skip_channels.reverse() 144 | 145 | self.generator = Generator(self.img_size, self.rnn_gen_dim, condition_size=self.id_enc_dim, 146 | num_gen_channels=self.encoder_id.channels[-1], 147 | skip_channels=skip_channels, aux_size=self.aux_latent, 148 | sequential_noise=self.sequential_noise) 149 | 150 | self.generator.to(self.device) 151 | self.generator.load_state_dict(model_dict['generator']) 152 | 153 | self.encoder.eval() 154 | self.encoder_id.eval() 155 | self.generator.eval() 156 | 157 | def save_video(self, video, audio, path, overwrite=True, experimental_ffmpeg=False, scale=None): 158 | if not os.path.isabs(path): 159 | path = os.path.join(os.getcwd(), path) 160 | 161 | with tempdir() as dirpath: 162 | # Save the video file 163 | writer = sio.FFmpegWriter(os.path.join(dirpath, "tmp.avi"), 164 | inputdict={'-r': str(self.video_rate) + "/1", }, 165 | outputdict={'-r': str(self.video_rate) + "/1", } 166 | ) 167 | for i in range(video.shape[0]): 168 | frame = np.rollaxis(video[i, :, :, :], 0, 3) 169 | 170 | if scale is not None: 171 | frame = tf.rescale(frame, scale, anti_aliasing=True, multichannel=True, mode='reflect') 172 | 173 | writer.writeFrame(frame) 174 | writer.close() 175 | 176 | # Save the audio file 177 | wav.write(os.path.join(dirpath, "tmp.wav"), self.audio_rate, audio) 178 | 179 | in1 = ffmpeg.input(os.path.join(dirpath, "tmp.avi")) 180 | in2 = ffmpeg.input(os.path.join(dirpath, "tmp.wav")) 181 | if experimental_ffmpeg: 182 | out = ffmpeg.output(in1['v'], in2['a'], path, strict='-2', loglevel="panic") 183 | else: 184 | out = ffmpeg.output(in1['v'], in2['a'], path, loglevel="panic") 185 | 186 | if overwrite: 187 | out = out.overwrite_output() 188 | out.run() 189 | 190 | def preprocess_img(self, img): 191 | src = self.fa.get_landmarks(img)[0][self.stablePntsIDs, :] 192 | dst = self.mean_face[self.stablePntsIDs, :] 193 | tform = tf.estimate_transform('similarity', src, dst) # find the transformation matrix 194 | warped = tf.warp(img, inverse_map=tform.inverse, output_shape=self.img_size) # wrap the frame image 195 | warped = warped * 255 # note output from wrap is double image (value range [0,1]) 196 | warped = warped.astype('uint8') 197 | 198 | return warped 199 | 200 | def _cut_sequence_(self, seq, cutting_stride, pad_samples): 201 | pad_left = torch.zeros(pad_samples // 2, 1) 202 | pad_right = torch.zeros(pad_samples - pad_samples // 2, 1) 203 | 204 | seq = torch.cat((pad_left, seq), 0) 205 | seq = torch.cat((seq, pad_right), 0) 206 | 207 | stacked = seq.narrow(0, 0, self.audio_feat_samples).unsqueeze(0) 208 | iterations = (seq.size()[0] - self.audio_feat_samples) // cutting_stride + 1 209 | for i in range(1, iterations): 210 | stacked = torch.cat((stacked, seq.narrow(0, i * cutting_stride, self.audio_feat_samples).unsqueeze(0))) 211 | return stacked.to(self.device) 212 | 213 | def _broadcast_elements_(self, batch, repeat_no): 214 | total_tensors = [] 215 | for i in range(0, batch.size()[0]): 216 | total_tensors += [torch.stack(repeat_no * [batch[i]])] 217 | 218 | return torch.stack(total_tensors) 219 | 220 | def __call__(self, img, audio, fs=None, aligned=False): 221 | if isinstance(img, str): # if we have a path then grab the image 222 | frm = Image.open(img) 223 | frm.thumbnail((400, 400)) 224 | frame = np.array(frm) 225 | else: 226 | frame = img 227 | 228 | if not aligned: 229 | frame = self.preprocess_img(frame) 230 | 231 | if isinstance(audio, str): # if we have a path then grab the audio clip 232 | info = mediainfo(audio) 233 | fs = int(info['sample_rate']) 234 | audio = np.array(AudioSegment.from_file(audio, info['format_name']).set_channels(1).get_array_of_samples()) 235 | 236 | if info['sample_fmt'] in self.conversion_dict: 237 | audio = audio.astype(self.conversion_dict[info['sample_fmt']]) 238 | else: 239 | if max(audio) > np.iinfo(np.int16).max: 240 | audio = audio.astype(np.int32) 241 | else: 242 | audio = audio.astype(np.int16) 243 | 244 | if fs is None: 245 | raise AttributeError("Audio provided without specifying the rate. Specify rate or use audio file!") 246 | 247 | if audio.ndim > 1 and audio.shape[1] > 1: 248 | audio = audio[:, 0] 249 | 250 | max_value = np.iinfo(audio.dtype).max 251 | if fs != self.audio_rate: 252 | seq_length = audio.shape[0] 253 | speech = torch.from_numpy( 254 | signal.resample(audio, int(seq_length * self.audio_rate / float(fs))) / float(max_value)).float() 255 | speech = speech.view(-1, 1) 256 | else: 257 | audio = torch.from_numpy(audio / float(max_value)).float() 258 | speech = audio.view(-1, 1) 259 | 260 | frame = self.img_transform(frame).to(self.device) 261 | 262 | cutting_stride = int(self.audio_rate / float(self.video_rate)) 263 | audio_seq_padding = self.audio_feat_samples - cutting_stride 264 | 265 | # Create new sequences of the audio windows 266 | audio_feat_seq = self._cut_sequence_(speech, cutting_stride, audio_seq_padding) 267 | frame = frame.unsqueeze(0) 268 | audio_feat_seq = audio_feat_seq.unsqueeze(0) 269 | audio_feat_seq_length = audio_feat_seq.size()[1] 270 | 271 | z = self.encoder(audio_feat_seq, [audio_feat_seq_length]) # Encoding for the motion 272 | noise = torch.FloatTensor(1, audio_feat_seq_length, self.aux_latent).normal_(0, 0.33).to(self.device) 273 | z_id, skips = self.encoder_id(frame, retain_intermediate=True) 274 | skip_connections = [] 275 | for skip_variable in skips: 276 | skip_connections.append(self._broadcast_elements_(skip_variable, z.size()[1])) 277 | skip_connections.reverse() 278 | 279 | z_id = self._broadcast_elements_(z_id, z.size()[1]) 280 | gen_video = self.generator(z, c=z_id, aux=noise, skip=skip_connections) 281 | 282 | returned_audio = ((2 ** 15) * speech.detach().cpu().numpy()).astype(np.int16) 283 | gen_video = 125 * gen_video.squeeze().detach().cpu().numpy() + 125 284 | return gen_video, returned_audio 285 | -------------------------------------------------------------------------------- /sda/utils.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | 4 | def prime_factors(number): 5 | factor = 2 6 | factors = [] 7 | while factor * factor <= number: 8 | if number % factor: 9 | factor += 1 10 | else: 11 | number //= factor 12 | factors.append(int(factor)) 13 | if number > 1: 14 | factors.append(int(number)) 15 | return factors 16 | 17 | 18 | def calculate_padding(kernel_size, stride=1, in_size=0): 19 | out_size = ceil(float(in_size) / float(stride)) 20 | return int((out_size - 1) * stride + kernel_size - in_size) 21 | 22 | 23 | def calculate_output_size(in_size, kernel_size, stride, padding): 24 | return int((in_size + padding - kernel_size) / stride) + 1 25 | 26 | 27 | def is_power2(num): 28 | return num != 0 and ((num & (num - 1)) == 0) 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='sda', 4 | version='0.2', 5 | description='Produces speech-driven faces', 6 | packages=['sda'], 7 | package_dir={'sda': 'sda'}, 8 | package_data={'sda': ['data/*.dat']}, 9 | install_requires=[ 10 | 'numpy', 11 | 'scipy', 12 | 'scikit-video', 13 | 'scikit-image', 14 | 'ffmpeg-python', 15 | 'torch', 16 | 'face-alignment', 17 | 'torchvision', 18 | 'pydub', 19 | ], 20 | zip_safe=False) 21 | --------------------------------------------------------------------------------