├── requirements.txt
├── config_grid.py
├── config_koeba.py
├── config_obama.py
├── modules
├── audio_encoder.py
├── eyes_encoder.py
├── audio_encoder_bn.py
├── utils.py
├── hourglass.py
├── blocks.py
├── head_predictor.py
├── layers.py
└── generator.py
├── model.py
├── test.py
├── .gitignore
├── README.md
├── inference.py
└── loader.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa==0.9.1
2 | av==8.0.3
3 |
--------------------------------------------------------------------------------
/config_grid.py:
--------------------------------------------------------------------------------
1 | from modules.audio_encoder import AudioEncoder
2 |
3 | params = {
4 | 'fps': 25,
5 | 'samplerate': 22050,
6 | 'weight_path': './weight/grid.pt',
7 | 'model_params': {
8 | 'head_predictor': {
9 | 'num_affines': 1,
10 | 'using_scale': False
11 | },
12 | 'generator': {
13 | 'num_affines': 1,
14 | 'num_residual_mod_blocks': 6,
15 | 'using_gaussian': False
16 | },
17 | 'audio_encoder': AudioEncoder
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/config_koeba.py:
--------------------------------------------------------------------------------
1 | from modules.audio_encoder_bn import AudioEncoder
2 |
3 | params = {
4 | 'fps': 29.97,
5 | 'samplerate': 22050,
6 | 'weight_path': './weight/koeba.pt',
7 | 'model_params': {
8 | 'head_predictor': {
9 | 'num_affines': 1,
10 | 'using_scale': True
11 | },
12 | 'generator': {
13 | 'num_affines': 1,
14 | 'num_residual_mod_blocks': 6,
15 | 'using_gaussian': False
16 | },
17 | 'audio_encoder': AudioEncoder
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/config_obama.py:
--------------------------------------------------------------------------------
1 | from modules.audio_encoder import AudioEncoder
2 |
3 | params = {
4 | 'fps': 29.97,
5 | 'samplerate': 22050,
6 | 'weight_path': './weight/obama.pt',
7 | 'model_params': {
8 | 'head_predictor': {
9 | 'num_affines': 1,
10 | 'using_scale': True
11 | },
12 | 'generator': {
13 | 'num_affines': 1,
14 | 'num_residual_mod_blocks': 6,
15 | 'using_gaussian': False
16 | },
17 | 'audio_encoder': AudioEncoder
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/modules/audio_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .layers import conv1d, lstm, linear
5 |
6 |
7 | class AudioEncoder(nn.Module):
8 | def __init__(self):
9 | super(AudioEncoder, self).__init__()
10 | self.encoder = nn.Sequential(
11 | conv1d(192, 512, 3, 1, 1), nn.ReLU(),
12 | conv1d(512, 512, 3, 1, 1), nn.ReLU(),
13 | conv1d(512, 512, 3, 1, 1), nn.ReLU(),
14 | conv1d(512, 512, 3, 1, 1), nn.ReLU())
15 |
16 | self.rnn = lstm(512, 512)
17 |
18 | self.decoder = nn.Sequential(
19 | linear(512, 256), nn.ReLU(),
20 | linear(256, 256))
21 |
22 |
23 | def forward(self, x):
24 | x = self.encoder(x) # (n, c, l)
25 | x, _ = self.rnn(x.permute(0, 2, 1)) # (n, l, c)
26 | x = self.decoder(x[:,-1]) # (n, c)
27 | return x
28 |
29 |
--------------------------------------------------------------------------------
/modules/eyes_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .layers import linear
5 | from .utils import AntiAliasInterpolation2d
6 | from .blocks import down_block2d
7 |
8 |
9 | class EyesEncoder(nn.Module):
10 | def __init__(self):
11 | super(EyesEncoder, self).__init__()
12 | self.down = AntiAliasInterpolation2d(3, 0.5) # 1/2 resolution
13 |
14 | self.encoder = nn.Sequential(
15 | down_block2d(3, 32),
16 | down_block2d(32, 64),
17 | down_block2d(64, 64),
18 | down_block2d(64, 64),
19 | down_block2d(64, 64),
20 | nn.AvgPool2d(kernel_size=(4, 4)))
21 |
22 | self.predictor = nn.Sequential(
23 | linear(64, 32), nn.ReLU(),
24 | linear(32, 8))
25 |
26 |
27 | def forward(self, x):
28 | # x: (n, c, 32, 64)
29 | x = self.down(x)
30 | x = self.encoder(x)
31 | x = self.predictor(x[...,0,0])
32 | return x
33 |
--------------------------------------------------------------------------------
/modules/audio_encoder_bn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .layers import conv1d, conv2d, linear
5 |
6 |
7 | class AudioEncoder(nn.Module):
8 | def __init__(self):
9 | super(AudioEncoder, self).__init__()
10 | self.encoder = nn.Sequential(
11 | conv1d(192, 256, 7, 1, 0), nn.BatchNorm1d(256), nn.ReLU(),
12 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU(),
13 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU(),
14 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU(),
15 | conv1d(256, 256, 5, 1, 0), nn.BatchNorm1d(256), nn.ReLU())
16 |
17 | self.predictor = nn.Sequential(
18 | linear(256*9, 256), nn.ReLU(),
19 | linear(256, 256), nn.ReLU(),
20 | linear(256, 256), nn.ReLU(),
21 | linear(256, 256), nn.ReLU(),
22 | linear(256, 256), nn.ReLU(),
23 | linear(256, 256))
24 |
25 |
26 | def forward(self, x):
27 | x = self.encoder(x)
28 | x = x.view(x.shape[0], -1)
29 | x = self.predictor(x)
30 | return x
31 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from modules.utils import AntiAliasInterpolation2d
4 | from modules.head_predictor import HeadPredictor
5 | from modules.eyes_encoder import EyesEncoder
6 | from modules.generator import Generator
7 |
8 |
9 | class Model(nn.Module):
10 | def __init__(self, params):
11 | super(Model, self).__init__()
12 | self.down_sampler = AntiAliasInterpolation2d(3, 0.25) # 1/4 resolution
13 | self.head_predictor = HeadPredictor(**params['head_predictor'])
14 | self.eyes_encoder = EyesEncoder()
15 | self.audio_encoder = params['audio_encoder']()
16 | self.generator = Generator(**params['generator'])
17 |
18 |
19 | def forward(self, src, drv, eye, spec):
20 | src_down = self.down_sampler(src)
21 | drv_down = self.down_sampler(drv)
22 |
23 | src_head = self.head_predictor(src_down)
24 | drv_head = self.head_predictor(drv_down)
25 |
26 | drv_eyes = self.eyes_encoder(eye)
27 |
28 | drv_audio = self.audio_encoder(spec)
29 |
30 | generator_out = self.generator(src, src_head, drv_head, drv_eyes, drv_audio)
31 |
32 | return generator_out['prediction']
33 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import argparse
3 |
4 | from inference import inference
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('-m', '--mode', required=True)
8 | args = parser.parse_args()
9 |
10 | mode = args.mode
11 |
12 | # parsing mode
13 | if mode == 'obama_demo1':
14 | from config_obama import params
15 | params['mode'] = mode
16 | params['dataset_root'] = './dataset/obama/demo1'
17 | params['save_path'] = 'obama_demo1.mp4'
18 | elif mode == 'obama_demo2':
19 | from config_obama import params
20 | params['mode'] = mode
21 | params['dataset_root'] = './dataset/obama/demo2'
22 | params['save_path'] = 'obama_demo2.mp4'
23 | elif mode == 'grid_demo1':
24 | from config_grid import params
25 | params['mode'] = mode
26 | params['dataset_root'] = './dataset/grid/demo1'
27 | params['save_path'] = 'grid_demo1.mp4'
28 | elif mode == 'grid_demo2':
29 | from config_grid import params
30 | params['mode'] = mode
31 | params['dataset_root'] = './dataset/grid/demo2'
32 | params['save_path'] = 'grid_demo2.mp4'
33 | elif mode == 'koeba_demo1':
34 | from config_koeba import params
35 | params['mode'] = mode
36 | params['dataset_root'] = './dataset/koeba/demo1'
37 | params['save_path'] = 'koeba_demo1.mp4'
38 | elif mode == 'koeba_demo2':
39 | from config_koeba import params
40 | params['mode'] = mode
41 | params['dataset_root'] = './dataset/koeba/demo2'
42 | params['save_path'] = 'koeba_demo2.mp4'
43 | else:
44 | raise Exception('mode: [obama_demo1|obama_demo2|grid_demo1|grid_demo2|koeba_demo1|koeba_demo2]')
45 |
46 | inference(params)
47 |
--------------------------------------------------------------------------------
/modules/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def make_coordinate_grid(h, w):
7 | theta = torch.eye(2,3).unsqueeze(0)
8 | grid = F.affine_grid(theta, (1, 1, h, w), align_corners=False)[0]
9 | return grid
10 |
11 |
12 | class AntiAliasInterpolation2d(nn.Module):
13 | """
14 | Band-limited downsampling, for better preservation of the input signal.
15 | """
16 |
17 | def __init__(self, channels, scale):
18 | super(AntiAliasInterpolation2d, self).__init__()
19 | sigma = (1 / scale - 1) / 2
20 | kernel_size = 2 * round(sigma * 4) + 1
21 | self.ka = kernel_size // 2
22 | self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
23 |
24 | kernel_size = [kernel_size, kernel_size]
25 | sigma = [sigma, sigma]
26 | # The gaussian kernel is the product of the
27 | # gaussian function of each dimension.
28 | kernel = 1
29 | meshgrids = torch.meshgrid(
30 | [
31 | torch.arange(size, dtype=torch.float32)
32 | for size in kernel_size
33 | ]
34 | )
35 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
36 | mean = (size - 1) / 2
37 | kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
38 |
39 | # Make sure sum of values in gaussian kernel equals 1.
40 | kernel = kernel / torch.sum(kernel)
41 | # Reshape to depthwise convolutional weight
42 | kernel = kernel.view(1, 1, *kernel.size())
43 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
44 |
45 | self.register_buffer('weight', kernel)
46 | self.groups = channels
47 | self.scale = scale
48 | inv_scale = 1 / scale
49 | self.int_inv_scale = int(inv_scale)
50 |
51 | def forward(self, input):
52 | if self.scale == 1.0:
53 | return input
54 |
55 | out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
56 | out = F.conv2d(out, weight=self.weight, groups=self.groups)
57 | out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
58 |
59 | return out
60 |
61 |
--------------------------------------------------------------------------------
/modules/hourglass.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .blocks import down_block2d, up_block2d
6 |
7 |
8 | class Encoder(nn.Module):
9 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
10 | super(Encoder, self).__init__()
11 |
12 | down_blocks = []
13 | for i in range(num_blocks):
14 | if i == 0:
15 | in_channels = in_features
16 | else:
17 | in_channels = min(max_features, block_expansion * (2 ** i))
18 | out_channels = min(max_features, block_expansion * (2 ** (i + 1)))
19 | down_blocks.append(down_block2d(in_channels, out_channels))
20 |
21 | self.down_blocks = nn.ModuleList(down_blocks)
22 |
23 | def forward(self, x):
24 | outs = [x]
25 | for down_block in self.down_blocks:
26 | outs.append(down_block(outs[-1]))
27 | return outs
28 |
29 |
30 | class Decoder(nn.Module):
31 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
32 | super(Decoder, self).__init__()
33 |
34 | up_blocks = []
35 | for i in range(num_blocks)[::-1]:
36 | in_channels = min(max_features, block_expansion * (2 ** (i + 1)))
37 | if i != num_blocks - 1:
38 | in_channels = 2 * in_channels
39 | out_channels = min(max_features, block_expansion * (2 ** i))
40 | up_blocks.append(up_block2d(in_channels, out_channels))
41 |
42 | self.up_blocks = nn.ModuleList(up_blocks)
43 | self.out_filters = block_expansion + in_features
44 |
45 | def forward(self, x):
46 | out = x.pop()
47 | for up_block in self.up_blocks:
48 | out = up_block(out)
49 | skip = x.pop()
50 | out = torch.cat([out, skip], dim=1)
51 | return out
52 |
53 |
54 | class Hourglass(nn.Module):
55 | def __init__(
56 | self,
57 | block_expansion=32,
58 | in_features=3,
59 | num_blocks=5,
60 | max_features=1024
61 | ):
62 | super(Hourglass, self).__init__()
63 | self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
64 | self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
65 | self.out_filters = self.decoder.out_filters
66 |
67 | def forward(self, x):
68 | return self.decoder(self.encoder(x))
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/modules/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from .layers import linear, conv2d, conv3d, mod_conv2d
5 |
6 |
7 | activations = {
8 | 'relu': nn.ReLU(),
9 | 'leaky_relu': nn.LeakyReLU()
10 | }
11 |
12 | class conv2d_bn_relu(nn.Module):
13 | def __init__(
14 | self,
15 | in_channels,
16 | out_channels,
17 | kernel_size=3,
18 | stride=1,
19 | padding=1,
20 | activation='relu'
21 | ):
22 | super(conv2d_bn_relu, self).__init__()
23 | self.conv = conv2d(in_channels, out_channels, kernel_size, stride, padding)
24 | self.bn = nn.BatchNorm2d(out_channels, affine=True)
25 | self.act = activations[activation]
26 |
27 | def forward(self, x):
28 | return self.act(self.bn(self.conv(x)))
29 |
30 |
31 | class down_block2d(nn.Module):
32 | def __init__(self, in_channels, out_channels):
33 | super(down_block2d, self).__init__()
34 | self.conv = conv2d_bn_relu(in_channels, out_channels)
35 | self.pool = nn.AvgPool2d(kernel_size=(2, 2))
36 |
37 | def forward(self, x):
38 | return self.pool(self.conv(x))
39 |
40 |
41 | class res_block2d(nn.Module):
42 | def __init__(self, in_channels):
43 | super(res_block2d, self).__init__()
44 | self.block = nn.Sequential(
45 | nn.BatchNorm2d(in_channels, affine=True),
46 | nn.ReLU(),
47 | conv2d(in_channels, in_channels, 3, 1, 1),
48 | nn.BatchNorm2d(in_channels, affine=True),
49 | nn.ReLU(),
50 | conv2d(in_channels, in_channels, 3, 1, 1))
51 |
52 | def forward(self, x):
53 | return x + self.block(x)
54 |
55 |
56 | class res_mod_block2d(nn.Module):
57 | def __init__(self, in_channels, in_features):
58 | super(res_mod_block2d, self).__init__()
59 | self.block = nn.ModuleList([
60 | mod_conv2d(in_channels, in_channels, 3, 1, in_features),
61 | mod_conv2d(in_channels, in_channels, 3, 1, in_features)])
62 |
63 | def forward(self, x, y):
64 | r = x
65 | for i in range(len(self.block)):
66 | r = self.block[i](r, y)
67 | return x + r
68 |
69 |
70 | class up_block2d(nn.Module):
71 | def __init__(self, in_channels, out_channels):
72 | super(up_block2d, self).__init__()
73 | self.conv = conv2d_bn_relu(in_channels, out_channels)
74 |
75 | def forward(self, x):
76 | return self.conv(F.interpolate(x, scale_factor=2))
77 |
78 |
--------------------------------------------------------------------------------
/modules/head_predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .hourglass import Hourglass
6 | from .layers import conv2d
7 | from .utils import make_coordinate_grid
8 |
9 |
10 | class HeadPredictor(nn.Module):
11 | def __init__(self, num_affines, using_scale):
12 | super(HeadPredictor, self).__init__()
13 | self.using_scale = using_scale
14 |
15 | self.extractor = Hourglass()
16 | self.predictor = conv2d(self.extractor.out_filters, num_affines, 7, 1, 3) # linear activation
17 |
18 | self.register_buffer('grid', make_coordinate_grid(64, 64)) # (h, w, 2)
19 | self.register_buffer('identity', torch.diag(torch.ones(3))) # (3, 3)
20 |
21 |
22 | def forward(self, x):
23 | x = self.predictor(self.extractor(x)) # (n, 1, h, w)
24 |
25 | # convert feature to heatmap
26 | n, c, h, w = x.shape
27 | x = x.view(n, c, h*w) # flatten spatially
28 | heatmap = F.softmax(x, dim=2)
29 | heatmap = heatmap.view(n, c, h, w) # recover shape: (n, c, h, w)
30 |
31 | # compute statistics of heatmap
32 | mean = (self.grid * heatmap[...,None]).sum(dim=(2, 3)) # (n, c, 2)
33 | deviation = self.grid - mean[:,:,None,None] # (n, c, h, w, 2)
34 | covar = torch.matmul(deviation[...,None], deviation[...,None,:]) # (n, c, h, w, 2, 2)
35 | covar = (covar * heatmap[...,None,None]).sum(dim=(2, 3)) # (n, c, 2, 2)
36 |
37 | # SVD for extract affine from covariance matrix
38 | U, S, _ = torch.svd(covar.cpu())
39 | affine = U.to(covar.device) # rotation matrix: (n, c, 2, 2)
40 | if self.using_scale:
41 | S = S.to(covar.device) # (n, c, 2)
42 | S = torch.diag_embed(S ** 0.5) # scale matrix: (n, c, 2, 2)
43 | affine = torch.matmul(affine, S) # (n, c, 2, 2)
44 |
45 | # add translation to affine matrix
46 | affine = torch.cat([affine, mean[...,None]], dim=3) # (n, c, 2, 3)
47 | homo_affine = self.identity[None].repeat(n, c, 1, 1) # (n, c, 3, 3)
48 | homo_affine[:,:,:2] = affine # (n, c, 3, 3)
49 |
50 | # convert heatmap to gaussian
51 | covar_inverse = torch.inverse(covar)[:,:,None,None] # (n, c, 1, 1, 2, 2)
52 | under_exp = torch.matmul(deviation[...,None,:], covar_inverse) # (n, c, h, w, 1, 2)
53 | under_exp = torch.matmul(under_exp, deviation[...,None]) # (n, c, h, w, 1, 1)
54 | under_exp = under_exp[...,0,0] # (n, c, h, w)
55 | gaussian = torch.exp(-0.5 * under_exp) # (n, c, h, w)
56 |
57 | outputs = {
58 | 'affine': homo_affine, # (n, c, 3, 3)
59 | 'heatmap': heatmap, # (n, c, h, w)
60 | 'gaussian': gaussian # (n, c, h, w)
61 | }
62 | return outputs
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DisCoHead: Audio-and-Video-Driven Talking Head Generation by Disentangled Control of Head Pose and Facial Expressions
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | Project Page |
17 | KoEBA Dataset
18 |
19 |
20 |
21 |
22 | ## Requirements
23 |
24 | You can install required environments using below commands:
25 |
26 | ```shell
27 | git clone https://github.com/deepbrainai-research/discohead
28 | cd discohead
29 | conda create -n discohead python=3.7
30 | conda activate discohead
31 | conda install pytorch==1.10.0 torchvision==0.11.1 torchaudio==0.10.0 cudatoolkit=10.2 -c pytorch
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ## Generating Demo Videos
36 |
37 | - Download the pre-trained checkpoints from [google drive](https://drive.google.com/file/d/1ki8BsZ3Yg2i5OhHF04ULwtgFg6r5Tsro/view?usp=sharing) and put into `weight` folder.
38 | - Download `dataset.zip` from [google drive](https://drive.google.com/file/d/1xy9pxgQYrl2Bnee4npq88zdrHlIcX2wf/view?usp=sharing) and unzip into `dataset`.
39 | - `DisCoHead` directory should have the following structure.
40 |
41 | ```
42 | DisCoHead/
43 | ├── dataset/
44 | │ ├── grid/
45 | │ │ ├── demo1/
46 | │ │ ├── demo2/
47 | │ ├── koeba/
48 | │ │ ├── demo1/
49 | │ │ ├── demo2/
50 | │ ├── obama/
51 | │ │ ├── demo1/
52 | │ │ ├── demo2/
53 | ├── weight/
54 | │ ├── grid.pt
55 | │ ├── koeba.pt
56 | │ ├── obama.pt
57 | ├── modules/
58 | ‥‥
59 |
60 | ```
61 | - The `--mode` argument is used to specify which demo video you want to generate:
62 | ```shell
63 | python test.py --mode {mode}
64 | ```
65 | - Available modes: `obama_demo1, obama_demo2, grid_demo1, grid_demo2, koeba_demo1, koeba_demo2`
66 |
67 |
68 | ## License
69 |
70 |
71 |
72 |
73 |
74 |
This work is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License. You must not use this work for commercial purposes. You must not distribute it in modified material. You must give appropriate credit and provide a link to the license.
75 |
76 |
77 |
78 |
79 | ## Citation
80 |
81 | ```plain
82 | @INPROCEEDINGS{10095670,
83 | author={Hwang, Geumbyeol and Hong, Sunwon and Lee, Seunghyun and Park, Sungwoo and Chae, Gyeongsu},
84 | booktitle={ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
85 | title={DisCoHead: Audio-and-Video-Driven Talking Head Generation by Disentangled Control of Head Pose and Facial Expressions},
86 | year={2023},
87 | volume={},
88 | number={},
89 | pages={1-5},
90 | doi={10.1109/ICASSP49357.2023.10095670}}
91 | ```
92 |
93 |
--------------------------------------------------------------------------------
/modules/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class linear(nn.Module):
7 | def __init__(self, in_features, out_features):
8 | super(linear, self).__init__()
9 | self.linear = nn.Linear(in_features, out_features)
10 | # nn.init.xavier_uniform_(self.linear.weight)
11 |
12 | def forward(self, x):
13 | return self.linear(x)
14 |
15 |
16 | class conv1d(nn.Module):
17 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
18 | super(conv1d, self).__init__()
19 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
20 | # nn.init.xavier_uniform_(self.conv.weight)
21 | # nn.init.zeros_(self.conv.bias)
22 |
23 | def forward(self, x):
24 | return self.conv(x)
25 |
26 |
27 | class conv2d(nn.Module):
28 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
29 | super(conv2d, self).__init__()
30 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
31 | # nn.init.xavier_uniform_(self.conv.weight)
32 | # nn.init.zeros_(self.conv.bias)
33 |
34 | def forward(self, x):
35 | return self.conv(x)
36 |
37 |
38 | class conv3d(nn.Module):
39 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
40 | super(conv3d, self).__init__()
41 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
42 | # nn.init.xavier_uniform_(self.conv.weight)
43 | # nn.init.zeros_(self.conv.bias)
44 |
45 | def forward(self, x):
46 | return self.conv(x)
47 |
48 |
49 | class lstm(nn.Module):
50 | def __init__(self, input_size, hidden_size):
51 | super(lstm, self).__init__()
52 | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
53 |
54 | # initialize
55 | # kernel: glorot_uniform
56 | # recurrent: orthogonal
57 | # bias: zeros
58 | for name, param in self.lstm.named_parameters():
59 | if 'weight_ih' in name:
60 | for i in range(4):
61 | o = param.shape[0] // 4
62 | nn.init.xavier_uniform_(param[i*o:(i+1)*o])
63 | elif 'weight_hh' in name:
64 | for i in range(4):
65 | o = param.shape[0] // 4
66 | nn.init.orthogonal_(param[i*o:(i+1)*o])
67 | else:
68 | nn.init.zeros_(param)
69 |
70 | def forward(self, x):
71 | return self.lstm(x)
72 |
73 |
74 | class mod_conv2d(nn.Module):
75 | def __init__(self, in_channels, out_channels, kernel_size, padding, in_features):
76 | super(mod_conv2d, self).__init__()
77 |
78 | self.mod = linear(in_features, in_channels)
79 | nn.init.zeros_(self.mod.linear.weight)
80 | nn.init.constant_(self.mod.linear.bias, 1.0)
81 |
82 | self.bn = nn.BatchNorm2d(in_channels, affine=True)
83 | self.relu = nn.ReLU()
84 |
85 | weight = torch.zeros(
86 | (1, out_channels, in_channels, kernel_size, kernel_size), requires_grad=True)
87 | self.weight = nn.Parameter(weight)
88 | nn.init.xavier_uniform_(self.weight)
89 |
90 | self.padding = padding
91 |
92 |
93 | def forward(self, x, y):
94 | x = self.bn(x)
95 | x = self.relu(x)
96 |
97 | n, _, h, w = x.shape
98 | _, o, i, k, k = self.weight.shape
99 |
100 | # Modulate
101 | scale = self.mod(y).view(n, 1, i, 1, 1)
102 | weight = self.weight * scale # (n, o, i, k, k)
103 |
104 | # Demodulate
105 | demod = weight.pow(2).sum([2, 3, 4], keepdim=True)
106 | demod = torch.rsqrt(demod + 1e-8) # (n, o, 1, 1, 1)
107 | weight = weight * demod # (n, o, i, k, k)
108 |
109 | x = x.view(1, n * i, h, w)
110 | weight = weight.view(n * o, i, k, k)
111 | x = F.conv2d(x, weight, bias=None, padding=self.padding, groups=n)
112 | x = x.view(n, o, h, w)
113 |
114 | return x
115 |
--------------------------------------------------------------------------------
/modules/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .blocks import conv2d_bn_relu, down_block2d, res_mod_block2d, up_block2d
6 | from .layers import conv2d
7 | from .utils import make_coordinate_grid
8 |
9 |
10 | class Generator(nn.Module):
11 | def __init__(self, num_affines, num_residual_mod_blocks, using_gaussian):
12 | super(Generator, self).__init__()
13 | self.using_gaussian = using_gaussian
14 |
15 | num_input_channel = (num_affines + 1) * 3
16 | if self.using_gaussian:
17 | num_input_channel += num_affines
18 |
19 | self.encoder = nn.Sequential(
20 | conv2d_bn_relu(num_input_channel, 64, 7, 1, 3),
21 | down_block2d(64, 128),
22 | down_block2d(128, 256),
23 | down_block2d(256, 512)
24 | )
25 | self.mask_predictor = conv2d(512, num_affines+1, 5, 1, 2)
26 |
27 | self.feature_predictor = conv2d(512, 512, 5, 1, 2)
28 | self.occlusion_predictor = conv2d(512, 1, 5, 1, 2)
29 |
30 | bottleneck = []
31 | for i in range(num_residual_mod_blocks):
32 | bottleneck.append(res_mod_block2d(512, 8 + 256))
33 | self.bottleneck = nn.ModuleList(bottleneck)
34 |
35 | self.decoder = nn.Sequential(
36 | up_block2d(512, 256),
37 | up_block2d(256, 128),
38 | up_block2d(128, 64),
39 | conv2d(64, 3, 7, 1, 3)
40 | )
41 |
42 | full_grid = make_coordinate_grid(256, 256)
43 | homo_full_grid = torch.cat([full_grid, torch.ones(256, 256, 1)], dim=2)
44 | self.register_buffer('homo_full_grid', homo_full_grid) # (256, 256, 3)
45 |
46 | grid = make_coordinate_grid(32, 32)
47 | homo_grid = torch.cat([grid, torch.ones(32, 32, 1)], dim=2)
48 | self.register_buffer('grid', grid) # (h, w, 2)
49 | self.register_buffer('homo_grid', homo_grid) # (h, w, 3)
50 |
51 |
52 | def forward(self, src, src_head, drv_head, drv_eyes, drv_audio):
53 | affine = torch.matmul(src_head['affine'], torch.inverse(drv_head['affine'])) # (n, c, 3, 3)
54 | affine = affine * torch.sign(affine[:,:,0:1,0:1]) # revert_axis_swap
55 | affine = affine[:,:,None,None] # (n, c, 1, 1, 3, 3)
56 |
57 | affine_motion = torch.matmul(affine, self.homo_full_grid[...,None]) # (n, c, h, w, 3, 1)
58 | affine_motion = affine_motion[...,:2,0] # (n, c, h, w, 2)
59 | n, c, h, w, _ = affine_motion.shape
60 |
61 | stacked_src = src.repeat(c, 1, 1, 1)
62 | flatten_affine_motion = affine_motion.view(n*c, h, w, 2)
63 | transformed_src = F.grid_sample(stacked_src, flatten_affine_motion, align_corners=False)
64 | transformed_src = transformed_src.view(n, c*3, h, w)
65 |
66 | # encoding source and tansformed source
67 | stacked_input = torch.cat([src, transformed_src], dim=1)
68 | if self.using_gaussian:
69 | gaussian = drv_head['gaussian'] - src_head['gaussian']
70 | gaussian = F.interpolate(gaussian, scale_factor=4)
71 | stacked_input = torch.cat([stacked_input, gaussian], dim=1)
72 | x = self.encoder(stacked_input)
73 |
74 | # compute dense motion
75 | mask = F.softmax(self.mask_predictor(x), dim=1) # (n, c+1, h, w)
76 | dense_motion = torch.matmul(affine, self.homo_grid[...,None]) # (n, c, h, w, 3, 1)
77 | dense_motion = dense_motion[...,:2,0] # (n, c, h, w, 2)
78 | identity_motion = self.grid[None, None].repeat(n, 1, 1, 1, 1) # (n, 1, h, w, 2)
79 | dense_motion = torch.cat([dense_motion, identity_motion], dim=1) # (n, c+1, h, w, 2)
80 | optical_flow = (mask[...,None] * dense_motion).sum(dim=1) # (n, h, w, 2)
81 |
82 | # compute deformed source feature
83 | feature = self.feature_predictor(x) # (n, c, h, w)
84 | deformed = F.grid_sample(feature, optical_flow, align_corners=False) # (n, c, h, w)
85 | occlusion = torch.sigmoid(self.occlusion_predictor(x)) # (n, 1, h, w)
86 | feature = occlusion * deformed # (n, c, h, w)
87 |
88 | # inject local motion
89 | local_motion = torch.cat([drv_eyes, drv_audio], dim=1)
90 | for i in range(len(self.bottleneck)):
91 | feature = self.bottleneck[i](feature, local_motion)
92 |
93 | prediction = torch.sigmoid(self.decoder(feature))
94 |
95 | outputs = {
96 | 'prediction': prediction, # (n, 3, h, w)
97 | 'transformed_src': transformed_src, # (n, c*3, h, w)
98 | 'mask': mask, # (n, c+1, h, w)
99 | 'occlusion': occlusion, # (n, 1, h, w)
100 | 'optical_flow': optical_flow # (n, h, w, 2)
101 | }
102 | return outputs
103 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from loader import *
4 | from torchvision.io import write_video
5 |
6 |
7 | def inference(params):
8 | model = model_loader(params)
9 |
10 | if params['mode'] == 'obama_demo1':
11 | frames, audio = obama_demo1(model, params)
12 | elif params['mode'] == 'obama_demo2':
13 | frames, audio = obama_demo2(model, params)
14 | elif params['mode'] == 'grid_demo1':
15 | frames, audio = grid_demo1(model, params)
16 | elif params['mode'] == 'grid_demo2':
17 | frames, audio = grid_demo2(model, params)
18 | elif params['mode'] == 'koeba_demo1':
19 | frames, audio = koeba_demo1(model, params)
20 | elif params['mode'] == 'koeba_demo2':
21 | frames, audio = koeba_demo2(model, params)
22 |
23 | write_video(
24 | filename=params['save_path'],
25 | video_array=frames,
26 | fps=params['fps'],
27 | video_codec='libx264',
28 | options={'crf': '12'},
29 | audio_array=audio,
30 | audio_fps=params['samplerate'],
31 | audio_codec='aac')
32 |
33 |
34 | def obama_demo1(model, params):
35 | data = obama_demo1_data_loader(params)
36 |
37 | # fixed samples
38 | src_gpu = data['src'][None].cuda() / 255
39 | masked_src_gpu = data['masked_src'][None].cuda() / 255
40 | silence_feature_gpu = data['silence_feature'][None].cuda()
41 | black = torch.zeros_like(data['src'])
42 |
43 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8)
44 | for i in range(data['n_frames']):
45 | head_driver_gpu = data['head_driver'][i:i+1].cuda() / 255
46 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255
47 | audio_feature_gpu = data['audio_features'][i:i+1].cuda()
48 |
49 | # 1. head only
50 | ho = model(src_gpu, head_driver_gpu, masked_src_gpu, silence_feature_gpu)
51 | ho = (ho[0] * 255).type(torch.uint8).cpu()
52 |
53 | # 2. head + audio
54 | ha = model(src_gpu, head_driver_gpu, masked_src_gpu, audio_feature_gpu)
55 | ha = (ha[0] * 255).type(torch.uint8).cpu()
56 |
57 | # 3. head + eye + audio
58 | hea = model(src_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu)
59 | hea = (hea[0] * 255).type(torch.uint8).cpu()
60 |
61 | preds[i] = torch.cat([
62 | torch.cat([black, data['head_driver'][i], data['lip_driver'][i], data['eye_driver'][i]], dim=2),
63 | torch.cat([data['src'], ho, ha, hea], dim=2),
64 | ], dim=1)
65 |
66 | preds = preds.permute(0, 2, 3, 1)
67 |
68 | return preds, data['audio']
69 |
70 |
71 | def obama_demo2(model, params):
72 | data = obama_demo2_data_loader(params)
73 |
74 | # fixed samples
75 | src_gpu = data['src'][None].cuda() / 255
76 | masked_src_gpu = data['masked_src'][None].cuda() / 255
77 | silence_feature_gpu = data['silence_feature'][None].cuda()
78 |
79 | preds = torch.zeros(data['n_frames'], 3, 256, 256*6, dtype=torch.uint8)
80 | for i in range(data['n_frames']):
81 | driver_gpu = data['driver'][i:i+1].cuda() / 255
82 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255
83 | audio_feature_gpu = data['audio_features'][i:i+1].cuda()
84 |
85 | # 1. head only
86 | ho = model(src_gpu, driver_gpu, masked_src_gpu, silence_feature_gpu)
87 | ho = (ho[0] * 255).type(torch.uint8).cpu()
88 |
89 | # 2. audio only
90 | ao = model(src_gpu, src_gpu, masked_src_gpu, audio_feature_gpu)
91 | ao = (ao[0] * 255).type(torch.uint8).cpu()
92 |
93 | # 3. eye only
94 | eo = model(src_gpu, src_gpu, masked_driver_gpu, silence_feature_gpu)
95 | eo = (eo[0] * 255).type(torch.uint8).cpu()
96 |
97 | # 4. all
98 | hea = model(src_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu)
99 | hea = (hea[0] * 255).type(torch.uint8).cpu()
100 |
101 | preds[i] = torch.cat([data['src'], data['driver'][i], ho, ao, eo, hea], dim=2)
102 |
103 | preds = preds.permute(0, 2, 3, 1)
104 |
105 | return preds, data['audio']
106 |
107 |
108 | def grid_demo1(model, params):
109 | data = grid_demo1_data_loader(params)
110 |
111 | # fixed samples
112 | src1_gpu = data['src1'][None].cuda() / 255
113 | src2_gpu = data['src2'][None].cuda() / 255
114 | src3_gpu = data['src3'][None].cuda() / 255
115 | black = torch.zeros_like(data['src1'])
116 |
117 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8)
118 | for i in range(data['n_frames']):
119 | driver_gpu = data['driver'][i:i+1].cuda() / 255
120 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255
121 | audio_feature_gpu = data['audio_features'][i:i+1].cuda()
122 |
123 | # 1. head only
124 | sp1 = model(src1_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu)
125 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu()
126 |
127 | # 2. head only + audio
128 | sp2 = model(src2_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu)
129 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu()
130 |
131 | # 3. head only + audio + eye
132 | sp3 = model(src3_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu)
133 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu()
134 |
135 | preds[i] = torch.cat([
136 | torch.cat([black, data['src1'], data['src2'], data['src3']], dim=2),
137 | torch.cat([data['driver'][i], sp1, sp2, sp3], dim=2),
138 | ], dim=1)
139 |
140 | preds = preds.permute(0, 2, 3, 1)
141 |
142 | return preds, data['audio']
143 |
144 |
145 | def grid_demo2(model, params):
146 | data = grid_demo2_data_loader(params)
147 |
148 | # fixed samples
149 | src1_gpu = data['src1'][None].cuda() / 255
150 | src2_gpu = data['src2'][None].cuda() / 255
151 | src3_gpu = data['src3'][None].cuda() / 255
152 |
153 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8)
154 | for i in range(data['n_frames']):
155 | head_driver_gpu = data['head_driver'][i:i+1].cuda() / 255
156 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255
157 | audio_feature_gpu = data['audio_features'][i:i+1].cuda()
158 |
159 | # 1. head only
160 | sp1 = model(src1_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu)
161 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu()
162 |
163 | # 2. head only + audio
164 | sp2 = model(src2_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu)
165 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu()
166 |
167 | # 3. head only + audio + eye
168 | sp3 = model(src3_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu)
169 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu()
170 |
171 | half = torch.zeros_like(data['src1'])
172 | margin = 140
173 | half[:,:margin] = data['eye_driver'][i][:,:margin]
174 | half[:,margin:] = data['lip_driver'][i][:,margin:]
175 |
176 | preds[i] = torch.cat([
177 | torch.cat([data['head_driver'][i], data['src1'], data['src2'], data['src3']], dim=2),
178 | torch.cat([half, sp1, sp2, sp3], dim=2),
179 | ], dim=1)
180 |
181 | preds = preds.permute(0, 2, 3, 1)
182 |
183 | return preds, data['audio']
184 |
185 |
186 | def koeba_demo1(model, params):
187 | data = koeba_demo1_data_loader(params)
188 |
189 | # fixed samples
190 | src1_gpu = data['src1'][None].cuda() / 255
191 | src2_gpu = data['src2'][None].cuda() / 255
192 | src3_gpu = data['src3'][None].cuda() / 255
193 | black = torch.zeros_like(data['src1'])
194 |
195 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8)
196 | for i in range(data['n_frames']):
197 | driver_gpu = data['driver'][i:i+1].cuda() / 255
198 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255
199 | audio_feature_gpu = data['audio_features'][i:i+1].cuda()
200 |
201 | # 1. head only
202 | sp1 = model(src1_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu)
203 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu()
204 |
205 | # 2. head only + audio
206 | sp2 = model(src2_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu)
207 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu()
208 |
209 | # 3. head only + audio + eye
210 | sp3 = model(src3_gpu, driver_gpu, masked_driver_gpu, audio_feature_gpu)
211 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu()
212 |
213 | preds[i] = torch.cat([
214 | torch.cat([black, data['src1'], data['src2'], data['src3']], dim=2),
215 | torch.cat([data['driver'][i], sp1, sp2, sp3], dim=2),
216 | ], dim=1)
217 |
218 | preds = preds.permute(0, 2, 3, 1)
219 |
220 | return preds, data['audio']
221 |
222 |
223 | def koeba_demo2(model, params):
224 | data = koeba_demo2_data_loader(params)
225 |
226 | # fixed samples
227 | src1_gpu = data['src1'][None].cuda() / 255
228 | src2_gpu = data['src2'][None].cuda() / 255
229 | src3_gpu = data['src3'][None].cuda() / 255
230 |
231 | preds = torch.zeros(data['n_frames'], 3, 256*2, 256*4, dtype=torch.uint8)
232 | for i in range(data['n_frames']):
233 | head_driver_gpu = data['head_driver'][i:i+1].cuda() / 255
234 | masked_driver_gpu = data['masked_driver'][i:i+1].cuda() / 255
235 | audio_feature_gpu = data['audio_features'][i:i+1].cuda()
236 |
237 | # 1. head only
238 | sp1 = model(src1_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu)
239 | sp1 = (sp1[0] * 255).type(torch.uint8).cpu()
240 |
241 | # 2. head only + audio
242 | sp2 = model(src2_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu)
243 | sp2 = (sp2[0] * 255).type(torch.uint8).cpu()
244 |
245 | # 3. head only + audio + eye
246 | sp3 = model(src3_gpu, head_driver_gpu, masked_driver_gpu, audio_feature_gpu)
247 | sp3 = (sp3[0] * 255).type(torch.uint8).cpu()
248 |
249 | half = torch.zeros_like(data['src1'])
250 | margin = 128
251 | half[:,:margin] = data['eye_driver'][i][:,:margin]
252 | half[:,margin:] = data['lip_driver'][i][:,margin:]
253 |
254 | preds[i] = torch.cat([
255 | torch.cat([data['head_driver'][i], data['src1'], data['src2'], data['src3']], dim=2),
256 | torch.cat([half, sp1, sp2, sp3], dim=2),
257 | ], dim=1)
258 |
259 | preds = preds.permute(0, 2, 3, 1)
260 |
261 | return preds, data['audio']
262 |
--------------------------------------------------------------------------------
/loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import librosa
3 | import numpy as np
4 |
5 | from glob import glob
6 | from os.path import join
7 | from torchvision.io import read_image
8 | from torchvision.io.image import ImageReadMode
9 |
10 | from model import Model
11 |
12 |
13 | def model_loader(config_dict):
14 | with torch.no_grad():
15 | model = Model(config_dict['model_params'])
16 | checkpoint = torch.load(config_dict['weight_path'])
17 | model.load_state_dict(checkpoint['state_dict'], strict=False)
18 | model.eval().cuda()
19 | return model
20 |
21 |
22 | def obama_demo1_data_loader(params):
23 | dataset_root = params['dataset_root']
24 |
25 | # load audio
26 | audio, _ = librosa.load(join(dataset_root, 'audio.wav'))
27 |
28 | # load features
29 | audio_features = np.load(join(dataset_root, 'audio_features.npy'))
30 | silence_feature = np.load(join(dataset_root, 'silence_feature.npy'))
31 |
32 | # parsing image paths
33 | src_path = join(dataset_root, 'src.jpg')
34 | masked_src_path = join(dataset_root, 'masked_src.png')
35 | head_driver_paths = sorted(glob(join(dataset_root, 'head_driver/*.jpg')))
36 | eye_driver_paths = sorted(glob(join(dataset_root, 'eye_driver/*.jpg')))
37 | lip_driver_paths = sorted(glob(join(dataset_root, 'lip_driver/*.jpg')))
38 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png')))
39 |
40 | # load images
41 | n_frames = len(audio_features)
42 |
43 | src = read_image(src_path, ImageReadMode.RGB)
44 | masked_src = read_image(masked_src_path, ImageReadMode.RGB)
45 |
46 | head_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
47 | eye_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
48 | lip_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
49 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
50 |
51 | for i in range(n_frames):
52 | head_driver[i] = read_image(head_driver_paths[i], ImageReadMode.RGB)
53 | eye_driver[i] = read_image(eye_driver_paths[i], ImageReadMode.RGB)
54 | lip_driver[i] = read_image(lip_driver_paths[i], ImageReadMode.RGB)
55 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB)
56 |
57 | data = {
58 | 'n_frames': n_frames,
59 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L)
60 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L)
61 | 'silence_feature': torch.from_numpy(silence_feature), # (C, L)
62 | 'src': src,
63 | 'masked_src': masked_src,
64 | 'head_driver': head_driver,
65 | 'eye_driver': eye_driver,
66 | 'lip_driver': lip_driver,
67 | 'masked_driver': masked_driver
68 | }
69 |
70 | return data
71 |
72 |
73 | def obama_demo2_data_loader(params):
74 | dataset_root = params['dataset_root']
75 |
76 | # load audio
77 | audio, _ = librosa.load(join(dataset_root, 'audio.wav'))
78 |
79 | # load features
80 | audio_features = np.load(join(dataset_root, 'audio_features.npy'))
81 | silence_feature = np.load(join(dataset_root, 'silence_feature.npy'))
82 |
83 | # parsing image paths
84 | src_path = join(dataset_root, 'src.jpg')
85 | masked_src_path = join(dataset_root, 'masked_src.png')
86 | driver_paths = sorted(glob(join(dataset_root, 'driver/*.jpg')))
87 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png')))
88 |
89 | # load images
90 | n_frames = len(audio_features)
91 |
92 | src = read_image(src_path, ImageReadMode.RGB)
93 | masked_src = read_image(masked_src_path, ImageReadMode.RGB)
94 |
95 | driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
96 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
97 |
98 | for i in range(n_frames):
99 | driver[i] = read_image(driver_paths[i], ImageReadMode.RGB)
100 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB)
101 |
102 | data = {
103 | 'n_frames': n_frames,
104 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L)
105 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L)
106 | 'silence_feature': torch.from_numpy(silence_feature), # (C, L)
107 | 'src': src,
108 | 'masked_src': masked_src,
109 | 'driver': driver,
110 | 'masked_driver': masked_driver
111 | }
112 |
113 | return data
114 |
115 |
116 | def grid_demo1_data_loader(params):
117 | dataset_root = params['dataset_root']
118 |
119 | # load audio
120 | audio, _ = librosa.load(join(dataset_root, 'audio.wav'))
121 |
122 | # load features
123 | audio_features = np.load(join(dataset_root, 'audio_features.npy'))
124 |
125 | # parsing image paths
126 | src1_path = join(dataset_root, 'src1.jpg')
127 | src2_path = join(dataset_root, 'src2.jpg')
128 | src3_path = join(dataset_root, 'src3.jpg')
129 |
130 | driver_paths = sorted(glob(join(dataset_root, 'driver/*.jpg')))
131 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png')))
132 |
133 | # load images
134 | n_frames = len(audio_features)
135 |
136 | src1 = read_image(src1_path, ImageReadMode.RGB)
137 | src2 = read_image(src2_path, ImageReadMode.RGB)
138 | src3 = read_image(src3_path, ImageReadMode.RGB)
139 |
140 | driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
141 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
142 |
143 | for i in range(n_frames):
144 | driver[i] = read_image(driver_paths[i], ImageReadMode.RGB)
145 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB)
146 |
147 | data = {
148 | 'n_frames': n_frames,
149 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L)
150 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L)
151 | 'src1': src1,
152 | 'src2': src2,
153 | 'src3': src3,
154 | 'driver': driver,
155 | 'masked_driver': masked_driver
156 | }
157 |
158 | return data
159 |
160 |
161 | def grid_demo2_data_loader(params):
162 | dataset_root = params['dataset_root']
163 |
164 | # load audio
165 | audio, _ = librosa.load(join(dataset_root, 'audio.wav'))
166 |
167 | # load features
168 | audio_features = np.load(join(dataset_root, 'audio_features.npy'))
169 |
170 | # parsing image paths
171 | src1_path = join(dataset_root, 'src1.jpg')
172 | src2_path = join(dataset_root, 'src2.jpg')
173 | src3_path = join(dataset_root, 'src3.jpg')
174 |
175 | head_driver_paths = sorted(glob(join(dataset_root, 'head_driver/*.jpg')))
176 | eye_driver_paths = sorted(glob(join(dataset_root, 'eye_driver/*.jpg')))
177 | lip_driver_paths = sorted(glob(join(dataset_root, 'lip_driver/*.jpg')))
178 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png')))
179 |
180 | # load images
181 | n_frames = len(audio_features)
182 |
183 | src1 = read_image(src1_path, ImageReadMode.RGB)
184 | src2 = read_image(src2_path, ImageReadMode.RGB)
185 | src3 = read_image(src3_path, ImageReadMode.RGB)
186 |
187 | head_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
188 | eye_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
189 | lip_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
190 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
191 |
192 | for i in range(n_frames):
193 | head_driver[i] = read_image(head_driver_paths[i], ImageReadMode.RGB)
194 | eye_driver[i] = read_image(eye_driver_paths[i], ImageReadMode.RGB)
195 | lip_driver[i] = read_image(lip_driver_paths[i], ImageReadMode.RGB)
196 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB)
197 |
198 | data = {
199 | 'n_frames': n_frames,
200 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L)
201 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L)
202 | 'src1': src1,
203 | 'src2': src2,
204 | 'src3': src3,
205 | 'head_driver': head_driver,
206 | 'eye_driver': eye_driver,
207 | 'lip_driver': lip_driver,
208 | 'masked_driver': masked_driver
209 | }
210 |
211 | return data
212 |
213 |
214 | def koeba_demo1_data_loader(params):
215 | dataset_root = params['dataset_root']
216 |
217 | # load audio
218 | audio, _ = librosa.load(join(dataset_root, 'audio.wav'))
219 |
220 | # load features
221 | audio_features = np.load(join(dataset_root, 'audio_features.npy'))
222 |
223 | # parsing image paths
224 | src1_path = join(dataset_root, 'src1.jpg')
225 | src2_path = join(dataset_root, 'src2.jpg')
226 | src3_path = join(dataset_root, 'src3.jpg')
227 |
228 | driver_paths = sorted(glob(join(dataset_root, 'driver/*.png')))
229 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png')))
230 |
231 | # load images
232 | n_frames = len(audio_features)
233 |
234 | src1 = read_image(src1_path, ImageReadMode.RGB)
235 | src2 = read_image(src2_path, ImageReadMode.RGB)
236 | src3 = read_image(src3_path, ImageReadMode.RGB)
237 |
238 | driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
239 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
240 |
241 | for i in range(n_frames):
242 | driver[i] = read_image(driver_paths[i], ImageReadMode.RGB)
243 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB)
244 |
245 | data = {
246 | 'n_frames': n_frames,
247 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L)
248 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L)
249 | 'src1': src1,
250 | 'src2': src2,
251 | 'src3': src3,
252 | 'driver': driver,
253 | 'masked_driver': masked_driver
254 | }
255 |
256 | return data
257 |
258 |
259 | def koeba_demo2_data_loader(params):
260 | dataset_root = params['dataset_root']
261 |
262 | # load audio
263 | audio, _ = librosa.load(join(dataset_root, 'audio.wav'))
264 |
265 | # load features
266 | audio_features = np.load(join(dataset_root, 'audio_features.npy'))
267 |
268 | # parsing image paths
269 | src1_path = join(dataset_root, 'src1.jpg')
270 | src2_path = join(dataset_root, 'src2.jpg')
271 | src3_path = join(dataset_root, 'src3.jpg')
272 |
273 | head_driver_paths = sorted(glob(join(dataset_root, 'head_driver/*.png')))
274 | eye_driver_paths = sorted(glob(join(dataset_root, 'eye_driver/*.png')))
275 | lip_driver_paths = sorted(glob(join(dataset_root, 'lip_driver/*.png')))
276 | masked_driver_paths = sorted(glob(join(dataset_root, 'masked_driver/*.png')))
277 |
278 | # load images
279 | n_frames = len(audio_features)
280 |
281 | src1 = read_image(src1_path, ImageReadMode.RGB)
282 | src2 = read_image(src2_path, ImageReadMode.RGB)
283 | src3 = read_image(src3_path, ImageReadMode.RGB)
284 |
285 | head_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
286 | eye_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
287 | lip_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
288 | masked_driver = torch.zeros(n_frames, 3, 256, 256, dtype=torch.uint8)
289 |
290 | for i in range(n_frames):
291 | head_driver[i] = read_image(head_driver_paths[i], ImageReadMode.RGB)
292 | eye_driver[i] = read_image(eye_driver_paths[i], ImageReadMode.RGB)
293 | lip_driver[i] = read_image(lip_driver_paths[i], ImageReadMode.RGB)
294 | masked_driver[i] = read_image(masked_driver_paths[i], ImageReadMode.RGB)
295 |
296 | data = {
297 | 'n_frames': n_frames,
298 | 'audio': torch.from_numpy(audio).view(1,-1), # (1, L)
299 | 'audio_features': torch.from_numpy(audio_features), # (N, C, L)
300 | 'src1': src1,
301 | 'src2': src2,
302 | 'src3': src3,
303 | 'head_driver': head_driver,
304 | 'eye_driver': eye_driver,
305 | 'lip_driver': lip_driver,
306 | 'masked_driver': masked_driver
307 | }
308 |
309 | return data
310 |
--------------------------------------------------------------------------------