├── requirements.txt ├── config_grid.py ├── config_koeba.py ├── config_obama.py ├── modules ├── audio_encoder.py ├── eyes_encoder.py ├── audio_encoder_bn.py ├── utils.py ├── hourglass.py ├── blocks.py ├── head_predictor.py ├── layers.py └── generator.py ├── model.py ├── test.py ├── .gitignore ├── README.md ├── inference.py └── loader.py /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.9.1 2 | av==8.0.3 3 | -------------------------------------------------------------------------------- /config_grid.py: -------------------------------------------------------------------------------- 1 | from modules.audio_encoder import AudioEncoder 2 | 3 | params = { 4 | 'fps': 25, 5 | 'samplerate': 22050, 6 | 'weight_path': './weight/grid.pt', 7 | 'model_params': { 8 | 'head_predictor': { 9 | 'num_affines': 1, 10 | 'using_scale': False 11 | }, 12 | 'generator': { 13 | 'num_affines': 1, 14 | 'num_residual_mod_blocks': 6, 15 | 'using_gaussian': False 16 | }, 17 | 'audio_encoder': AudioEncoder 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /config_koeba.py: -------------------------------------------------------------------------------- 1 | from modules.audio_encoder_bn import AudioEncoder 2 | ​ 3 | params = { 4 | 'fps': 29.97, 5 | 'samplerate': 22050, 6 | 'weight_path': './weight/koeba.pt', 7 | 'model_params': { 8 | 'head_predictor': { 9 | 'num_affines': 1, 10 | 'using_scale': True 11 | }, 12 | 'generator': { 13 | 'num_affines': 1, 14 | 'num_residual_mod_blocks': 6, 15 | 'using_gaussian': False 16 | }, 17 | 'audio_encoder': AudioEncoder 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /config_obama.py: -------------------------------------------------------------------------------- 1 | from modules.audio_encoder import AudioEncoder 2 | 3 | params = { 4 | 'fps': 29.97, 5 | 'samplerate': 22050, 6 | 'weight_path': './weight/obama.pt', 7 | 'model_params': { 8 | 'head_predictor': { 9 | 'num_affines': 1, 10 | 'using_scale': True 11 | }, 12 | 'generator': { 13 | 'num_affines': 1, 14 | 'num_residual_mod_blocks': 6, 15 | 'using_gaussian': False 16 | }, 17 | 'audio_encoder': AudioEncoder 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /modules/audio_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layers import conv1d, lstm, linear 5 | 6 | 7 | class AudioEncoder(nn.Module): 8 | def __init__(self): 9 | super(AudioEncoder, self).__init__() 10 | self.encoder = nn.Sequential( 11 | conv1d(192, 512, 3, 1, 1), nn.ReLU(), 12 | conv1d(512, 512, 3, 1, 1), nn.ReLU(), 13 | conv1d(512, 512, 3, 1, 1), nn.ReLU(), 14 | conv1d(512, 512, 3, 1, 1), nn.ReLU()) 15 | 16 | self.rnn = lstm(512, 512) 17 | 18 | self.decoder = nn.Sequential( 19 | linear(512, 256), nn.ReLU(), 20 | linear(256, 256)) 21 | 22 | 23 | def forward(self, x): 24 | x = self.encoder(x) # (n, c, l) 25 | x, _ = self.rnn(x.permute(0, 2, 1)) # (n, l, c) 26 | x = self.decoder(x[:,-1]) # (n, c) 27 | return x 28 | 29 | -------------------------------------------------------------------------------- /modules/eyes_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layers import linear 5 | from .utils import AntiAliasInterpolation2d 6 | from .blocks import down_block2d 7 | 8 | 9 | class EyesEncoder(nn.Module): 10 | def __init__(self): 11 | super(EyesEncoder, self).__init__() 12 | self.down = AntiAliasInterpolation2d(3, 0.5) # 1/2 resolution 13 | 14 | self.encoder = nn.Sequential( 15 | down_block2d(3, 32), 16 | down_block2d(32, 64), 17 | down_block2d(64, 64), 18 | down_block2d(64, 64), 19 | down_block2d(64, 64), 20 | nn.AvgPool2d(kernel_size=(4, 4))) 21 | 22 | self.predictor = nn.Sequential( 23 | linear(64, 32), nn.ReLU(), 24 | linear(32, 8)) 25 | 26 | 27 | def forward(self, x): 28 | # x: (n, c, 32, 64) 29 | x = self.down(x) 30 | x = self.encoder(x) 31 | x = self.predictor(x[...,0,0]) 32 | return x 33 | -------------------------------------------------------------------------------- /modules/audio_encoder_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layers import conv1d, conv2d, linear 5 | 6 | 7 | class AudioEncoder(nn.Module): 8 | def __init__(self): 9 | super(AudioEncoder, self).__init__() 10 | self.encoder = nn.Sequential( 11 | conv1d(192, 256, 7, 1, 0), nn.BatchNorm1d(256), nn.ReLU(), 12 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU(), 13 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU(), 14 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU(), 15 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU()) 16 | 17 | self.predictor = nn.Sequential( 18 | linear(256*9, 256), nn.ReLU(), 19 | linear(256, 256), nn.ReLU(), 20 | linear(256, 256), nn.ReLU(), 21 | linear(256, 256), nn.ReLU(), 22 | linear(256, 256), nn.ReLU(), 23 | linear(256, 256)) 24 | 25 | 26 | def forward(self, x): 27 | x = self.encoder(x) 28 | x = x.view(x.shape[0], -1) 29 | x = self.predictor(x) 30 | return x 31 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from modules.utils import AntiAliasInterpolation2d 4 | from modules.head_predictor import HeadPredictor 5 | from modules.eyes_encoder import EyesEncoder 6 | from modules.generator import Generator 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, params): 11 | super(Model, self).__init__() 12 | self.down_sampler = AntiAliasInterpolation2d(3, 0.25) # 1/4 resolution 13 | self.head_predictor = HeadPredictor(**params['head_predictor']) 14 | self.eyes_encoder = EyesEncoder() 15 | self.audio_encoder = params['audio_encoder']() 16 | self.generator = Generator(**params['generator']) 17 | 18 | 19 | def forward(self, src, drv, eye, spec): 20 | src_down = self.down_sampler(src) 21 | drv_down = self.down_sampler(drv) 22 | 23 | src_head = self.head_predictor(src_down) 24 | drv_head = self.head_predictor(drv_down) 25 | 26 | drv_eyes = self.eyes_encoder(eye) 27 | 28 | drv_audio = self.audio_encoder(spec) 29 | 30 | generator_out = self.generator(src, src_head, drv_head, drv_eyes, drv_audio) 31 | 32 | return generator_out['prediction'] 33 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import argparse 3 | 4 | from inference import inference 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-m', '--mode', required=True) 8 | args = parser.parse_args() 9 | 10 | mode = args.mode 11 | 12 | # parsing mode 13 | if mode == 'obama_demo1': 14 | from config_obama import params 15 | params['mode'] = mode 16 | params['dataset_root'] = './dataset/obama/demo1' 17 | params['save_path'] = 'obama_demo1.mp4' 18 | elif mode == 'obama_demo2': 19 | from config_obama import params 20 | params['mode'] = mode 21 | params['dataset_root'] = './dataset/obama/demo2' 22 | params['save_path'] = 'obama_demo2.mp4' 23 | elif mode == 'grid_demo1': 24 | from config_grid import params 25 | params['mode'] = mode 26 | params['dataset_root'] = './dataset/grid/demo1' 27 | params['save_path'] = 'grid_demo1.mp4' 28 | elif mode == 'grid_demo2': 29 | from config_grid import params 30 | params['mode'] = mode 31 | params['dataset_root'] = './dataset/grid/demo2' 32 | params['save_path'] = 'grid_demo2.mp4' 33 | elif mode == 'koeba_demo1': 34 | from config_koeba import params 35 | params['mode'] = mode 36 | params['dataset_root'] = './dataset/koeba/demo1' 37 | params['save_path'] = 'koeba_demo1.mp4' 38 | elif mode == 'koeba_demo2': 39 | from config_koeba import params 40 | params['mode'] = mode 41 | params['dataset_root'] = './dataset/koeba/demo2' 42 | params['save_path'] = 'koeba_demo2.mp4' 43 | else: 44 | raise Exception('mode: [obama_demo1|obama_demo2|grid_demo1|grid_demo2|koeba_demo1|koeba_demo2]') 45 | 46 | inference(params) 47 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def make_coordinate_grid(h, w): 7 | theta = torch.eye(2,3).unsqueeze(0) 8 | grid = F.affine_grid(theta, (1, 1, h, w), align_corners=False)[0] 9 | return grid 10 | 11 | 12 | class AntiAliasInterpolation2d(nn.Module): 13 | """ 14 | Band-limited downsampling, for better preservation of the input signal. 15 | """ 16 | 17 | def __init__(self, channels, scale): 18 | super(AntiAliasInterpolation2d, self).__init__() 19 | sigma = (1 / scale - 1) / 2 20 | kernel_size = 2 * round(sigma * 4) + 1 21 | self.ka = kernel_size // 2 22 | self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka 23 | 24 | kernel_size = [kernel_size, kernel_size] 25 | sigma = [sigma, sigma] 26 | # The gaussian kernel is the product of the 27 | # gaussian function of each dimension. 28 | kernel = 1 29 | meshgrids = torch.meshgrid( 30 | [ 31 | torch.arange(size, dtype=torch.float32) 32 | for size in kernel_size 33 | ] 34 | ) 35 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 36 | mean = (size - 1) / 2 37 | kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) 38 | 39 | # Make sure sum of values in gaussian kernel equals 1. 40 | kernel = kernel / torch.sum(kernel) 41 | # Reshape to depthwise convolutional weight 42 | kernel = kernel.view(1, 1, *kernel.size()) 43 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 44 | 45 | self.register_buffer('weight', kernel) 46 | self.groups = channels 47 | self.scale = scale 48 | inv_scale = 1 / scale 49 | self.int_inv_scale = int(inv_scale) 50 | 51 | def forward(self, input): 52 | if self.scale == 1.0: 53 | return input 54 | 55 | out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) 56 | out = F.conv2d(out, weight=self.weight, groups=self.groups) 57 | out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] 58 | 59 | return out 60 | 61 | -------------------------------------------------------------------------------- /modules/hourglass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .blocks import down_block2d, up_block2d 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 10 | super(Encoder, self).__init__() 11 | 12 | down_blocks = [] 13 | for i in range(num_blocks): 14 | if i == 0: 15 | in_channels = in_features 16 | else: 17 | in_channels = min(max_features, block_expansion * (2 ** i)) 18 | out_channels = min(max_features, block_expansion * (2 ** (i + 1))) 19 | down_blocks.append(down_block2d(in_channels, out_channels)) 20 | 21 | self.down_blocks = nn.ModuleList(down_blocks) 22 | 23 | def forward(self, x): 24 | outs = [x] 25 | for down_block in self.down_blocks: 26 | outs.append(down_block(outs[-1])) 27 | return outs 28 | 29 | 30 | class Decoder(nn.Module): 31 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 32 | super(Decoder, self).__init__() 33 | 34 | up_blocks = [] 35 | for i in range(num_blocks)[::-1]: 36 | in_channels = min(max_features, block_expansion * (2 ** (i + 1))) 37 | if i != num_blocks - 1: 38 | in_channels = 2 * in_channels 39 | out_channels = min(max_features, block_expansion * (2 ** i)) 40 | up_blocks.append(up_block2d(in_channels, out_channels)) 41 | 42 | self.up_blocks = nn.ModuleList(up_blocks) 43 | self.out_filters = block_expansion + in_features 44 | 45 | def forward(self, x): 46 | out = x.pop() 47 | for up_block in self.up_blocks: 48 | out = up_block(out) 49 | skip = x.pop() 50 | out = torch.cat([out, skip], dim=1) 51 | return out 52 | 53 | 54 | class Hourglass(nn.Module): 55 | def __init__( 56 | self, 57 | block_expansion=32, 58 | in_features=3, 59 | num_blocks=5, 60 | max_features=1024 61 | ): 62 | super(Hourglass, self).__init__() 63 | self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) 64 | self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) 65 | self.out_filters = self.decoder.out_filters 66 | 67 | def forward(self, x): 68 | return self.decoder(self.encoder(x)) 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /modules/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from .layers import linear, conv2d, conv3d, mod_conv2d 5 | 6 | 7 | activations = { 8 | 'relu': nn.ReLU(), 9 | 'leaky_relu': nn.LeakyReLU() 10 | } 11 | 12 | class conv2d_bn_relu(nn.Module): 13 | def __init__( 14 | self, 15 | in_channels, 16 | out_channels, 17 | kernel_size=3, 18 | stride=1, 19 | padding=1, 20 | activation='relu' 21 | ): 22 | super(conv2d_bn_relu, self).__init__() 23 | self.conv = conv2d(in_channels, out_channels, kernel_size, stride, padding) 24 | self.bn = nn.BatchNorm2d(out_channels, affine=True) 25 | self.act = activations[activation] 26 | 27 | def forward(self, x): 28 | return self.act(self.bn(self.conv(x))) 29 | 30 | 31 | class down_block2d(nn.Module): 32 | def __init__(self, in_channels, out_channels): 33 | super(down_block2d, self).__init__() 34 | self.conv = conv2d_bn_relu(in_channels, out_channels) 35 | self.pool = nn.AvgPool2d(kernel_size=(2, 2)) 36 | 37 | def forward(self, x): 38 | return self.pool(self.conv(x)) 39 | 40 | 41 | class res_block2d(nn.Module): 42 | def __init__(self, in_channels): 43 | super(res_block2d, self).__init__() 44 | self.block = nn.Sequential( 45 | nn.BatchNorm2d(in_channels, affine=True), 46 | nn.ReLU(), 47 | conv2d(in_channels, in_channels, 3, 1, 1), 48 | nn.BatchNorm2d(in_channels, affine=True), 49 | nn.ReLU(), 50 | conv2d(in_channels, in_channels, 3, 1, 1)) 51 | 52 | def forward(self, x): 53 | return x + self.block(x) 54 | 55 | 56 | class res_mod_block2d(nn.Module): 57 | def __init__(self, in_channels, in_features): 58 | super(res_mod_block2d, self).__init__() 59 | self.block = nn.ModuleList([ 60 | mod_conv2d(in_channels, in_channels, 3, 1, in_features), 61 | mod_conv2d(in_channels, in_channels, 3, 1, in_features)]) 62 | 63 | def forward(self, x, y): 64 | r = x 65 | for i in range(len(self.block)): 66 | r = self.block[i](r, y) 67 | return x + r 68 | 69 | 70 | class up_block2d(nn.Module): 71 | def __init__(self, in_channels, out_channels): 72 | super(up_block2d, self).__init__() 73 | self.conv = conv2d_bn_relu(in_channels, out_channels) 74 | 75 | def forward(self, x): 76 | return self.conv(F.interpolate(x, scale_factor=2)) 77 | 78 | -------------------------------------------------------------------------------- /modules/head_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .hourglass import Hourglass 6 | from .layers import conv2d 7 | from .utils import make_coordinate_grid 8 | 9 | 10 | class HeadPredictor(nn.Module): 11 | def __init__(self, num_affines, using_scale): 12 | super(HeadPredictor, self).__init__() 13 | self.using_scale = using_scale 14 | 15 | self.extractor = Hourglass() 16 | self.predictor = conv2d(self.extractor.out_filters, num_affines, 7, 1, 3) # linear activation 17 | 18 | self.register_buffer('grid', make_coordinate_grid(64, 64)) # (h, w, 2) 19 | self.register_buffer('identity', torch.diag(torch.ones(3))) # (3, 3) 20 | 21 | 22 | def forward(self, x): 23 | x = self.predictor(self.extractor(x)) # (n, 1, h, w) 24 | 25 | # convert feature to heatmap 26 | n, c, h, w = x.shape 27 | x = x.view(n, c, h*w) # flatten spatially 28 | heatmap = F.softmax(x, dim=2) 29 | heatmap = heatmap.view(n, c, h, w) # recover shape: (n, c, h, w) 30 | 31 | # compute statistics of heatmap 32 | mean = (self.grid * heatmap[...,None]).sum(dim=(2, 3)) # (n, c, 2) 33 | deviation = self.grid - mean[:,:,None,None] # (n, c, h, w, 2) 34 | covar = torch.matmul(deviation[...,None], deviation[...,None,:]) # (n, c, h, w, 2, 2) 35 | covar = (covar * heatmap[...,None,None]).sum(dim=(2, 3)) # (n, c, 2, 2) 36 | 37 | # SVD for extract affine from covariance matrix 38 | U, S, _ = torch.svd(covar.cpu()) 39 | affine = U.to(covar.device) # rotation matrix: (n, c, 2, 2) 40 | if self.using_scale: 41 | S = S.to(covar.device) # (n, c, 2) 42 | S = torch.diag_embed(S ** 0.5) # scale matrix: (n, c, 2, 2) 43 | affine = torch.matmul(affine, S) # (n, c, 2, 2) 44 | 45 | # add translation to affine matrix 46 | affine = torch.cat([affine, mean[...,None]], dim=3) # (n, c, 2, 3) 47 | homo_affine = self.identity[None].repeat(n, c, 1, 1) # (n, c, 3, 3) 48 | homo_affine[:,:,:2] = affine # (n, c, 3, 3) 49 | 50 | # convert heatmap to gaussian 51 | covar_inverse = torch.inverse(covar)[:,:,None,None] # (n, c, 1, 1, 2, 2) 52 | under_exp = torch.matmul(deviation[...,None,:], covar_inverse) # (n, c, h, w, 1, 2) 53 | under_exp = torch.matmul(under_exp, deviation[...,None]) # (n, c, h, w, 1, 1) 54 | under_exp = under_exp[...,0,0] # (n, c, h, w) 55 | gaussian = torch.exp(-0.5 * under_exp) # (n, c, h, w) 56 | 57 | outputs = { 58 | 'affine': homo_affine, # (n, c, 3, 3) 59 | 'heatmap': heatmap, # (n, c, h, w) 60 | 'gaussian': gaussian # (n, c, h, w) 61 | } 62 | return outputs 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DisCoHead: Audio-and-Video-Driven Talking Head Generation by Disentangled Control of Head Pose and Facial Expressions 2 | 3 | 4 | 5 |

6 |
7 | 8 |
9 |

10 | 11 | 12 | 13 | 14 |

15 |

16 | Project Page | 17 | KoEBA Dataset 18 |

19 |

20 | 21 | 22 | ## Requirements 23 | 24 | You can install required environments using below commands: 25 | 26 | ```shell 27 | git clone https://github.com/deepbrainai-research/discohead 28 | cd discohead 29 | conda create -n discohead python=3.7 30 | conda activate discohead 31 | conda install pytorch==1.10.0 torchvision==0.11.1 torchaudio==0.10.0 cudatoolkit=10.2 -c pytorch 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Generating Demo Videos 36 | 37 | - Download the pre-trained checkpoints from [google drive](https://drive.google.com/file/d/1ki8BsZ3Yg2i5OhHF04ULwtgFg6r5Tsro/view?usp=sharing) and put into `weight` folder. 38 | - Download `dataset.zip` from [google drive](https://drive.google.com/file/d/1xy9pxgQYrl2Bnee4npq88zdrHlIcX2wf/view?usp=sharing) and unzip into `dataset`. 39 | - `DisCoHead` directory should have the following structure. 40 | 41 | ``` 42 | DisCoHead/ 43 | ├── dataset/ 44 | │ ├── grid/ 45 | │ │ ├── demo1/ 46 | │ │ ├── demo2/ 47 | │ ├── koeba/ 48 | │ │ ├── demo1/ 49 | │ │ ├── demo2/ 50 | │ ├── obama/ 51 | │ │ ├── demo1/ 52 | │ │ ├── demo2/ 53 | ├── weight/ 54 | │ ├── grid.pt 55 | │ ├── koeba.pt 56 | │ ├── obama.pt 57 | ├── modules/ 58 | ‥‥ 59 | 60 | ``` 61 | - The `--mode` argument is used to specify which demo video you want to generate: 62 | ```shell 63 | python test.py --mode {mode} 64 | ``` 65 | - Available modes: `obama_demo1, obama_demo2, grid_demo1, grid_demo2, koeba_demo1, koeba_demo2` 66 | 67 | 68 | ## License 69 | 70 |

71 | 72 | Creative Commons License 73 | 74 |
This work is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License. You must not use this work for commercial purposes. You must not distribute it in modified material. You must give appropriate credit and provide a link to the license. 75 |

76 | 77 | 78 | 79 | ## Citation 80 | 81 | ```plain 82 | @INPROCEEDINGS{10095670, 83 | author={Hwang, Geumbyeol and Hong, Sunwon and Lee, Seunghyun and Park, Sungwoo and Chae, Gyeongsu}, 84 | booktitle={ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 85 | title={DisCoHead: Audio-and-Video-Driven Talking Head Generation by Disentangled Control of Head Pose and Facial Expressions}, 86 | year={2023}, 87 | volume={}, 88 | number={}, 89 | pages={1-5}, 90 | doi={10.1109/ICASSP49357.2023.10095670}} 91 | ``` 92 | 93 | -------------------------------------------------------------------------------- /modules/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class linear(nn.Module): 7 | def __init__(self, in_features, out_features): 8 | super(linear, self).__init__() 9 | self.linear = nn.Linear(in_features, out_features) 10 | # nn.init.xavier_uniform_(self.linear.weight) 11 | 12 | def forward(self, x): 13 | return self.linear(x) 14 | 15 | 16 | class conv1d(nn.Module): 17 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 18 | super(conv1d, self).__init__() 19 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding) 20 | # nn.init.xavier_uniform_(self.conv.weight) 21 | # nn.init.zeros_(self.conv.bias) 22 | 23 | def forward(self, x): 24 | return self.conv(x) 25 | 26 | 27 | class conv2d(nn.Module): 28 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 29 | super(conv2d, self).__init__() 30 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 31 | # nn.init.xavier_uniform_(self.conv.weight) 32 | # nn.init.zeros_(self.conv.bias) 33 | 34 | def forward(self, x): 35 | return self.conv(x) 36 | 37 | 38 | class conv3d(nn.Module): 39 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 40 | super(conv3d, self).__init__() 41 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding) 42 | # nn.init.xavier_uniform_(self.conv.weight) 43 | # nn.init.zeros_(self.conv.bias) 44 | 45 | def forward(self, x): 46 | return self.conv(x) 47 | 48 | 49 | class lstm(nn.Module): 50 | def __init__(self, input_size, hidden_size): 51 | super(lstm, self).__init__() 52 | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) 53 | 54 | # initialize 55 | # kernel: glorot_uniform 56 | # recurrent: orthogonal 57 | # bias: zeros 58 | for name, param in self.lstm.named_parameters(): 59 | if 'weight_ih' in name: 60 | for i in range(4): 61 | o = param.shape[0] // 4 62 | nn.init.xavier_uniform_(param[i*o:(i+1)*o]) 63 | elif 'weight_hh' in name: 64 | for i in range(4): 65 | o = param.shape[0] // 4 66 | nn.init.orthogonal_(param[i*o:(i+1)*o]) 67 | else: 68 | nn.init.zeros_(param) 69 | 70 | def forward(self, x): 71 | return self.lstm(x) 72 | 73 | 74 | class mod_conv2d(nn.Module): 75 | def __init__(self, in_channels, out_channels, kernel_size, padding, in_features): 76 | super(mod_conv2d, self).__init__() 77 | 78 | self.mod = linear(in_features, in_channels) 79 | nn.init.zeros_(self.mod.linear.weight) 80 | nn.init.constant_(self.mod.linear.bias, 1.0) 81 | 82 | self.bn = nn.BatchNorm2d(in_channels, affine=True) 83 | self.relu = nn.ReLU() 84 | 85 | weight = torch.zeros( 86 | (1, out_channels, in_channels, kernel_size, kernel_size), requires_grad=True) 87 | self.weight = nn.Parameter(weight) 88 | nn.init.xavier_uniform_(self.weight) 89 | 90 | self.padding = padding 91 | 92 | 93 | def forward(self, x, y): 94 | x = self.bn(x) 95 | x = self.relu(x) 96 | 97 | n, _, h, w = x.shape 98 | _, o, i, k, k = self.weight.shape 99 | 100 | # Modulate 101 | scale = self.mod(y).view(n, 1, i, 1, 1) 102 | weight = self.weight * scale # (n, o, i, k, k) 103 | 104 | # Demodulate 105 | demod = weight.pow(2).sum([2, 3, 4], keepdim=True) 106 | demod = torch.rsqrt(demod + 1e-8) # (n, o, 1, 1, 1) 107 | weight = weight * demod # (n, o, i, k, k) 108 | 109 | x = x.view(1, n * i, h, w) 110 | weight = weight.view(n * o, i, k, k) 111 | x = F.conv2d(x, weight, bias=None, padding=self.padding, groups=n) 112 | x = x.view(n, o, h, w) 113 | 114 | return x 115 | -------------------------------------------------------------------------------- /modules/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .blocks import conv2d_bn_relu, down_block2d, res_mod_block2d, up_block2d 6 | from .layers import conv2d 7 | from .utils import make_coordinate_grid 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, num_affines, num_residual_mod_blocks, using_gaussian): 12 | super(Generator, self).__init__() 13 | self.using_gaussian = using_gaussian 14 | 15 | num_input_channel = (num_affines + 1) * 3 16 | if self.using_gaussian: 17 | num_input_channel += num_affines 18 | 19 | self.encoder = nn.Sequential( 20 | conv2d_bn_relu(num_input_channel, 64, 7, 1, 3), 21 | down_block2d(64, 128), 22 | down_block2d(128, 256), 23 | down_block2d(256, 512) 24 | ) 25 | self.mask_predictor = conv2d(512, num_affines+1, 5, 1, 2) 26 | 27 | self.feature_predictor = conv2d(512, 512, 5, 1, 2) 28 | self.occlusion_predictor = conv2d(512, 1, 5, 1, 2) 29 | 30 | bottleneck = [] 31 | for i in range(num_residual_mod_blocks): 32 | bottleneck.append(res_mod_block2d(512, 8 + 256)) 33 | self.bottleneck = nn.ModuleList(bottleneck) 34 | 35 | self.decoder = nn.Sequential( 36 | up_block2d(512, 256), 37 | up_block2d(256, 128), 38 | up_block2d(128, 64), 39 | conv2d(64, 3, 7, 1, 3) 40 | ) 41 | 42 | full_grid = make_coordinate_grid(256, 256) 43 | homo_full_grid = torch.cat([full_grid, torch.ones(256, 256, 1)], dim=2) 44 | self.register_buffer('homo_full_grid', homo_full_grid) # (256, 256, 3) 45 | 46 | grid = make_coordinate_grid(32, 32) 47 | homo_grid = torch.cat([grid, torch.ones(32, 32, 1)], dim=2) 48 | self.register_buffer('grid', grid) # (h, w, 2) 49 | self.register_buffer('homo_grid', homo_grid) # (h, w, 3) 50 | 51 | 52 | def forward(self, src, src_head, drv_head, drv_eyes, drv_audio): 53 | affine = torch.matmul(src_head['affine'], torch.inverse(drv_head['affine'])) # (n, c, 3, 3) 54 | affine = affine * torch.sign(affine[:,:,0:1,0:1]) # revert_axis_swap 55 | affine = affine[:,:,None,None] # (n, c, 1, 1, 3, 3) 56 | 57 | affine_motion = torch.matmul(affine, self.homo_full_grid[...,None]) # (n, c, h, w, 3, 1) 58 | affine_motion = affine_motion[...,:2,0] # (n, c, h, w, 2) 59 | n, c, h, w, _ = affine_motion.shape 60 | 61 | stacked_src = src.repeat(c, 1, 1, 1) 62 | flatten_affine_motion = affine_motion.view(n*c, h, w, 2) 63 | transformed_src = F.grid_sample(stacked_src, flatten_affine_motion, align_corners=False) 64 | transformed_src = transformed_src.view(n, c*3, h, w) 65 | 66 | # encoding source and tansformed source 67 | stacked_input = torch.cat([src, transformed_src], dim=1) 68 | if self.using_gaussian: 69 | gaussian = drv_head['gaussian'] - src_head['gaussian'] 70 | gaussian = F.interpolate(gaussian, scale_factor=4) 71 | stacked_input = torch.cat([stacked_input, gaussian], dim=1) 72 | x = self.encoder(stacked_input) 73 | 74 | # compute dense motion 75 | mask = F.softmax(self.mask_predictor(x), dim=1) # (n, c+1, h, w) 76 | dense_motion = torch.matmul(affine, self.homo_grid[...,None]) # (n, c, h, w, 3, 1) 77 | dense_motion = dense_motion[...,:2,0] # (n, c, h, w, 2) 78 | identity_motion = self.grid[None, None].repeat(n, 1, 1, 1, 1) # (n, 1, h, w, 2) 79 | dense_motion = torch.cat([dense_motion, identity_motion], dim=1) # (n, c+1, h, w, 2) 80 | optical_flow = (mask[...,None] * dense_motion).sum(dim=1) # (n, h, w, 2) 81 | 82 | # compute deformed source feature 83 | feature = self.feature_predictor(x) # (n, c, h, w) 84 | deformed = F.grid_sample(feature, optical_flow, align_corners=False) # (n, c, h, w) 85 | occlusion = torch.sigmoid(self.occlusion_predictor(x)) # (n, 1, h, w) 86 | feature = occlusion * deformed # (n, c, h, w) 87 | 88 | # inject local motion 89 | local_motion = torch.cat([drv_eyes, drv_audio], dim=1) 90 | for i in range(len(self.bottleneck)): 91 | feature = self.bottleneck[i](feature, local_motion) 92 | 93 | prediction = torch.sigmoid(self.decoder(feature)) 94 | 95 | outputs = { 96 | 'prediction': prediction, # (n, 3, h, w) 97 | 'transformed_src': transformed_src, # (n, c*3, h, w) 98 | 'mask': mask, # (n, c+1, h, w) 99 | 'occlusion': occlusion, # (n, 1, h, w) 100 | 'optical_flow': optical_flow # (n, h, w, 2) 101 | } 102 | return outputs 103 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from loader import * 4 | from torchvision.io import write_video 5 | 6 | 7 | def inference(params): 8 | model = model_loader(params) 9 | 10 | if params['mode'] == 'obama_demo1': 11 | frames, audio = obama_demo1(model, params) 12 | elif params['mode'] == 'obama_demo2': 13 | frames, audio = obama_demo2(model, params) 14 | elif params['mode'] == 'grid_demo1': 15 | frames, audio = grid_demo1(model, params) 16 | elif params['mode'] == 'grid_demo2': 17 | frames, audio = grid_demo2(model, params) 18 | elif params['mode'] == 'koeba_demo1': 19 | frames, audio = koeba_demo1(model, params) 20 | elif params['mode'] == 'koeba_demo2': 21 | frames, audio = koeba_demo2(model, params) 22 | 23 | write_video( 24 | filename=params['save_path'], 25 | video_array=frames, 26 | fps=params['fps'], 27 | video_codec='libx264', 28 | options={'crf': '12'}, 29 | audio_array=audio, 30 | audio_fps=params['samplerate'], 31 | audio_codec='aac') 32 | 33 | 34 | def obama_demo1(model, params): 35 | data = obama_demo1_data_loader(params) 36 | 37 | # fixed samples 38 | src_gpu = data['src'][None].cuda() / 255 39 | masked_src_gpu = data['masked_src'][None].cuda() / 255 40 | silence_feature_gpu = data['silence_feature'][None].cuda() 41 | black = torch.zeros_like(data['src']) 42 | 43 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8) 44 | for i in range(data['n_frames']): 45 | head_driver_gpu = data['head_driver'][i:i+1].cuda() / 255 46 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255 47 | audio_feature_gpu = data['audio_features'][i:i+1].cuda() 48 | 49 | # 1. head only 50 | ho = model(src_gpu, head_driver_gpu, masked_src_gpu, silence_feature_gpu) 51 | ho = (ho[0] * 255).type(torch.uint8).cpu() 52 | 53 | # 2. head + audio 54 | ha = model(src_gpu, head_driver_gpu, masked_src_gpu, audio_feature_gpu) 55 | ha = (ha[0] * 255).type(torch.uint8).cpu() 56 | 57 | # 3. head + eye + audio 58 | hea = model(src_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu) 59 | hea = (hea[0] * 255).type(torch.uint8).cpu() 60 | 61 | preds[i] = torch.cat([ 62 | torch.cat([black, data['head_driver'][i], data['lip_driver'][i], data['eye_driver'][i]], dim=2), 63 | torch.cat([data['src'], ho, ha, hea], dim=2), 64 | ], dim=1) 65 | 66 | preds = preds.permute(0, 2, 3, 1) 67 | 68 | return preds, data['audio'] 69 | 70 | 71 | def obama_demo2(model, params): 72 | data = obama_demo2_data_loader(params) 73 | 74 | # fixed samples 75 | src_gpu = data['src'][None].cuda() / 255 76 | masked_src_gpu = data['masked_src'][None].cuda() / 255 77 | silence_feature_gpu = data['silence_feature'][None].cuda() 78 | 79 | preds = torch.zeros(data['n_frames'], 3, 256, 256*6, dtype=torch.uint8) 80 | for i in range(data['n_frames']): 81 | driver_gpu = data['driver'][i:i+1].cuda() / 255 82 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255 83 | audio_feature_gpu = data['audio_features'][i:i+1].cuda() 84 | 85 | # 1. head only 86 | ho = model(src_gpu, driver_gpu, masked_src_gpu, silence_feature_gpu) 87 | ho = (ho[0] * 255).type(torch.uint8).cpu() 88 | 89 | # 2. audio only 90 | ao = model(src_gpu, src_gpu, masked_src_gpu, audio_feature_gpu) 91 | ao = (ao[0] * 255).type(torch.uint8).cpu() 92 | 93 | # 3. eye only 94 | eo = model(src_gpu, src_gpu, masked_driver_gpu, silence_feature_gpu) 95 | eo = (eo[0] * 255).type(torch.uint8).cpu() 96 | 97 | # 4. all 98 | hea = model(src_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu) 99 | hea = (hea[0] * 255).type(torch.uint8).cpu() 100 | 101 | preds[i] = torch.cat([data['src'], data['driver'][i], ho, ao, eo, hea], dim=2) 102 | 103 | preds = preds.permute(0, 2, 3, 1) 104 | 105 | return preds, data['audio'] 106 | 107 | 108 | def grid_demo1(model, params): 109 | data = grid_demo1_data_loader(params) 110 | 111 | # fixed samples 112 | src1_gpu = data['src1'][None].cuda() / 255 113 | src2_gpu = data['src2'][None].cuda() / 255 114 | src3_gpu = data['src3'][None].cuda() / 255 115 | black = torch.zeros_like(data['src1']) 116 | 117 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8) 118 | for i in range(data['n_frames']): 119 | driver_gpu = data['driver'][i:i+1].cuda() / 255 120 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255 121 | audio_feature_gpu = data['audio_features'][i:i+1].cuda() 122 | 123 | # 1. head only 124 | sp1 = model(src1_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu) 125 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu() 126 | 127 | # 2. head only + audio 128 | sp2 = model(src2_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu) 129 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu() 130 | 131 | # 3. head only + audio + eye 132 | sp3 = model(src3_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu) 133 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu() 134 | 135 | preds[i] = torch.cat([ 136 | torch.cat([black, data['src1'], data['src2'], data['src3']], dim=2), 137 | torch.cat([data['driver'][i], sp1, sp2, sp3], dim=2), 138 | ], dim=1) 139 | 140 | preds = preds.permute(0, 2, 3, 1) 141 | 142 | return preds, data['audio'] 143 | 144 | 145 | def grid_demo2(model, params): 146 | data = grid_demo2_data_loader(params) 147 | 148 | # fixed samples 149 | src1_gpu = data['src1'][None].cuda() / 255 150 | src2_gpu = data['src2'][None].cuda() / 255 151 | src3_gpu = data['src3'][None].cuda() / 255 152 | 153 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8) 154 | for i in range(data['n_frames']): 155 | head_driver_gpu = data['head_driver'][i:i+1].cuda() / 255 156 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255 157 | audio_feature_gpu = data['audio_features'][i:i+1].cuda() 158 | 159 | # 1. head only 160 | sp1 = model(src1_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu) 161 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu() 162 | 163 | # 2. head only + audio 164 | sp2 = model(src2_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu) 165 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu() 166 | 167 | # 3. head only + audio + eye 168 | sp3 = model(src3_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu) 169 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu() 170 | 171 | half = torch.zeros_like(data['src1']) 172 | margin = 140 173 | half[:,:margin] = data['eye_driver'][i][:,:margin] 174 | half[:,margin:] = data['lip_driver'][i][:,margin:] 175 | 176 | preds[i] = torch.cat([ 177 | torch.cat([data['head_driver'][i], data['src1'], data['src2'], data['src3']], dim=2), 178 | torch.cat([half, sp1, sp2, sp3], dim=2), 179 | ], dim=1) 180 | 181 | preds = preds.permute(0, 2, 3, 1) 182 | 183 | return preds, data['audio'] 184 | 185 | 186 | def koeba_demo1(model, params): 187 | data = koeba_demo1_data_loader(params) 188 | 189 | # fixed samples 190 | src1_gpu = data['src1'][None].cuda() / 255 191 | src2_gpu = data['src2'][None].cuda() / 255 192 | src3_gpu = data['src3'][None].cuda() / 255 193 | black = torch.zeros_like(data['src1']) 194 | 195 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8) 196 | for i in range(data['n_frames']): 197 | driver_gpu = data['driver'][i:i+1].cuda() / 255 198 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255 199 | audio_feature_gpu = data['audio_features'][i:i+1].cuda() 200 | 201 | # 1. head only 202 | sp1 = model(src1_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu) 203 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu() 204 | 205 | # 2. head only + audio 206 | sp2 = model(src2_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu) 207 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu() 208 | 209 | # 3. head only + audio + eye 210 | sp3 = model(src3_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu) 211 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu() 212 | 213 | preds[i] = torch.cat([ 214 | torch.cat([black, data['src1'], data['src2'], data['src3']], dim=2), 215 | torch.cat([data['driver'][i], sp1, sp2, sp3], dim=2), 216 | ], dim=1) 217 | 218 | preds = preds.permute(0, 2, 3, 1) 219 | 220 | return preds, data['audio'] 221 | 222 | 223 | def koeba_demo2(model, params): 224 | data = koeba_demo2_data_loader(params) 225 | 226 | # fixed samples 227 | src1_gpu = data['src1'][None].cuda() / 255 228 | src2_gpu = data['src2'][None].cuda() / 255 229 | src3_gpu = data['src3'][None].cuda() / 255 230 | 231 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8) 232 | for i in range(data['n_frames']): 233 | head_driver_gpu = data['head_driver'][i:i+1].cuda() / 255 234 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255 235 | audio_feature_gpu = data['audio_features'][i:i+1].cuda() 236 | 237 | # 1. head only 238 | sp1 = model(src1_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu) 239 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu() 240 | 241 | # 2. head only + audio 242 | sp2 = model(src2_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu) 243 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu() 244 | 245 | # 3. head only + audio + eye 246 | sp3 = model(src3_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu) 247 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu() 248 | 249 | half = torch.zeros_like(data['src1']) 250 | margin = 128 251 | half[:,:margin] = data['eye_driver'][i][:,:margin] 252 | half[:,margin:] = data['lip_driver'][i][:,margin:] 253 | 254 | preds[i] = torch.cat([ 255 | torch.cat([data['head_driver'][i], data['src1'], data['src2'], data['src3']], dim=2), 256 | torch.cat([half, sp1, sp2, sp3], dim=2), 257 | ], dim=1) 258 | 259 | preds = preds.permute(0, 2, 3, 1) 260 | 261 | return preds, data['audio'] 262 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | import numpy as np 4 | 5 | from glob import glob 6 | from os.path import join 7 | from torchvision.io import read_image 8 | from torchvision.io.image import ImageReadMode 9 | 10 | from model import Model 11 | 12 | 13 | def model_loader(config_dict): 14 | with torch.no_grad(): 15 | model = Model(config_dict['model_params']) 16 | checkpoint = torch.load(config_dict['weight_path']) 17 | model.load_state_dict(checkpoint['state_dict'], strict=False) 18 | model.eval().cuda() 19 | return model 20 | 21 | 22 | def obama_demo1_data_loader(params): 23 | dataset_root = params['dataset_root'] 24 | 25 | # load audio 26 | audio, _ = librosa.load(join(dataset_root, 'audio.wav')) 27 | 28 | # load features 29 | audio_features = np.load(join(dataset_root, 'audio_features.npy')) 30 | silence_feature = np.load(join(dataset_root, 'silence_feature.npy')) 31 | 32 | # parsing image paths 33 | src_path = join(dataset_root, 'src.jpg') 34 | masked_src_path = join(dataset_root, 'masked_src.png') 35 | head_driver_paths = sorted(glob(join(dataset_root, 'head_driver/*.jpg'))) 36 | eye_driver_paths = sorted(glob(join(dataset_root, 'eye_driver/*.jpg'))) 37 | lip_driver_paths = sorted(glob(join(dataset_root, 'lip_driver/*.jpg'))) 38 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png'))) 39 | 40 | # load images 41 | n_frames = len(audio_features) 42 | 43 | src = read_image(src_path, ImageReadMode.RGB) 44 | masked_src = read_image(masked_src_path, ImageReadMode.RGB) 45 | 46 | head_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 47 | eye_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 48 | lip_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 49 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 50 | 51 | for i in range(n_frames): 52 | head_driver[i] = read_image(head_driver_paths[i], ImageReadMode.RGB) 53 | eye_driver[i] = read_image(eye_driver_paths[i], ImageReadMode.RGB) 54 | lip_driver[i] = read_image(lip_driver_paths[i], ImageReadMode.RGB) 55 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB) 56 | 57 | data = { 58 | 'n_frames': n_frames, 59 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L) 60 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L) 61 | 'silence_feature': torch.from_numpy(silence_feature), # (C, L) 62 | 'src': src, 63 | 'masked_src': masked_src, 64 | 'head_driver': head_driver, 65 | 'eye_driver': eye_driver, 66 | 'lip_driver': lip_driver, 67 | 'masked_driver': masked_driver 68 | } 69 | 70 | return data 71 | 72 | 73 | def obama_demo2_data_loader(params): 74 | dataset_root = params['dataset_root'] 75 | 76 | # load audio 77 | audio, _ = librosa.load(join(dataset_root, 'audio.wav')) 78 | 79 | # load features 80 | audio_features = np.load(join(dataset_root, 'audio_features.npy')) 81 | silence_feature = np.load(join(dataset_root, 'silence_feature.npy')) 82 | 83 | # parsing image paths 84 | src_path = join(dataset_root, 'src.jpg') 85 | masked_src_path = join(dataset_root, 'masked_src.png') 86 | driver_paths = sorted(glob(join(dataset_root, 'driver/*.jpg'))) 87 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png'))) 88 | 89 | # load images 90 | n_frames = len(audio_features) 91 | 92 | src = read_image(src_path, ImageReadMode.RGB) 93 | masked_src = read_image(masked_src_path, ImageReadMode.RGB) 94 | 95 | driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 96 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 97 | 98 | for i in range(n_frames): 99 | driver[i] = read_image(driver_paths[i], ImageReadMode.RGB) 100 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB) 101 | 102 | data = { 103 | 'n_frames': n_frames, 104 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L) 105 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L) 106 | 'silence_feature': torch.from_numpy(silence_feature), # (C, L) 107 | 'src': src, 108 | 'masked_src': masked_src, 109 | 'driver': driver, 110 | 'masked_driver': masked_driver 111 | } 112 | 113 | return data 114 | 115 | 116 | def grid_demo1_data_loader(params): 117 | dataset_root = params['dataset_root'] 118 | 119 | # load audio 120 | audio, _ = librosa.load(join(dataset_root, 'audio.wav')) 121 | 122 | # load features 123 | audio_features = np.load(join(dataset_root, 'audio_features.npy')) 124 | 125 | # parsing image paths 126 | src1_path = join(dataset_root, 'src1.jpg') 127 | src2_path = join(dataset_root, 'src2.jpg') 128 | src3_path = join(dataset_root, 'src3.jpg') 129 | 130 | driver_paths = sorted(glob(join(dataset_root, 'driver/*.jpg'))) 131 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png'))) 132 | 133 | # load images 134 | n_frames = len(audio_features) 135 | 136 | src1 = read_image(src1_path, ImageReadMode.RGB) 137 | src2 = read_image(src2_path, ImageReadMode.RGB) 138 | src3 = read_image(src3_path, ImageReadMode.RGB) 139 | 140 | driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 141 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 142 | 143 | for i in range(n_frames): 144 | driver[i] = read_image(driver_paths[i], ImageReadMode.RGB) 145 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB) 146 | 147 | data = { 148 | 'n_frames': n_frames, 149 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L) 150 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L) 151 | 'src1': src1, 152 | 'src2': src2, 153 | 'src3': src3, 154 | 'driver': driver, 155 | 'masked_driver': masked_driver 156 | } 157 | 158 | return data 159 | 160 | 161 | def grid_demo2_data_loader(params): 162 | dataset_root = params['dataset_root'] 163 | 164 | # load audio 165 | audio, _ = librosa.load(join(dataset_root, 'audio.wav')) 166 | 167 | # load features 168 | audio_features = np.load(join(dataset_root, 'audio_features.npy')) 169 | 170 | # parsing image paths 171 | src1_path = join(dataset_root, 'src1.jpg') 172 | src2_path = join(dataset_root, 'src2.jpg') 173 | src3_path = join(dataset_root, 'src3.jpg') 174 | 175 | head_driver_paths = sorted(glob(join(dataset_root, 'head_driver/*.jpg'))) 176 | eye_driver_paths = sorted(glob(join(dataset_root, 'eye_driver/*.jpg'))) 177 | lip_driver_paths = sorted(glob(join(dataset_root, 'lip_driver/*.jpg'))) 178 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png'))) 179 | 180 | # load images 181 | n_frames = len(audio_features) 182 | 183 | src1 = read_image(src1_path, ImageReadMode.RGB) 184 | src2 = read_image(src2_path, ImageReadMode.RGB) 185 | src3 = read_image(src3_path, ImageReadMode.RGB) 186 | 187 | head_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 188 | eye_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 189 | lip_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 190 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 191 | 192 | for i in range(n_frames): 193 | head_driver[i] = read_image(head_driver_paths[i], ImageReadMode.RGB) 194 | eye_driver[i] = read_image(eye_driver_paths[i], ImageReadMode.RGB) 195 | lip_driver[i] = read_image(lip_driver_paths[i], ImageReadMode.RGB) 196 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB) 197 | 198 | data = { 199 | 'n_frames': n_frames, 200 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L) 201 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L) 202 | 'src1': src1, 203 | 'src2': src2, 204 | 'src3': src3, 205 | 'head_driver': head_driver, 206 | 'eye_driver': eye_driver, 207 | 'lip_driver': lip_driver, 208 | 'masked_driver': masked_driver 209 | } 210 | 211 | return data 212 | 213 | 214 | def koeba_demo1_data_loader(params): 215 | dataset_root = params['dataset_root'] 216 | 217 | # load audio 218 | audio, _ = librosa.load(join(dataset_root, 'audio.wav')) 219 | 220 | # load features 221 | audio_features = np.load(join(dataset_root, 'audio_features.npy')) 222 | 223 | # parsing image paths 224 | src1_path = join(dataset_root, 'src1.jpg') 225 | src2_path = join(dataset_root, 'src2.jpg') 226 | src3_path = join(dataset_root, 'src3.jpg') 227 | 228 | driver_paths = sorted(glob(join(dataset_root, 'driver/*.png'))) 229 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png'))) 230 | 231 | # load images 232 | n_frames = len(audio_features) 233 | 234 | src1 = read_image(src1_path, ImageReadMode.RGB) 235 | src2 = read_image(src2_path, ImageReadMode.RGB) 236 | src3 = read_image(src3_path, ImageReadMode.RGB) 237 | 238 | driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 239 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 240 | 241 | for i in range(n_frames): 242 | driver[i] = read_image(driver_paths[i], ImageReadMode.RGB) 243 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB) 244 | 245 | data = { 246 | 'n_frames': n_frames, 247 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L) 248 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L) 249 | 'src1': src1, 250 | 'src2': src2, 251 | 'src3': src3, 252 | 'driver': driver, 253 | 'masked_driver': masked_driver 254 | } 255 | 256 | return data 257 | 258 | 259 | def koeba_demo2_data_loader(params): 260 | dataset_root = params['dataset_root'] 261 | 262 | # load audio 263 | audio, _ = librosa.load(join(dataset_root, 'audio.wav')) 264 | 265 | # load features 266 | audio_features = np.load(join(dataset_root, 'audio_features.npy')) 267 | 268 | # parsing image paths 269 | src1_path = join(dataset_root, 'src1.jpg') 270 | src2_path = join(dataset_root, 'src2.jpg') 271 | src3_path = join(dataset_root, 'src3.jpg') 272 | 273 | head_driver_paths = sorted(glob(join(dataset_root, 'head_driver/*.png'))) 274 | eye_driver_paths = sorted(glob(join(dataset_root, 'eye_driver/*.png'))) 275 | lip_driver_paths = sorted(glob(join(dataset_root, 'lip_driver/*.png'))) 276 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png'))) 277 | 278 | # load images 279 | n_frames = len(audio_features) 280 | 281 | src1 = read_image(src1_path, ImageReadMode.RGB) 282 | src2 = read_image(src2_path, ImageReadMode.RGB) 283 | src3 = read_image(src3_path, ImageReadMode.RGB) 284 | 285 | head_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 286 | eye_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 287 | lip_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 288 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8) 289 | 290 | for i in range(n_frames): 291 | head_driver[i] = read_image(head_driver_paths[i], ImageReadMode.RGB) 292 | eye_driver[i] = read_image(eye_driver_paths[i], ImageReadMode.RGB) 293 | lip_driver[i] = read_image(lip_driver_paths[i], ImageReadMode.RGB) 294 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB) 295 | 296 | data = { 297 | 'n_frames': n_frames, 298 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L) 299 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L) 300 | 'src1': src1, 301 | 'src2': src2, 302 | 'src3': src3, 303 | 'head_driver': head_driver, 304 | 'eye_driver': eye_driver, 305 | 'lip_driver': lip_driver, 306 | 'masked_driver': masked_driver 307 | } 308 | 309 | return data 310 | --------------------------------------------------------------------------------