├── ASR_model
├── GRID
│ ├── data
│ │ └── README.md
│ ├── src
│ │ ├── models
│ │ │ ├── classifier.py
│ │ │ ├── audio_front.py
│ │ │ └── resnet.py
│ │ └── data
│ │ │ ├── transforms.py
│ │ │ ├── audio_processing.py
│ │ │ ├── stft.py
│ │ │ └── vid_aud_GRID_test.py
│ └── test.py
└── LRW
│ ├── data
│ ├── README.md
│ └── class.txt
│ ├── src
│ ├── data
│ │ ├── transforms.py
│ │ ├── audio_processing.py
│ │ ├── vid_aud_lrw_test.py
│ │ └── stft.py
│ └── models
│ │ ├── classifier.py
│ │ ├── audio_front.py
│ │ └── resnet.py
│ └── test.py
├── data
├── LRS2
│ ├── LRS2_crop
│ │ └── README.md
│ └── README.md
├── LRS3
│ └── LRS3_crop
│ │ └── README.md
├── val_4.txt
└── test_4.txt
├── img
└── Img.PNG
├── src
├── data
│ ├── transforms.py
│ ├── audio_processing.py
│ ├── stft.py
│ ├── vid_aud_grid.py
│ ├── vid_aud_lrs3.py
│ └── vid_aud_lrs2.py
└── models
│ ├── audio_front.py
│ ├── visual_front.py
│ ├── resnet.py
│ └── generator.py
├── preprocess
├── Extract_frames.py
├── Extract_audio_LRS.py
├── Preprocess.py
└── Ref_face.txt
├── README.md
├── README_LRS.md
├── README_GRID.md
├── test.py
└── test_LRS.py
/ASR_model/GRID/data/README.md:
--------------------------------------------------------------------------------
1 | Put the checkpoints of ASR here.
2 |
--------------------------------------------------------------------------------
/ASR_model/LRW/data/README.md:
--------------------------------------------------------------------------------
1 | Put the checkpoints of ASR here.
2 |
--------------------------------------------------------------------------------
/data/LRS2/LRS2_crop/README.md:
--------------------------------------------------------------------------------
1 | Put the lip coordinates files of LRS2.
2 |
--------------------------------------------------------------------------------
/data/LRS3/LRS3_crop/README.md:
--------------------------------------------------------------------------------
1 | Put the lip coordinates files of LRS3.
2 |
--------------------------------------------------------------------------------
/img/Img.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ms-dot-k/Visual-Context-Attentional-GAN/HEAD/img/Img.PNG
--------------------------------------------------------------------------------
/data/LRS2/README.md:
--------------------------------------------------------------------------------
1 | Put here the data split files of LRS2.
2 | -pretrain.txt
3 | -test.txt
4 | -train.txt
5 | -val.txt
6 |
--------------------------------------------------------------------------------
/ASR_model/GRID/src/models/classifier.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | class Backend(nn.Module):
4 | def __init__(self):
5 | super().__init__()
6 | self.gru = nn.GRU(256, 256, 2, bidirectional=True, dropout=0.3)
7 | self.fc = nn.Linear(512, 27 + 1)
8 |
9 | def forward(self, x):
10 | x = x.permute(1, 0, 2).contiguous() # S,B,96*7*7
11 | self.gru.flatten_parameters()
12 |
13 | x, _ = self.gru(x) # S,B,512
14 | x = x.permute(1, 0, 2).contiguous()
15 | x = self.fc(x) # B, S, 28
16 | return x
17 |
18 |
--------------------------------------------------------------------------------
/src/data/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms.functional as F
3 |
4 | class StatefulRandomHorizontalFlip():
5 | def __init__(self, probability=0.5):
6 | self.probability = probability
7 | self.rand = random.random()
8 |
9 | def __call__(self, img):
10 | if self.rand < self.probability:
11 | return F.hflip(img)
12 | return img
13 |
14 | def __repr__(self):
15 | return self.__class__.__name__ + '(probability={})'.format(self.probability)
16 |
17 |
18 | class Crop(object):
19 | def __init__(self, crop):
20 | self.crop = crop
21 |
22 | def __call__(self, img):
23 | return img.crop(self.crop)
24 |
--------------------------------------------------------------------------------
/ASR_model/GRID/src/data/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms.functional as F
3 |
4 | class StatefulRandomHorizontalFlip():
5 | def __init__(self, probability=0.5):
6 | self.probability = probability
7 | self.rand = random.random()
8 |
9 | def __call__(self, img):
10 | if self.rand < self.probability:
11 | return F.hflip(img)
12 | return img
13 |
14 | def __repr__(self):
15 | return self.__class__.__name__ + '(probability={})'.format(self.probability)
16 |
17 |
18 | class Crop(object):
19 | def __init__(self, crop):
20 | self.crop = crop
21 |
22 | def __call__(self, img):
23 | return img.crop(self.crop)
24 |
--------------------------------------------------------------------------------
/ASR_model/LRW/src/data/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms.functional as F
3 |
4 | class StatefulRandomHorizontalFlip():
5 | def __init__(self, probability=0.5):
6 | self.probability = probability
7 | self.rand = random.random()
8 |
9 | def __call__(self, img):
10 | if self.rand < self.probability:
11 | return F.hflip(img)
12 | return img
13 |
14 | def __repr__(self):
15 | return self.__class__.__name__ + '(probability={})'.format(self.probability)
16 |
17 |
18 | class Crop(object):
19 | def __init__(self, crop):
20 | self.crop = crop
21 |
22 | def __call__(self, img):
23 | return img.crop(self.crop)
24 |
--------------------------------------------------------------------------------
/ASR_model/LRW/src/models/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class Backend(nn.Module):
5 | def __init__(self, logits=True):
6 | super().__init__()
7 | self.logits = logits
8 |
9 | self.gru = nn.GRU(512, 512, 2, bidirectional=True, dropout=0.3)
10 | if logits:
11 | self.fc = nn.Linear(1024, 500)
12 |
13 | def forward(self, x):
14 | x = x.permute(1, 0, 2).contiguous() # S,B,512
15 | self.gru.flatten_parameters()
16 |
17 | x, _ = self.gru(x)
18 | x = x.mean(0, keepdim=False)
19 |
20 | if self.logits:
21 | pred = self.fc(x) # B, 500
22 | return pred
23 | else:
24 | return x
25 |
26 |
--------------------------------------------------------------------------------
/ASR_model/LRW/src/models/audio_front.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from src.models.resnet import BasicBlock
3 | import torch
4 |
5 | class Audio_front(nn.Module):
6 | def __init__(self, in_channels=1):
7 | super().__init__()
8 |
9 | self.in_channels = in_channels
10 |
11 | self.frontend = nn.Sequential(
12 | nn.Conv2d(self.in_channels, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
13 | nn.BatchNorm2d(128),
14 | nn.PReLU(128),
15 | nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
16 | nn.BatchNorm2d(256),
17 | nn.PReLU(256)
18 | )
19 |
20 | self.Res_block = nn.Sequential(
21 | BasicBlock(256, 256)
22 | )
23 |
24 | self.Linear = nn.Linear(256 * 20, 512)
25 |
26 | self.dropout = nn.Dropout(0.3)
27 |
28 | def forward(self, x):
29 | x = self.frontend(x) #B, 256, F/4, T/4
30 | x = self.Res_block(x) #B, 256, F/4, T/4
31 | b, c, f, t = x.size()
32 | x = x.view(b, c*f, t).transpose(1, 2).contiguous() #B, T/4, 256 * F/4
33 | x = self.dropout(x)
34 | x = self.Linear(x) #B, T/4, 512
35 | return x
36 |
37 |
--------------------------------------------------------------------------------
/ASR_model/GRID/src/models/audio_front.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from src.models.resnet import BasicBlock
3 | import torch
4 |
5 | class Audio_front(nn.Module):
6 | def __init__(self, in_channels=1):
7 | super().__init__()
8 |
9 | self.in_channels = in_channels
10 |
11 | self.frontend = nn.Sequential(
12 | nn.Conv2d(self.in_channels, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)),
13 | nn.BatchNorm2d(32),
14 | nn.PReLU(32),
15 |
16 | nn.Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)),
17 | nn.BatchNorm2d(64),
18 | nn.PReLU(64)
19 | )
20 |
21 | self.Res_block = nn.Sequential(
22 | BasicBlock(64, 64, relu_type='prelu')
23 | )
24 |
25 | self.Linear = nn.Linear(64 * 20, 256)
26 |
27 | self.dropout = nn.Dropout(0.3)
28 |
29 | def forward(self, x):
30 | x = self.frontend(x) #B, 64, F/4, T/4
31 | x = self.Res_block(x) #B, 64, F/4, T/4
32 | b, c, f, t = x.size()
33 | x = x.view(b, c*f, t).transpose(1, 2).contiguous() #B, T/4, 64 * F/4
34 | x = self.dropout(x)
35 | x = self.Linear(x) #B, T/4, 96
36 | return x
37 |
38 |
--------------------------------------------------------------------------------
/src/models/audio_front.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from src.models.resnet import BasicBlock
3 | import torch
4 |
5 | class Audio_front(nn.Module):
6 | def __init__(self, in_channels=1):
7 | super().__init__()
8 |
9 | self.in_channels = in_channels
10 |
11 | self.frontend = nn.Sequential(
12 | nn.Conv2d(self.in_channels, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
13 | nn.BatchNorm2d(128),
14 | nn.PReLU(128),
15 | nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
16 | nn.BatchNorm2d(256),
17 | nn.PReLU(256)
18 | )
19 |
20 | self.Res_block = nn.Sequential(
21 | BasicBlock(256, 256)
22 | )
23 |
24 | self.Linear = nn.Linear(256 * 20, 512)
25 |
26 | self.dropout = nn.Dropout(0.3)
27 |
28 | def forward(self, x):
29 | x = self.frontend(x) #B, 256, F/4, T/4
30 | x = self.Res_block(x) #B, 256, F/4, T/4
31 | b, c, f, t = x.size()
32 | x = x.view(b, c*f, t).transpose(1, 2).contiguous() #B, T/4, 256 * F/4
33 | x = self.dropout(x)
34 | x = self.Linear(x) #B, T/4, 512
35 | x = x.permute(1, 0, 2).contiguous() # S,B,256
36 | return x
37 |
38 |
--------------------------------------------------------------------------------
/preprocess/Extract_frames.py:
--------------------------------------------------------------------------------
1 | import os, glob, subprocess
2 | import argparse
3 |
4 | def parse_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--Grid_dir', type=str, default="Data dir to GRID_corpus")
7 | parser.add_argument("--Output_dir", type=str, default='Output dir Ex) ./GRID_imgs_aud')
8 | args = parser.parse_args()
9 | return args
10 |
11 | args = parse_args()
12 |
13 | vid_files = sorted(glob.glob(os.path.join(args.Grid_dir, '*', 'video', '*.mpg'))) #suppose the directory: Data_dir/subject/video/mpg files
14 | for k, v in enumerate(vid_files):
15 | t, f_name = os.path.split(v)
16 | t, _ = os.path.split(t)
17 | _, sub_name = os.path.split(t)
18 | out_im = os.path.join(args.Output_dir, sub_name, 'video', f_name[:-4])
19 | if len(glob.glob(os.path.join(out_im, '*.png'))) < 75: # Can resume after killed
20 | if not os.path.exists(out_im):
21 | os.makedirs(out_im)
22 | out_aud = os.path.join(args.Output_dir, sub_name, 'audio')
23 | if not os.path.exists(out_aud):
24 | os.makedirs(out_aud)
25 | subprocess.call(f'ffmpeg -y -i {v} -qscale:v 2 -r 25 {out_im}/%02d.png', shell=True)
26 | subprocess.call(f'ffmpeg -y -i {v} -ac 1 -acodec pcm_s16le -ar 16000 {os.path.join(out_aud, f_name[:-4] + ".wav")}', shell=True)
27 | print(f'{k}/{len(vid_files)}')
28 |
--------------------------------------------------------------------------------
/src/models/visual_front.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from src.models.resnet import BasicBlock, ResNet
3 |
4 | class Visual_front(nn.Module):
5 | def __init__(self, in_channels=1):
6 | super().__init__()
7 |
8 | self.in_channels = in_channels
9 |
10 | self.frontend = nn.Sequential(
11 | nn.Conv3d(self.in_channels, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), #44,44
12 | nn.BatchNorm3d(64),
13 | nn.PReLU(64),
14 | nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) #28,28
15 | )
16 |
17 | self.resnet = ResNet(BasicBlock, [2, 2, 2, 2], relu_type='prelu')
18 | self.dropout = nn.Dropout(0.3)
19 |
20 | self.sentence_encoder = nn.GRU(512, 512, 2, bidirectional=True, dropout=0.3)
21 | self.fc = nn.Linear(1024, 512)
22 |
23 | def forward(self, x):
24 | #B,C,S,H,W
25 | x = self.frontend(x) #B,C,T,H,W
26 | B, C, T, H, W = x.size()
27 | x = x.transpose(1, 2).contiguous().view(B*T, C, H, W)
28 | x = self.resnet(x) # B*T, 512
29 | x = self.dropout(x)
30 | x = x.view(B, T, -1)
31 | phons = x.permute(1, 0, 2).contiguous() # S,B,512
32 |
33 | self.sentence_encoder.flatten_parameters()
34 | sentence, _ = self.sentence_encoder(phons)
35 | sentence = self.fc(sentence).permute(1, 2, 0).contiguous() # B,512,T
36 |
37 | return phons.permute(1, 0, 2), sentence
38 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lip to Speech Synthesis with Visual Context Attentional GAN
2 |
3 | This repository contains the PyTorch implementation of the following paper:
4 | > **Lip to Speech Synthesis with Visual Context Attentional GAN**
5 | > Minsu Kim, Joanna Hong, and Yong Man Ro
6 | > \[[Paper](https://proceedings.neurips.cc/paper/2021/file/16437d40c29a1a7b1e78143c9c38f289-Paper.pdf)\] \[[Demo Video](https://kaistackr-my.sharepoint.com/:v:/g/personal/ms_k_kaist_ac_kr/EQp2Zao1ZQFDm9xDVuZubKIB_ns_6gk0L6LB3U5jd4jYKw?e=Qw8ddt)\]
7 |
8 |

9 |
10 | ## Requirements
11 | - python 3.7
12 | - pytorch 1.6 ~ 1.8
13 | - torchvision
14 | - torchaudio
15 | - ffmpeg
16 | - av
17 | - tensorboard
18 | - scikit-image 0.17.0 ~
19 | - opencv-python 3.4 ~
20 | - pillow
21 | - librosa
22 | - pystoi
23 | - pesq
24 | - scipy
25 |
26 | ## GRID
27 | Please refer [here](README_GRID.md) to run the code on GRID dataset.
28 |
29 | ## LRS2/LRS3
30 | Please refer [here](README_LRS.md) to run the code and model on LRS2 and LRS3 datasets.
31 |
32 | ## Citation
33 | If you find this work useful in your research, please cite the papers:
34 | ```
35 | @article{kim2021vcagan,
36 | title={Lip to Speech Synthesis with Visual Context Attentional GAN},
37 | author={Kim, Minsu and Hong, Joanna and Ro, Yong Man},
38 | journal={Advances in Neural Information Processing Systems},
39 | volume={34},
40 | year={2021}
41 | }
42 |
43 | @inproceedings{kim2023lip,
44 | title={Lip-to-speech synthesis in the wild with multi-task learning},
45 | author={Kim, Minsu and Hong, Joanna and Ro, Yong Man},
46 | booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
47 | pages={1--5},
48 | year={2023},
49 | organization={IEEE}
50 | }
51 | ```
52 |
--------------------------------------------------------------------------------
/preprocess/Extract_audio_LRS.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 | import argparse
3 | from tqdm import tqdm
4 | from joblib import Parallel, delayed
5 |
6 |
7 | def build_file_list(data_path, data_type):
8 | if data_type == 'LRS2':
9 | files = sorted(glob.glob(os.path.join(data_path, 'main', '*', '*.mp4')))
10 | files.extend(glob.glob(os.path.join(data_path, 'pretrain', '*', '*.mp4')))
11 | elif data_type == 'LRS3':
12 | files = sorted(glob.glob(os.path.join(data_path, 'trainval', '*', '*.mp4')))
13 | files.extend(glob.glob(os.path.join(data_path, 'pretrain', '*', '*.mp4')))
14 | files.extend(glob.glob(os.path.join(data_path, 'test', '*', '*.mp4')))
15 | else:
16 | raise NotImplementedError
17 | return [f.replace(data_path + '/', '')[:-4] for f in files]
18 |
19 | def per_file(f, args):
20 | save_path = os.path.join(args.save_path, f)
21 | if os.path.exists(save_path + '.wav'): return
22 | if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True)
23 | vid_name = os.path.join(args.data_path, f + '.mp4')
24 | os.system(
25 | f'ffmpeg -loglevel panic -nostdin -y -i {vid_name} -acodec pcm_s16le -ar 16000 -ac 1 {save_path}.wav')
26 |
27 | def main():
28 | parser = get_parser()
29 | args = parser.parse_args()
30 | file_lists = build_file_list(args.data_path, args.data_type)
31 | Parallel(n_jobs=3)(delayed(per_file)(f, args) for f in tqdm(file_lists))
32 |
33 | def get_parser():
34 | parser = argparse.ArgumentParser(
35 | description="Command-line script for preprocessing."
36 | )
37 | parser.add_argument(
38 | "--data_path", type=str, required=True, help="path for original data"
39 | )
40 | parser.add_argument(
41 | "--save_path", type=str, required=True, help="path for saving"
42 | )
43 | parser.add_argument(
44 | "--data_type", type=str, required=True, help="LRS2 or LRS3"
45 | )
46 | return parser
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
--------------------------------------------------------------------------------
/src/data/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.signal import get_window
4 | import librosa.util as librosa_util
5 |
6 |
7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
8 | n_fft=800, dtype=np.float32, norm=None):
9 | """
10 | # from librosa 0.6
11 | Compute the sum-square envelope of a window function at a given hop length.
12 | This is used to estimate modulation effects induced by windowing
13 | observations in short-time fourier transforms.
14 | Parameters
15 | ----------
16 | window : string, tuple, number, callable, or list-like
17 | Window specification, as in `get_window`
18 | n_frames : int > 0
19 | The number of analysis frames
20 | hop_length : int > 0
21 | The number of samples to advance between frames
22 | win_length : [optional]
23 | The length of the window function. By default, this matches `n_fft`.
24 | n_fft : int > 0
25 | The length of each analysis frame.
26 | dtype : np.dtype
27 | The data type of the output
28 | Returns
29 | -------
30 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
31 | The sum-squared envelope of the window function
32 | """
33 | if win_length is None:
34 | win_length = n_fft
35 |
36 | n = n_fft + hop_length * (n_frames - 1)
37 | x = np.zeros(n, dtype=dtype)
38 |
39 | # Compute the squared window at the desired length
40 | win_sq = get_window(window, win_length, fftbins=True)
41 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2
42 | win_sq = librosa_util.pad_center(win_sq, n_fft)
43 |
44 | # Fill the envelope
45 | for i in range(n_frames):
46 | sample = i * hop_length
47 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
48 | return x
49 |
50 |
51 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
52 | """
53 | PARAMS
54 | ------
55 | magnitudes: spectrogram magnitudes
56 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
57 | """
58 |
59 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
60 | angles = angles.astype(np.float32)
61 | angles = torch.autograd.Variable(torch.from_numpy(angles))
62 | angles = angles.cuda() if magnitudes.is_cuda else angles
63 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
64 |
65 | for i in range(n_iters):
66 | _, angles = stft_fn.transform(signal)
67 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
68 | return signal
69 |
70 |
71 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
72 | """
73 | PARAMS
74 | ------
75 | C: compression factor
76 | """
77 | return torch.log(torch.clamp(x, min=clip_val) * C)
78 |
79 |
80 | def dynamic_range_decompression(x, C=1):
81 | """
82 | PARAMS
83 | ------
84 | C: compression factor used to compress
85 | """
86 | return torch.exp(x) / C
--------------------------------------------------------------------------------
/ASR_model/GRID/src/data/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.signal import get_window
4 | import librosa.util as librosa_util
5 |
6 |
7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
8 | n_fft=800, dtype=np.float32, norm=None):
9 | """
10 | # from librosa 0.6
11 | Compute the sum-square envelope of a window function at a given hop length.
12 | This is used to estimate modulation effects induced by windowing
13 | observations in short-time fourier transforms.
14 | Parameters
15 | ----------
16 | window : string, tuple, number, callable, or list-like
17 | Window specification, as in `get_window`
18 | n_frames : int > 0
19 | The number of analysis frames
20 | hop_length : int > 0
21 | The number of samples to advance between frames
22 | win_length : [optional]
23 | The length of the window function. By default, this matches `n_fft`.
24 | n_fft : int > 0
25 | The length of each analysis frame.
26 | dtype : np.dtype
27 | The data type of the output
28 | Returns
29 | -------
30 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
31 | The sum-squared envelope of the window function
32 | """
33 | if win_length is None:
34 | win_length = n_fft
35 |
36 | n = n_fft + hop_length * (n_frames - 1)
37 | x = np.zeros(n, dtype=dtype)
38 |
39 | # Compute the squared window at the desired length
40 | win_sq = get_window(window, win_length, fftbins=True)
41 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2
42 | win_sq = librosa_util.pad_center(win_sq, n_fft)
43 |
44 | # Fill the envelope
45 | for i in range(n_frames):
46 | sample = i * hop_length
47 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
48 | return x
49 |
50 |
51 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
52 | """
53 | PARAMS
54 | ------
55 | magnitudes: spectrogram magnitudes
56 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
57 | """
58 |
59 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
60 | angles = angles.astype(np.float32)
61 | angles = torch.autograd.Variable(torch.from_numpy(angles))
62 | angles = angles.cuda() if magnitudes.is_cuda else angles
63 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
64 |
65 | for i in range(n_iters):
66 | _, angles = stft_fn.transform(signal)
67 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
68 | return signal
69 |
70 |
71 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
72 | """
73 | PARAMS
74 | ------
75 | C: compression factor
76 | """
77 | return torch.log(torch.clamp(x, min=clip_val) * C)
78 |
79 |
80 | def dynamic_range_decompression(x, C=1):
81 | """
82 | PARAMS
83 | ------
84 | C: compression factor used to compress
85 | """
86 | return torch.exp(x) / C
--------------------------------------------------------------------------------
/ASR_model/LRW/src/data/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.signal import get_window
4 | import librosa.util as librosa_util
5 |
6 |
7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
8 | n_fft=800, dtype=np.float32, norm=None):
9 | """
10 | # from librosa 0.6
11 | Compute the sum-square envelope of a window function at a given hop length.
12 | This is used to estimate modulation effects induced by windowing
13 | observations in short-time fourier transforms.
14 | Parameters
15 | ----------
16 | window : string, tuple, number, callable, or list-like
17 | Window specification, as in `get_window`
18 | n_frames : int > 0
19 | The number of analysis frames
20 | hop_length : int > 0
21 | The number of samples to advance between frames
22 | win_length : [optional]
23 | The length of the window function. By default, this matches `n_fft`.
24 | n_fft : int > 0
25 | The length of each analysis frame.
26 | dtype : np.dtype
27 | The data type of the output
28 | Returns
29 | -------
30 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
31 | The sum-squared envelope of the window function
32 | """
33 | if win_length is None:
34 | win_length = n_fft
35 |
36 | n = n_fft + hop_length * (n_frames - 1)
37 | x = np.zeros(n, dtype=dtype)
38 |
39 | # Compute the squared window at the desired length
40 | win_sq = get_window(window, win_length, fftbins=True)
41 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2
42 | win_sq = librosa_util.pad_center(win_sq, n_fft)
43 |
44 | # Fill the envelope
45 | for i in range(n_frames):
46 | sample = i * hop_length
47 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
48 | return x
49 |
50 |
51 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
52 | """
53 | PARAMS
54 | ------
55 | magnitudes: spectrogram magnitudes
56 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
57 | """
58 |
59 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
60 | angles = angles.astype(np.float32)
61 | angles = torch.autograd.Variable(torch.from_numpy(angles))
62 | angles = angles.cuda() if magnitudes.is_cuda else angles
63 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
64 |
65 | for i in range(n_iters):
66 | _, angles = stft_fn.transform(signal)
67 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
68 | return signal
69 |
70 |
71 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
72 | """
73 | PARAMS
74 | ------
75 | C: compression factor
76 | """
77 | return torch.log(torch.clamp(x, min=clip_val) * C)
78 |
79 |
80 | def dynamic_range_decompression(x, C=1):
81 | """
82 | PARAMS
83 | ------
84 | C: compression factor used to compress
85 | """
86 | return torch.exp(x) / C
--------------------------------------------------------------------------------
/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=1, bias=False)
8 |
9 |
10 | def downsample_basic_block(inplanes, outplanes, stride):
11 | return nn.Sequential(
12 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
13 | nn.BatchNorm2d(outplanes),
14 | )
15 |
16 |
17 | def downsample_basic_block_v2(inplanes, outplanes, stride):
18 | return nn.Sequential(
19 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
20 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False),
21 | nn.BatchNorm2d(outplanes),
22 | )
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='relu'):
29 | super(BasicBlock, self).__init__()
30 |
31 | assert relu_type in ['relu', 'prelu']
32 |
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = nn.BatchNorm2d(planes)
35 |
36 | # type of ReLU is an input option
37 | if relu_type == 'relu':
38 | self.relu1 = nn.ReLU(inplace=True)
39 | self.relu2 = nn.ReLU(inplace=True)
40 | elif relu_type == 'prelu':
41 | self.relu1 = nn.PReLU(num_parameters=planes)
42 | self.relu2 = nn.PReLU(num_parameters=planes)
43 | else:
44 | raise Exception('relu type not implemented')
45 | # --------
46 |
47 | self.conv2 = conv3x3(planes, planes)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 |
50 | self.downsample = downsample
51 | self.stride = stride
52 |
53 | def forward(self, x):
54 | residual = x
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu1(out)
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 | if self.downsample is not None:
61 | residual = self.downsample(x)
62 |
63 | out += residual
64 | out = self.relu2(out)
65 |
66 | return out
67 |
68 |
69 | class ResNet(nn.Module):
70 |
71 | def __init__(self, block, layers, num_classes=1000, relu_type='relu', gamma_zero=False, avg_pool_downsample=False):
72 | self.inplanes = 64
73 | self.relu_type = relu_type
74 | self.gamma_zero = gamma_zero
75 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
76 |
77 | super(ResNet, self).__init__()
78 | self.layer1 = self._make_layer(block, 64, layers[0])
79 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
82 | self.avgpool = nn.AvgPool2d(4)
83 |
84 | # default init
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88 | m.weight.data.normal_(0, math.sqrt(2. / n))
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 | # nn.init.ones_(m.weight)
93 | # nn.init.zeros_(m.bias)
94 |
95 | if self.gamma_zero:
96 | for m in self.modules():
97 | if isinstance(m, BasicBlock):
98 | m.bn2.weight.data.zero_()
99 |
100 | def _make_layer(self, block, planes, blocks, stride=1):
101 |
102 | downsample = None
103 | if stride != 1 or self.inplanes != planes * block.expansion:
104 | downsample = self.downsample_block(inplanes=self.inplanes,
105 | outplanes=planes * block.expansion,
106 | stride=stride)
107 |
108 | layers = []
109 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type))
110 | self.inplanes = planes * block.expansion
111 | for i in range(1, blocks):
112 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type))
113 |
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x):
117 | x = self.layer1(x)
118 | x = self.layer2(x)
119 | x = self.layer3(x)
120 | x = self.layer4(x)
121 | x = self.avgpool(x)
122 | x = x.view(x.size(0), -1)
123 | return x
124 |
--------------------------------------------------------------------------------
/ASR_model/GRID/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=1, bias=False)
8 |
9 | def downsample_basic_block(inplanes, outplanes, stride):
10 | return nn.Sequential(
11 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
12 | nn.BatchNorm2d(outplanes),
13 | )
14 |
15 |
16 | def downsample_basic_block_v2(inplanes, outplanes, stride):
17 | return nn.Sequential(
18 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
19 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False),
20 | nn.BatchNorm2d(outplanes),
21 | )
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='relu'):
28 | super(BasicBlock, self).__init__()
29 |
30 | assert relu_type in ['relu', 'prelu']
31 |
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 |
35 | # type of ReLU is an input option
36 | if relu_type == 'relu':
37 | self.relu1 = nn.ReLU(inplace=True)
38 | self.relu2 = nn.ReLU(inplace=True)
39 | elif relu_type == 'prelu':
40 | self.relu1 = nn.PReLU(num_parameters=planes)
41 | self.relu2 = nn.PReLU(num_parameters=planes)
42 | else:
43 | raise Exception('relu type not implemented')
44 | # --------
45 |
46 | self.conv2 = conv3x3(planes, planes)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 |
49 | self.downsample = downsample
50 | self.stride = stride
51 |
52 | def forward(self, x):
53 | residual = x
54 | out = self.conv1(x)
55 | out = self.bn1(out)
56 | out = self.relu1(out)
57 | out = self.conv2(out)
58 | out = self.bn2(out)
59 | if self.downsample is not None:
60 | residual = self.downsample(x)
61 |
62 | out += residual
63 | out = self.relu2(out)
64 |
65 | return out
66 |
67 |
68 | class ResNet(nn.Module):
69 |
70 | def __init__(self, block, layers, num_classes=1000, relu_type='relu', gamma_zero=False, avg_pool_downsample=False):
71 | self.inplanes = 64
72 | self.relu_type = relu_type
73 | self.gamma_zero = gamma_zero
74 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
75 |
76 | super(ResNet, self).__init__()
77 | self.layer1 = self._make_layer(block, 64, layers[0])
78 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
79 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
80 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
81 | self.avgpool = nn.AvgPool2d(3)
82 |
83 | # default init
84 | for m in self.modules():
85 | if isinstance(m, nn.Conv2d):
86 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
87 | m.weight.data.normal_(0, math.sqrt(2. / n))
88 | elif isinstance(m, nn.BatchNorm2d):
89 | m.weight.data.fill_(1)
90 | m.bias.data.zero_()
91 | # nn.init.ones_(m.weight)
92 | # nn.init.zeros_(m.bias)
93 |
94 | if self.gamma_zero:
95 | for m in self.modules():
96 | if isinstance(m, BasicBlock):
97 | m.bn2.weight.data.zero_()
98 |
99 | def _make_layer(self, block, planes, blocks, stride=1):
100 |
101 | downsample = None
102 | if stride != 1 or self.inplanes != planes * block.expansion:
103 | downsample = self.downsample_block(inplanes=self.inplanes,
104 | outplanes=planes * block.expansion,
105 | stride=stride)
106 |
107 | layers = []
108 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type))
109 | self.inplanes = planes * block.expansion
110 | for i in range(1, blocks):
111 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type))
112 |
113 | return nn.Sequential(*layers)
114 |
115 | def forward(self, x):
116 | x = self.layer1(x)
117 | x = self.layer2(x)
118 | x = self.layer3(x)
119 | x = self.layer4(x)
120 | x = self.avgpool(x)
121 | x = x.view(x.size(0), -1)
122 | return x
123 |
--------------------------------------------------------------------------------
/ASR_model/LRW/src/models/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=1, bias=False)
8 |
9 |
10 | def downsample_basic_block(inplanes, outplanes, stride):
11 | return nn.Sequential(
12 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
13 | nn.BatchNorm2d(outplanes),
14 | )
15 |
16 |
17 | def downsample_basic_block_v2(inplanes, outplanes, stride):
18 | return nn.Sequential(
19 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
20 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False),
21 | nn.BatchNorm2d(outplanes),
22 | )
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='relu'):
29 | super(BasicBlock, self).__init__()
30 |
31 | assert relu_type in ['relu', 'prelu']
32 |
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = nn.BatchNorm2d(planes)
35 |
36 | # type of ReLU is an input option
37 | if relu_type == 'relu':
38 | self.relu1 = nn.ReLU(inplace=True)
39 | self.relu2 = nn.ReLU(inplace=True)
40 | elif relu_type == 'prelu':
41 | self.relu1 = nn.PReLU(num_parameters=planes)
42 | self.relu2 = nn.PReLU(num_parameters=planes)
43 | else:
44 | raise Exception('relu type not implemented')
45 | # --------
46 |
47 | self.conv2 = conv3x3(planes, planes)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 |
50 | self.downsample = downsample
51 | self.stride = stride
52 |
53 | def forward(self, x):
54 | residual = x
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu1(out)
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 | if self.downsample is not None:
61 | residual = self.downsample(x)
62 |
63 | out += residual
64 | out = self.relu2(out)
65 |
66 | return out
67 |
68 |
69 | class ResNet(nn.Module):
70 |
71 | def __init__(self, block, layers, num_classes=1000, relu_type='relu', gamma_zero=False, avg_pool_downsample=False):
72 | self.inplanes = 64
73 | self.relu_type = relu_type
74 | self.gamma_zero = gamma_zero
75 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
76 |
77 | super(ResNet, self).__init__()
78 | self.layer1 = self._make_layer(block, 64, layers[0])
79 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
82 | self.avgpool = nn.AvgPool2d(3)
83 |
84 | # default init
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88 | m.weight.data.normal_(0, math.sqrt(2. / n))
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 | # nn.init.ones_(m.weight)
93 | # nn.init.zeros_(m.bias)
94 |
95 | if self.gamma_zero:
96 | for m in self.modules():
97 | if isinstance(m, BasicBlock):
98 | m.bn2.weight.data.zero_()
99 |
100 | def _make_layer(self, block, planes, blocks, stride=1):
101 |
102 | downsample = None
103 | if stride != 1 or self.inplanes != planes * block.expansion:
104 | downsample = self.downsample_block(inplanes=self.inplanes,
105 | outplanes=planes * block.expansion,
106 | stride=stride)
107 |
108 | layers = []
109 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type))
110 | self.inplanes = planes * block.expansion
111 | for i in range(1, blocks):
112 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type))
113 |
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x):
117 | x = self.layer1(x)
118 | x = self.layer2(x)
119 | x = self.layer3(x)
120 | x = self.layer4(x)
121 | x = self.avgpool(x)
122 | x = x.view(x.size(0), -1)
123 | return x
124 |
--------------------------------------------------------------------------------
/data/val_4.txt:
--------------------------------------------------------------------------------
1 | s1/video/srih1s.mp4
2 | s1/video/srih2p.mp4
3 | s1/video/lwbz6p.mp4
4 | s1/video/bwwh5a.mp4
5 | s1/video/lbbq8n.mp4
6 | s1/video/lbbk5s.mp4
7 | s1/video/lwbl6n.mp4
8 | s1/video/srin5s.mp4
9 | s1/video/lwie6p.mp4
10 | s1/video/bwwn8p.mp4
11 | s1/video/pbib9a.mp4
12 | s1/video/sgai6n.mp4
13 | s1/video/bgaa9a.mp4
14 | s1/video/praj2p.mp4
15 | s1/video/bgig9a.mp4
16 | s1/video/bbwg1s.mp4
17 | s1/video/lwaz1s.mp4
18 | s1/video/sgbj2p.mp4
19 | s1/video/bbas1s.mp4
20 | s1/video/pwbd6n.mp4
21 | s1/video/lrwr8n.mp4
22 | s1/video/lwws5s.mp4
23 | s1/video/swao5s.mp4
24 | s1/video/pwix2p.mp4
25 | s1/video/priv4n.mp4
26 | s1/video/bgwu8p.mp4
27 | s1/video/sgav4n.mp4
28 | s1/video/pwaj7s.mp4
29 | s1/video/lgwm8p.mp4
30 | s1/video/pwwk7a.mp4
31 | s1/video/pgwl3a.mp4
32 | s1/video/pwbd7s.mp4
33 | s1/video/pgby6p.mp4
34 | s1/video/lbwr4p.mp4
35 | s1/video/sbwu6p.mp4
36 | s1/video/swbc2p.mp4
37 | s1/video/lwws7a.mp4
38 | s1/video/sgai9a.mp4
39 | s1/video/lwal2n.mp4
40 | s1/video/brwg6n.mp4
41 | s1/video/bbal7s.mp4
42 | s1/video/srbo5a.mp4
43 | s1/video/prip2p.mp4
44 | s1/video/pwwe1s.mp4
45 | s1/video/bbiz1s.mp4
46 | s1/video/srwi2n.mp4
47 | s1/video/bgahzn.mp4
48 | s1/video/lbwk9s.mp4
49 | s1/video/pgiq4p.mp4
50 | s2/video/lbakzs.mp4
51 | s2/video/pwwe1p.mp4
52 | s2/video/brwa1n.mp4
53 | s2/video/lriyzs.mp4
54 | s2/video/bwim4s.mp4
55 | s2/video/pgby6a.mp4
56 | s2/video/lrwr7n.mp4
57 | s2/video/lbwlza.mp4
58 | s2/video/bbir8a.mp4
59 | s2/video/lwbs2a.mp4
60 | s2/video/lgwtzs.mp4
61 | s2/video/pwaq2a.mp4
62 | s2/video/lbad6s.mp4
63 | s2/video/lgir9p.mp4
64 | s2/video/lbij7p.mp4
65 | s2/video/pgayzs.mp4
66 | s2/video/brimza.mp4
67 | s2/video/lrwz1n.mp4
68 | s2/video/lwbs1p.mp4
69 | s2/video/srah3n.mp4
70 | s2/video/sbbt9n.mp4
71 | s2/video/praj2a.mp4
72 | s2/video/lbbk3n.mp4
73 | s2/video/swao5p.mp4
74 | s2/video/bwbt8a.mp4
75 | s2/video/swwi8s.mp4
76 | s2/video/lgil3n.mp4
77 | s2/video/bwbn4a.mp4
78 | s2/video/lbaq5p.mp4
79 | s2/video/bbws7n.mp4
80 | s2/video/bbir5n.mp4
81 | s2/video/swbv2s.mp4
82 | s2/video/srwb9p.mp4
83 | s2/video/sgai7p.mp4
84 | s2/video/sgbjzs.mp4
85 | s2/video/bril7n.mp4
86 | s2/video/srwb7n.mp4
87 | s2/video/brbg4a.mp4
88 | s2/video/bwbn1n.mp4
89 | s2/video/sgbp4s.mp4
90 | s2/video/bgan4s.mp4
91 | s2/video/lbbk5p.mp4
92 | s2/video/pbwj3p.mp4
93 | s2/video/srah6a.mp4
94 | s2/video/pgiq4a.mp4
95 | s2/video/prwd4s.mp4
96 | s2/video/lgas4a.mp4
97 | s2/video/lrak5n.mp4
98 | s2/video/lwws4s.mp4
99 | s2/video/lbby3p.mp4
100 | s4/video/brwazs.mp4
101 | s4/video/lraq7n.mp4
102 | s4/video/sgai4s.mp4
103 | s4/video/srwb5n.mp4
104 | s4/video/lgaz3n.mp4
105 | s4/video/srau1p.mp4
106 | s4/video/pbbv5p.mp4
107 | s4/video/sgbc6a.mp4
108 | s4/video/prap2s.mp4
109 | s4/video/lgir5n.mp4
110 | s4/video/bwwn6a.mp4
111 | s4/video/sgwj1n.mp4
112 | s4/video/lgbl9n.mp4
113 | s4/video/lbip9p.mp4
114 | s4/video/sbig3p.mp4
115 | s4/video/pgix4s.mp4
116 | s4/video/sbbuza.mp4
117 | s4/video/pbbi5n.mp4
118 | s4/video/lbwk6s.mp4
119 | s4/video/pbwp6a.mp4
120 | s4/video/pwbj7n.mp4
121 | s4/video/lwwl9p.mp4
122 | s4/video/bwwn5p.mp4
123 | s4/video/brwt3p.mp4
124 | s4/video/sgwx1p.mp4
125 | s4/video/bras4s.mp4
126 | s4/video/swao4a.mp4
127 | s4/video/srwu8s.mp4
128 | s4/video/pgbq7n.mp4
129 | s4/video/sgih9n.mp4
130 | s4/video/pwwx9n.mp4
131 | s4/video/lwar3n.mp4
132 | s4/video/bgan2s.mp4
133 | s4/video/pbib5p.mp4
134 | s4/video/pwbq4a.mp4
135 | s4/video/lbaq1n.mp4
136 | s4/video/sbim6s.mp4
137 | s4/video/brif4a.mp4
138 | s4/video/pbwxza.mp4
139 | s4/video/lwbr9p.mp4
140 | s4/video/priv3p.mp4
141 | s4/video/prwd3p.mp4
142 | s4/video/lbwe4a.mp4
143 | s4/video/lrae2a.mp4
144 | s4/video/bwam7p.mp4
145 | s4/video/sbat3n.mp4
146 | s4/video/pgwk8s.mp4
147 | s4/video/swbv2a.mp4
148 | s4/video/lbij5p.mp4
149 | s4/video/lbaj9p.mp4
150 | s29/video/bric7s.mp4
151 | s29/video/swiz1s.mp4
152 | s29/video/pbif4p.mp4
153 | s29/video/sgbg3s.mp4
154 | s29/video/pwin1a.mp4
155 | s29/video/lwwj5a.mp4
156 | s29/video/lrwc5a.mp4
157 | s29/video/bbbp6n.mp4
158 | s29/video/srar4n.mp4
159 | s29/video/pratzn.mp4
160 | s29/video/lgax1a.mp4
161 | s29/video/lbauzn.mp4
162 | s29/video/sbwl4p.mp4
163 | s29/video/srwz7s.mp4
164 | s29/video/lgwx6n.mp4
165 | s29/video/sgbt1s.mp4
166 | s29/video/bwwk9s.mp4
167 | s29/video/pwbb1a.mp4
168 | s29/video/lgipzn.mp4
169 | s29/video/pgbu9a.mp4
170 | s29/video/lgic4p.mp4
171 | s29/video/sbix8n.mp4
172 | s29/video/sgwt4n.mp4
173 | s29/video/lbau3a.mp4
174 | s29/video/sbikzn.mp4
175 | s29/video/lrao2n.mp4
176 | s29/video/lgbq1a.mp4
177 | s29/video/bgwl6p.mp4
178 | s29/video/pgwi4p.mp4
179 | s29/video/pwwozn.mp4
180 | s29/video/lwav4p.mp4
181 | s29/video/lrin9s.mp4
182 | s29/video/pwag9s.mp4
183 | s29/video/srbl4n.mp4
184 | s29/video/sbiq7a.mp4
185 | s29/video/lwwp7s.mp4
186 | s29/video/bgik2n.mp4
187 | s29/video/swir8p.mp4
188 | s29/video/pbaz8n.mp4
189 | s29/video/srwm1a.mp4
190 | s29/video/lwbc4n.mp4
191 | s29/video/lrwi9a.mp4
192 | s29/video/sray9s.mp4
193 | s29/video/srae6n.mp4
194 | s29/video/srazzp.mp4
195 | s29/video/bbac6p.mp4
196 | s29/video/pgwo6n.mp4
197 | s29/video/lrbv3a.mp4
198 | s29/video/sbwzzn.mp4
199 | s29/video/bbbj5a.mp4
200 |
--------------------------------------------------------------------------------
/data/test_4.txt:
--------------------------------------------------------------------------------
1 | s1/video/sbag8n.mp4
2 | s1/video/pbwj2n.mp4
3 | s1/video/bwat5a.mp4
4 | s1/video/pwbyzp.mp4
5 | s1/video/lgiz2n.mp4
6 | s1/video/sbwu4n.mp4
7 | s1/video/sbwo3a.mp4
8 | s1/video/brbm6n.mp4
9 | s1/video/swwc5s.mp4
10 | s1/video/bwaa3a.mp4
11 | s1/video/bwig2p.mp4
12 | s1/video/bwaa2p.mp4
13 | s1/video/sgbc6n.mp4
14 | s1/video/briz7s.mp4
15 | s1/video/lrie1a.mp4
16 | s1/video/pwip9a.mp4
17 | s1/video/swau8n.mp4
18 | s1/video/lrbr4n.mp4
19 | s1/video/lgwm9a.mp4
20 | s1/video/lwbl7s.mp4
21 | s1/video/pwidzp.mp4
22 | s1/video/bril9s.mp4
23 | s1/video/swiu5s.mp4
24 | s1/video/swwp2n.mp4
25 | s1/video/lrwszp.mp4
26 | s1/video/bwat4p.mp4
27 | s1/video/lwazzn.mp4
28 | s1/video/pwwq8n.mp4
29 | s1/video/sgav7a.mp4
30 | s1/video/brwg7s.mp4
31 | s1/video/bgbb3a.mp4
32 | s1/video/sgbc9a.mp4
33 | s1/video/swbc3a.mp4
34 | s1/video/sgbx1a.mp4
35 | s1/video/srit9s.mp4
36 | s1/video/bbwm4n.mp4
37 | s1/video/bbaf2n.mp4
38 | s1/video/srbu6n.mp4
39 | s1/video/bwwn9a.mp4
40 | s1/video/srwv1s.mp4
41 | s1/video/prii8p.mp4
42 | s1/video/lgbs6n.mp4
43 | s1/video/lwal3s.mp4
44 | s1/video/bwit1a.mp4
45 | s1/video/pbio6p.mp4
46 | s1/video/pwwq9s.mp4
47 | s1/video/lrbe6n.mp4
48 | s1/video/pwaq3a.mp4
49 | s1/video/lbad6n.mp4
50 | s1/video/sgii3s.mp4
51 | s2/video/pwwezs.mp4
52 | s2/video/lwar7p.mp4
53 | s2/video/pbao7n.mp4
54 | s2/video/lwae9p.mp4
55 | s2/video/bbbz9p.mp4
56 | s2/video/pgby4s.mp4
57 | s2/video/lgws9n.mp4
58 | s2/video/pbwp7p.mp4
59 | s2/video/sgbi9n.mp4
60 | s2/video/lwbszs.mp4
61 | s2/video/priv6a.mp4
62 | s2/video/sbaa5p.mp4
63 | s2/video/prwd5p.mp4
64 | s2/video/lgal7n.mp4
65 | s2/video/swio1p.mp4
66 | s2/video/lbaq3n.mp4
67 | s2/video/pwbq6a.mp4
68 | s2/video/prbc9n.mp4
69 | s2/video/brwt6a.mp4
70 | s2/video/bgaa5n.mp4
71 | s2/video/brwa4a.mp4
72 | s2/video/lgwaza.mp4
73 | s2/video/swah9n.mp4
74 | s2/video/bgag9n.mp4
75 | s2/video/pbih9n.mp4
76 | s2/video/lbby1n.mp4
77 | s2/video/brif6a.mp4
78 | s2/video/pwip7p.mp4
79 | s2/video/lbbk6a.mp4
80 | s2/video/bbal5n.mp4
81 | s2/video/lbij6s.mp4
82 | s2/video/bbie7n.mp4
83 | s2/video/brwg8a.mp4
84 | s2/video/lgaz8a.mp4
85 | s2/video/prip1p.mp4
86 | s2/video/pwwy3p.mp4
87 | s2/video/swao6a.mp4
88 | s2/video/lgal9p.mp4
89 | s2/video/lrwz4a.mp4
90 | s2/video/prbx1n.mp4
91 | s2/video/sban4a.mp4
92 | s2/video/pgad7n.mp4
93 | s2/video/lwwf6s.mp4
94 | s2/video/srwi4a.mp4
95 | s2/video/bbar9n.mp4
96 | s2/video/prwq2s.mp4
97 | s2/video/prio9n.mp4
98 | s2/video/srih1p.mp4
99 | s2/video/priv5p.mp4
100 | s2/video/srwv1p.mp4
101 | s4/video/briszs.mp4
102 | s4/video/bbwm1n.mp4
103 | s4/video/sgwxzs.mp4
104 | s4/video/srig7n.mp4
105 | s4/video/lriyza.mp4
106 | s4/video/braf8a.mp4
107 | s4/video/pwbx8a.mp4
108 | s4/video/pbai1n.mp4
109 | s4/video/bris2a.mp4
110 | s4/video/lbwe3p.mp4
111 | s4/video/bgia1p.mp4
112 | s4/video/bwbm9n.mp4
113 | s4/video/pwij1p.mp4
114 | s4/video/brwt4a.mp4
115 | s4/video/pbbi6s.mp4
116 | s4/video/lrby7p.mp4
117 | s4/video/pwbx7p.mp4
118 | s4/video/sgio6a.mp4
119 | s4/video/sgbp1n.mp4
120 | s4/video/swih4s.mp4
121 | s4/video/bbbf5p.mp4
122 | s4/video/lbwe2s.mp4
123 | s4/video/prap3p.mp4
124 | s4/video/bbir6a.mp4
125 | s4/video/lbidzs.mp4
126 | s4/video/prii6a.mp4
127 | s4/video/bgim9p.mp4
128 | s4/video/brwm8s.mp4
129 | s4/video/pbwp5p.mp4
130 | s4/video/pgaq3n.mp4
131 | s4/video/lwbe9n.mp4
132 | s4/video/briz5p.mp4
133 | s4/video/pgij5n.mp4
134 | s4/video/swbi4a.mp4
135 | s4/video/pwwq5n.mp4
136 | s4/video/lway9p.mp4
137 | s4/video/swbo8a.mp4
138 | s4/video/bwbg5n.mp4
139 | s4/video/bgit2s.mp4
140 | s4/video/pwijzs.mp4
141 | s4/video/lrbr3p.mp4
142 | s4/video/priv4a.mp4
143 | s4/video/pric2a.mp4
144 | s4/video/lbwk7p.mp4
145 | s4/video/pwiv9p.mp4
146 | s4/video/lbwy3n.mp4
147 | s4/video/sgwp6s.mp4
148 | s4/video/lwbz1n.mp4
149 | s4/video/srbb2s.mp4
150 | s4/video/prai8s.mp4
151 | s29/video/sgil8n.mp4
152 | s29/video/lbbozn.mp4
153 | s29/video/lbwi3a.mp4
154 | s29/video/sgam2n.mp4
155 | s29/video/lrwc3s.mp4
156 | s29/video/sbwr7s.mp4
157 | s29/video/prwhzn.mp4
158 | s29/video/lriu4p.mp4
159 | s29/video/lgic5a.mp4
160 | s29/video/bbwj7s.mp4
161 | s29/video/lrwv4n.mp4
162 | s29/video/bbbx3a.mp4
163 | s29/video/lrbi2n.mp4
164 | s29/video/brbkzp.mp4
165 | s29/video/swil4p.mp4
166 | s29/video/prwh2p.mp4
167 | s29/video/brap8n.mp4
168 | s29/video/bbipzp.mp4
169 | s29/video/pwig5s.mp4
170 | s29/video/pbwm8n.mp4
171 | s29/video/brbq4p.mp4
172 | s29/video/bgbl1s.mp4
173 | s29/video/lwio6p.mp4
174 | s29/video/bbap2n.mp4
175 | s29/video/bbai8n.mp4
176 | s29/video/lwbv8p.mp4
177 | s29/video/sgam4p.mp4
178 | s29/video/sbwe8n.mp4
179 | s29/video/briv9s.mp4
180 | s29/video/bwij6n.mp4
181 | s29/video/lwai5s.mp4
182 | s29/video/pwim9s.mp4
183 | s29/video/lbwo6p.mp4
184 | s29/video/srws5a.mp4
185 | s29/video/brap9s.mp4
186 | s29/video/pbbtzp.mp4
187 | s29/video/pgih3a.mp4
188 | s29/video/bgws1a.mp4
189 | s29/video/sbbe6p.mp4
190 | s29/video/lgwd6p.mp4
191 | s29/video/bgblzn.mp4
192 | s29/video/bwiqzn.mp4
193 | s29/video/pbwt2n.mp4
194 | s29/video/brbx8p.mp4
195 | s29/video/sriy5s.mp4
196 | s29/video/bgae5a.mp4
197 | s29/video/sgif6p.mp4
198 | s29/video/prim5a.mp4
199 | s29/video/lgac9a.mp4
200 | s29/video/srwz6n.mp4
201 |
--------------------------------------------------------------------------------
/ASR_model/LRW/src/data/vid_aud_lrw_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torchaudio
8 | import torchvision
9 | from torchvision import transforms
10 | from torch.utils.data import DataLoader, Dataset
11 | from librosa.filters import mel as librosa_mel_fn
12 | from src.data.stft import STFT
13 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression
14 | import glob, math
15 | from scipy import signal
16 | import librosa
17 |
18 | log1e5 = math.log(1e-5)
19 |
20 | class MultiDataset(Dataset):
21 | def __init__(self, lrw, mode, max_v_timesteps=155, augmentations=False, num_mel_bins=80, wav=False):
22 | self.max_v_timesteps = max_v_timesteps
23 | self.augmentations = augmentations if mode == 'train' else False
24 | self.num_mel_bins = num_mel_bins
25 | self.skip_long_samples = True
26 | self.wav = wav
27 | self.file_paths, self.word_list = self.build_file_list(lrw)
28 | self.word2int = {word: index for index, word in self.word_list.items()}
29 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=80, sampling_rate=16000, mel_fmin=55., mel_fmax=7600.)
30 |
31 | def build_file_list(self, lrw):
32 | word = {}
33 | # data_dir: spec_mel (or wav) / class / test (train, val) / class_#.npz(or .wav)
34 | if self.wav:
35 | files = sorted(glob.glob(os.path.join(lrw, '*', '*', '*.wav')))
36 | else:
37 | files = sorted(glob.glob(os.path.join(lrw, '*', '*', '*.npz')))
38 |
39 | with open('./data/class.txt', 'r') as f:
40 | lines = f.readlines()
41 | for i, l in enumerate(lines):
42 | word[i] = l.strip().upper()
43 |
44 | return files, word
45 |
46 | def __len__(self):
47 | return len(self.file_paths)
48 |
49 | def __getitem__(self, idx):
50 | file_path = self.file_paths[idx]
51 | content = os.path.split(file_path)[-1].split('_')[0].upper()
52 | target = self.word2int[content]
53 |
54 | if self.wav:
55 | aud, sr = torchaudio.load(file_path)
56 | if round(sr) != 16000:
57 | aud = torch.tensor(librosa.resample(aud.squeeze(0).numpy(), sr, 16000)).unsqueeze(0)
58 |
59 | aud = aud / torch.abs(aud).max() * 0.9
60 | aud = torch.FloatTensor(self.preemphasize(aud.squeeze(0))).unsqueeze(0)
61 | aud = torch.clamp(aud, min=-1, max=1)
62 |
63 | spec = self.stft.mel_spectrogram(aud)
64 |
65 | else:
66 | data = np.load(file_path)
67 | spec = data['mel']
68 | data.close()
69 |
70 | spec = torch.FloatTensor(self.denormalize(spec))
71 |
72 | spec = spec[:, :, :self.max_v_timesteps * 4]
73 | num_a_frames = spec.size(2)
74 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(spec)
75 |
76 | assert spec.size(2) == 116
77 | return spec, target
78 |
79 | def preemphasize(self, aud):
80 | aud = signal.lfilter([1, -0.97], [1], aud)
81 | return aud
82 |
83 | def denormalize(self, melspec):
84 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5
85 | return melspec
86 |
87 | class TacotronSTFT(torch.nn.Module):
88 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
89 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
90 | mel_fmax=8000.0):
91 | super(TacotronSTFT, self).__init__()
92 | self.n_mel_channels = n_mel_channels
93 | self.sampling_rate = sampling_rate
94 | self.stft_fn = STFT(filter_length, hop_length, win_length)
95 | mel_basis = librosa_mel_fn(
96 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
97 | mel_basis = torch.from_numpy(mel_basis).float()
98 | self.register_buffer('mel_basis', mel_basis)
99 |
100 | def spectral_normalize(self, magnitudes):
101 | output = dynamic_range_compression(magnitudes)
102 | return output
103 |
104 | def spectral_de_normalize(self, magnitudes):
105 | output = dynamic_range_decompression(magnitudes)
106 | return output
107 |
108 | def mel_spectrogram(self, y):
109 | """Computes mel-spectrograms from a batch of waves
110 | PARAMS
111 | ------
112 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
113 | RETURNS
114 | -------
115 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
116 | """
117 | assert(torch.min(y.data) >= -1)
118 | assert(torch.max(y.data) <= 1)
119 |
120 | magnitudes, phases = self.stft_fn.transform(y)
121 | magnitudes = magnitudes.data
122 | mel_output = torch.matmul(self.mel_basis, magnitudes)
123 | mel_output = self.spectral_normalize(mel_output)
124 | return mel_output
--------------------------------------------------------------------------------
/README_LRS.md:
--------------------------------------------------------------------------------
1 | We provide the training code and trained VCA-GAN model on LRS2 and LRS3.
2 | The performances are reported in our ICASSP23 paper '[Lip-to-Speech Synthesis in the Wild with Multi-task Learning](https://arxiv.org/abs/2302.08841)'.
3 | ### Datasets
4 | #### Download
5 | LRS2/LRS3 dataset can be downloaded from the below link.
6 | - https://www.robots.ox.ac.uk/~vgg/data/lip_reading/
7 |
8 | For data preprocessing, download the lip coordinate of LRS2 and LRS3 from the below links.
9 | - [LRS2](https://drive.google.com/file/d/10cnzNRRr-LQbS5kjc393FLvmNxPJ_u1N/view?usp=drive_link)
10 | - [LRS3](https://drive.google.com/file/d/10eAVKBuy7TyslcPdv4xmf5dxSYx4NMrS/view?usp=drive_link)
11 |
12 | Unzip and put the files to
13 | ```
14 | ./data/LRS2/LRS2_crop/*.txt
15 | ./data/LRS3/LRS3_crop/*.txt
16 | ```
17 |
18 | #### Preprocessing
19 | After download the dataset, extract audio file (.wav) from the video using `./preprocess/Extract_audio_LRS.py`.
20 |
21 | ```shell
22 | python ./preprocess/Extract_audio_LRS.py \
23 | --data_path 'original_video_path/LRS2-BBC' \
24 | --save_path 'path_to_save/LRS2-BBC_audio' \
25 | --data_type 'LRS2 or LRS3'
26 | ```
27 |
28 | We suppose the data directory is constructed as
29 | ```
30 | LRS2-BBC
31 | ├── main
32 | | ├── *
33 | | | └── *.mp4
34 | | | └── *.txt
35 |
36 | LRS2-BBC_audio
37 | ├── main
38 | | ├── *
39 | | | └── *.wav
40 | ```
41 |
42 | ```
43 | LRS3-TED
44 | ├── trainval
45 | | ├── *
46 | | | └── *.mp4
47 | | | └── *.txt
48 |
49 | LRS2-TED_audio
50 | ├── trainval
51 | | ├── *
52 | | | └── *.wav
53 | ```
54 |
55 | Moreover, put the train/val/test splits to
56 | ```
57 | ./data/LRS2/*.txt
58 | ./data/LRS3/*.txt
59 | ```
60 |
61 | For the LRS2, we use the original splits of the dataset provided.
62 | For the LRS3, we use the unseen splits setting of [SVTS](https://arxiv.org/abs/2205.02058), where they are placed in the directory already.
63 |
64 | ## Training the Model
65 | `data_name` argument is used to choose which dataset will be used. (LRS2 or LRS3)
66 | To train the model, run following command:
67 |
68 | ```shell
69 | # Data Parallel training example using 4 GPUs on LRS2
70 | python train_LRS.py \
71 | --data '/data_dir_as_like/LRS2-BBC' \
72 | --data_name 'LRS2'
73 | --checkpoint_dir 'enter_the_path_to_save' \
74 | --batch_size 80 \
75 | --epochs 200 \
76 | --dataparallel \
77 | --gpu 0,1,2,3
78 | ```
79 |
80 | ```shell
81 | # 1 GPU training example on LRS3
82 | python train_LRS.py \
83 | --data '/data_dir_as_like/LRS3-TED' \
84 | --data_name 'LRS3'
85 | --checkpoint_dir 'enter_the_path_to_save' \
86 | --batch_size 80 \
87 | --epochs 200 \
88 | --gpu 0
89 | ```
90 |
91 | Descriptions of training parameters are as follows:
92 | - `--data`: Dataset location (LRS2 or LRS3)
93 | - `--data_name`: Choose to train on LRS2 or LRS3
94 | - `--checkpoint_dir`: directory for saving checkpoints
95 | - `--checkpoint` : saved checkpoint where the training is resumed from
96 | - `--batch_size`: batch size
97 | - `--epochs`: number of epochs
98 | - `--augmentations`: whether performing augmentation
99 | - `--dataparallel`: Use DataParallel
100 | - `--gpu`: gpu number for training
101 | - `--lr`: learning rate
102 | - `--window_size`: number of frames to be used for training
103 | - Refer to `train_LRS3.py` for the other training parameters
104 |
105 | The evaluation during training is performed for a subset of the validation dataset due to the heavy time costs of waveform conversion (griffin-lim).
106 | In order to evaluate the entire performance of the trained model run the test code (refer to "Testing the Model" section).
107 |
108 | ### check the training logs
109 | ```shell
110 | tensorboard --logdir='./runs/logs to watch' --host='ip address of the server'
111 | ```
112 | The tensorboard shows the training and validation loss, evaluation metrics, generated mel-spectrogram, and audio
113 |
114 |
115 | ## Testing the Model
116 | To test the model, run following command:
117 | ```shell
118 | # test example on LRS2
119 | python test_LRS.py \
120 | --data 'data_directory_path' \
121 | --data_name 'LRS2'
122 | --checkpoint 'enter_the_checkpoint_path' \
123 | --batch_size 20 \
124 | --save_mel \
125 | --save_wav \
126 | --gpu 0
127 | ```
128 |
129 | Descriptions of training parameters are as follows:
130 | - `--data`: Dataset location (LRS2 or LRS3)
131 | - `--data_name`: Choose to train on LRS2 or LRS3
132 | - `--checkpoint` : saved checkpoint where the training is resumed from
133 | - `--batch_size`: batch size
134 | - `--dataparallel`: Use DataParallel
135 | - `--gpu`: gpu number for training
136 | - `--save_mel`: whether to save the 'mel_spectrogram' and 'spectrogram' in `.npz` format
137 | - `--save_wav`: whether to save the 'waveform' in `.wav` format
138 | - Refer to `test.py` for the other parameters
139 |
140 |
141 | ## Pre-trained model checkpoints
142 | We provide pre-trained VCA-GAN models trained on LRS2 and LRS3.
143 | The performances are reported in our ICASSP23 [paper](https://arxiv.org/abs/2302.08841).
144 |
145 | | Dataset | STOI |
146 | |:-------------------:|:--------:|
147 | |LRS2 | [0.407](https://drive.google.com/file/d/11ixych8CrmrrUWsHC35LG0nipFEOZw-W/view?usp=sharing) |
148 | |LRS3 | [0.474](https://drive.google.com/file/d/11jUyjr_hnsaeOzxUDcVywhUTl3lemXT5/view?usp=sharing) |
149 |
--------------------------------------------------------------------------------
/preprocess/Preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import glob
4 | import cv2
5 | import torchvision
6 | import torch
7 | from skimage import transform
8 | import numpy as np
9 | from scipy import signal
10 | from torch.utils.data import DataLoader, Dataset
11 | import torchvision.transforms.functional as F
12 | from torchvision import transforms
13 | import librosa
14 | import soundfile as sf
15 | import argparse
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--Data_dir', type=str, default="Data dir of images and audio of GRID")
20 | parser.add_argument('--Landmark', type=str, default="Data dir of GRID Landmark")
21 | parser.add_argument('--FPS', type=int, default=25, help="25 for GRID")
22 | parser.add_argument('--reference', type=str, default='./Ref_face.txt')
23 | parser.add_argument("--Output_dir", type=str, default='Output dir Ex) ./GRID_processed')
24 | args = parser.parse_args()
25 | return args
26 |
27 | args = parse_args()
28 | eps = 1e-8
29 |
30 | class Crop(object):
31 | def __init__(self, crop):
32 | self.crop = crop
33 |
34 | def __call__(self, img):
35 | return img.crop(self.crop)
36 |
37 | ######################################################
38 | f = open(args.reference, 'r')
39 | lm = f.readlines()[0]
40 | f.close()
41 | lm = lm.split(':')[-1].split('|')[6]
42 | lms = lm.split(',')
43 | temp_lm = []
44 |
45 | for lm in lms:
46 | x, y = lm.split()
47 | temp_lm.append([x, y])
48 | temp_lm = np.array(temp_lm, dtype=float)
49 | refer_lm = temp_lm
50 | ######################################################
51 |
52 |
53 | class Preprocessing(Dataset):
54 | def __init__(self,):
55 | self.file_paths = self.build_file_list()
56 |
57 | def build_file_list(self):
58 | file_list = []
59 | landmarks = sorted(glob.glob(os.path.join(args.Landmark, '*', '*', '*.txt')))
60 | for lm in landmarks:
61 | if not os.path.exists(lm.replace(args.Landmark, args.Output_dir)[:-4] + '.mp4'):
62 | file_list.append(lm)
63 | return file_list
64 |
65 | def __len__(self):
66 | return len(self.file_paths)
67 |
68 | def __getitem__(self, idx):
69 | file_path = self.file_paths[idx]
70 | ims = sorted(glob.glob(os.path.join(file_path.replace(args.Landmark, args.Data_dir)[:-4], '*.png')))
71 | frames = []
72 |
73 | for im in ims:
74 | frames += [cv2.cvtColor(cv2.imread(im), cv2.COLOR_BGR2RGB)]
75 | v = np.stack(frames, 0)
76 |
77 | t, f_name = os.path.split(file_path)
78 | t, m_name = os.path.split(t)
79 | _, s_name = os.path.split(t)
80 | save_path = os.path.join(args.Output_dir, s_name, m_name)
81 | try:
82 | with open(file_path, 'r', encoding='utf-8') as lf:
83 | lms = lf.readlines()[0]
84 | except:
85 | with open(file_path, 'r', encoding='cp949') as lf:
86 | lms = lf.readlines()[0]
87 | lms = lms.split(':')[-1].split('|')
88 | assert v.shape[0] == len(lms), 'the video frame length differs to the landmark frames'
89 |
90 | aligned_video = []
91 | for i, frame in enumerate(v):
92 | lm = lms[i].split(',')
93 | temp_lm = []
94 | for l in lm:
95 | x, y = l.split()
96 | temp_lm.append([x, y])
97 | temp_lm = np.array(temp_lm, dtype=float) # 98,2
98 |
99 | source_lm = temp_lm
100 |
101 | tform.estimate(source_lm, refer_lm)
102 | mat = tform.params[0:2, :]
103 | aligned_im = cv2.warpAffine(frame, mat, (np.shape(frame)[0], np.shape(frame)[1]))
104 | aligned_video += [aligned_im[:256, :256, :]]
105 |
106 | aligned_video = np.array(aligned_video)
107 |
108 | #### audio preprocessing ####
109 | aud, _ = librosa.load(os.path.join(file_path.replace(args.Landmark, args.Data_dir).replace('video', 'audio')[:-4] + '.wav'), 16000)
110 | fc = 55. # Cut-off frequency of the filter
111 | w = fc / (16000 / 2) # Normalize the frequency
112 | b, a = signal.butter(7, w, 'high')
113 | aud = signal.filtfilt(b, a, aud)
114 |
115 | return torch.tensor(aligned_video), save_path, f_name, torch.tensor(aud.copy())
116 |
117 | Data = Preprocessing()
118 | Data_loader = DataLoader(Data, shuffle=False, batch_size=1, num_workers=3)
119 | tform = transform.SimilarityTransform()
120 |
121 | for kk, data in enumerate(Data_loader):
122 | cropped_video, save_path, f_name, aud = data
123 | aud = aud[0]
124 | cropped_video = cropped_video[0]
125 | save_path = save_path[0]
126 | f_name = f_name[0]
127 | if not os.path.exists(save_path):
128 | os.makedirs(save_path)
129 | if not os.path.exists(save_path.replace('video', 'audio')):
130 | os.makedirs(save_path.replace('video', 'audio'))
131 | torchvision.io.write_video(os.path.join(save_path, f_name[:-4] + '.mp4'), video_array=cropped_video, fps=args.FPS)
132 | sf.write(os.path.join(save_path.replace('video', 'audio'), f_name[:-4] + ".flac"), aud.numpy(), samplerate=16000)
133 | print('##########', kk + 1, ' / ', len(Data_loader), '##########')
134 |
--------------------------------------------------------------------------------
/ASR_model/LRW/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import torch
4 | from torch import nn, optim
5 | from torch.utils.tensorboard import SummaryWriter
6 | import numpy as np
7 | from src.models.audio_front import Audio_front
8 | from src.models.classifier import Backend
9 | import os
10 | from torch.utils.data import DataLoader
11 | from torch.nn import functional as F
12 | from src.data.vid_aud_lrw_test import MultiDataset
13 | from torch.nn import DataParallel as DP
14 | import torch.nn.parallel
15 | import math
16 | from matplotlib import pyplot as plt
17 | import time
18 | import glob
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--data', default="TEST_DIR", help='./../../test/spec_mel')
23 | parser.add_argument('--wav', default=False, action='store_true')
24 | parser.add_argument("--checkpoint_dir", type=str, default='./data')
25 | parser.add_argument("--checkpoint", type=str, default='./data/LRW_acc_0.98464.ckpt')
26 | parser.add_argument("--batch_size", type=int, default=320)
27 | parser.add_argument("--epochs", type=int, default=100)
28 | parser.add_argument("--lr", type=float, default=0.01)
29 | parser.add_argument("--weight_decay", type=float, default=0.00001)
30 | parser.add_argument("--workers", type=int, default=5)
31 | parser.add_argument("--resnet", type=int, default=18)
32 | parser.add_argument("--seed", type=int, default=1)
33 |
34 | parser.add_argument("--max_timesteps", type=int, default=29)
35 |
36 | parser.add_argument("--dataparallel", default=False, action='store_true')
37 | parser.add_argument("--gpu", type=str, default='0')
38 | args = parser.parse_args()
39 | return args
40 |
41 |
42 | def train_net(args):
43 | torch.backends.cudnn.deterministic = False
44 | torch.backends.cudnn.benchmark = True
45 | torch.manual_seed(args.seed)
46 | torch.cuda.manual_seed_all(args.seed)
47 | random.seed(args.seed)
48 | os.environ['OMP_NUM_THREADS'] = '2'
49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
50 |
51 | a_front = Audio_front(in_channels=1)
52 | a_back = Backend()
53 |
54 | if args.checkpoint is not None:
55 | print(f"Loading checkpoint: {args.checkpoint}")
56 | checkpoint = torch.load(args.checkpoint)
57 | a_front.load_state_dict(checkpoint['a_front_state_dict'])
58 | a_back.load_state_dict(checkpoint['a_back_state_dict'])
59 | del checkpoint
60 |
61 | a_front.cuda()
62 | a_back.cuda()
63 |
64 | if args.dataparallel:
65 | a_front = DP(a_front)
66 | a_back = DP(a_back)
67 |
68 | _ = validate(a_front, a_back, epoch=0, writer=None)
69 |
70 | def validate(a_front, a_back, fast_validate=False, epoch=0, writer=None):
71 | with torch.no_grad():
72 | a_front.eval()
73 | a_back.eval()
74 |
75 | val_data = MultiDataset(
76 | lrw=args.data,
77 | mode='test',
78 | max_v_timesteps=args.max_timesteps,
79 | augmentations=False,
80 | wav=args.wav
81 | )
82 |
83 | dataloader = DataLoader(
84 | val_data,
85 | shuffle=False,
86 | batch_size=args.batch_size * 2,
87 | num_workers=args.workers,
88 | drop_last=False
89 | )
90 |
91 | criterion = nn.CrossEntropyLoss().cuda()
92 | batch_size = dataloader.batch_size
93 | if fast_validate:
94 | samples = min(2 * batch_size, int(len(dataloader.dataset)))
95 | max_batches = 2
96 | else:
97 | samples = int(len(dataloader.dataset))
98 | max_batches = int(len(dataloader))
99 |
100 | val_loss = []
101 | tot_cor, tot_v_cor, tot_a_cor, tot_num = 0, 0, 0, 0
102 |
103 | description = 'Check validation step' if fast_validate else 'Validation'
104 | print(description)
105 | for i, batch in enumerate(dataloader):
106 | if i % 10 == 0:
107 | if not fast_validate:
108 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples))
109 | a_in, target = batch
110 |
111 | a_feat = a_front(a_in.cuda()) # S,B,51
112 | a_pred = a_back(a_feat)
113 |
114 | loss = criterion(a_pred, target.long().cuda()).cpu().item()
115 | prediction = torch.argmax(a_pred.cpu(), dim=1).numpy()
116 | tot_cor += np.sum(prediction == target.long().numpy())
117 | tot_num += len(prediction)
118 |
119 | batch_size = a_pred.size(0)
120 | val_loss.append(loss)
121 |
122 | if i >= max_batches:
123 | break
124 |
125 | if writer is not None:
126 | writer.add_scalar('Val/loss', np.mean(np.array(val_loss)), epoch)
127 | writer.add_scalar('Val/acc', tot_cor / tot_num, epoch)
128 |
129 | a_front.train()
130 | a_back.train()
131 | print('test_ACC:', tot_cor / tot_num, 'WER:', 1. - tot_cor / tot_num)
132 | if fast_validate:
133 | return {}
134 | else:
135 | return np.mean(np.array(val_loss)), tot_cor / tot_num
136 |
137 | if __name__ == "__main__":
138 | args = parse_args()
139 | train_net(args)
140 |
141 |
--------------------------------------------------------------------------------
/src/data/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 | Copyright (c) 2017, Prem Seetharaman
4 | All rights reserved.
5 | * Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 | * Redistributions of source code must retain the above copyright notice,
8 | this list of conditions and the following disclaimer.
9 | * Redistributions in binary form must reproduce the above copyright notice, this
10 | list of conditions and the following disclaimer in the
11 | documentation and/or other materials provided with the distribution.
12 | * Neither the name of the copyright holder nor the names of its
13 | contributors may be used to endorse or promote products derived from this
14 | software without specific prior written permission.
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | """
26 |
27 | import torch
28 | import numpy as np
29 | import torch.nn.functional as F
30 | from torch.autograd import Variable
31 | from scipy.signal import get_window
32 | from librosa.util import pad_center, tiny
33 | from src.data.audio_processing import window_sumsquare
34 |
35 | class STFT(torch.nn.Module):
36 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
37 | def __init__(self, filter_length=800, hop_length=200, win_length=800,
38 | window='hann'):
39 | super(STFT, self).__init__()
40 | self.filter_length = filter_length
41 | self.hop_length = hop_length
42 | self.win_length = win_length
43 | self.window = window
44 | self.forward_transform = None
45 | scale = self.filter_length / self.hop_length
46 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
47 |
48 | cutoff = int((self.filter_length / 2 + 1))
49 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
50 | np.imag(fourier_basis[:cutoff, :])])
51 |
52 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
53 | inverse_basis = torch.FloatTensor(
54 | np.linalg.pinv(scale * fourier_basis).T[:, None, :])
55 |
56 | if window is not None:
57 | assert(filter_length >= win_length)
58 | # get window and zero center pad it to filter_length
59 | fft_window = get_window(window, win_length, fftbins=True)
60 | fft_window = pad_center(fft_window, filter_length)
61 | fft_window = torch.from_numpy(fft_window).float()
62 |
63 | # window the bases
64 | forward_basis *= fft_window
65 | inverse_basis *= fft_window
66 |
67 | self.register_buffer('forward_basis', forward_basis.float())
68 | self.register_buffer('inverse_basis', inverse_basis.float())
69 |
70 | def transform(self, input_data):
71 | num_batches = input_data.size(0)
72 | num_samples = input_data.size(1)
73 |
74 | self.num_samples = num_samples
75 |
76 | # similar to librosa, reflect-pad the input
77 | input_data = input_data.view(num_batches, 1, num_samples)
78 | input_data = F.pad(
79 | input_data.unsqueeze(1),
80 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
81 | mode='reflect')
82 | input_data = input_data.squeeze(1)
83 |
84 | forward_transform = F.conv1d(
85 | input_data,
86 | Variable(self.forward_basis, requires_grad=False),
87 | stride=self.hop_length,
88 | padding=0)
89 |
90 | cutoff = int((self.filter_length / 2) + 1)
91 | real_part = forward_transform[:, :cutoff, :]
92 | imag_part = forward_transform[:, cutoff:, :]
93 |
94 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
95 | phase = torch.autograd.Variable(
96 | torch.atan2(imag_part.data, real_part.data))
97 |
98 | return magnitude, phase
99 |
100 | def inverse(self, magnitude, phase):
101 | recombine_magnitude_phase = torch.cat(
102 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
103 |
104 | inverse_transform = F.conv_transpose1d(
105 | recombine_magnitude_phase,
106 | Variable(self.inverse_basis, requires_grad=False),
107 | stride=self.hop_length,
108 | padding=0)
109 |
110 | if self.window is not None:
111 | window_sum = window_sumsquare(
112 | self.window, magnitude.size(-1), hop_length=self.hop_length,
113 | win_length=self.win_length, n_fft=self.filter_length,
114 | dtype=np.float32)
115 | # remove modulation effects
116 | approx_nonzero_indices = torch.from_numpy(
117 | np.where(window_sum > tiny(window_sum))[0])
118 | window_sum = torch.autograd.Variable(
119 | torch.from_numpy(window_sum), requires_grad=False)
120 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
121 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
122 |
123 | # scale by hop ratio
124 | inverse_transform *= float(self.filter_length) / self.hop_length
125 |
126 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
127 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
128 |
129 | return inverse_transform
130 |
131 | def forward(self, input_data):
132 | self.magnitude, self.phase = self.transform(input_data)
133 | reconstruction = self.inverse(self.magnitude, self.phase)
134 | return reconstruction
--------------------------------------------------------------------------------
/ASR_model/GRID/src/data/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 | Copyright (c) 2017, Prem Seetharaman
4 | All rights reserved.
5 | * Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 | * Redistributions of source code must retain the above copyright notice,
8 | this list of conditions and the following disclaimer.
9 | * Redistributions in binary form must reproduce the above copyright notice, this
10 | list of conditions and the following disclaimer in the
11 | documentation and/or other materials provided with the distribution.
12 | * Neither the name of the copyright holder nor the names of its
13 | contributors may be used to endorse or promote products derived from this
14 | software without specific prior written permission.
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | """
26 |
27 | import torch
28 | import numpy as np
29 | import torch.nn.functional as F
30 | from torch.autograd import Variable
31 | from scipy.signal import get_window
32 | from librosa.util import pad_center, tiny
33 | from src.data.audio_processing import window_sumsquare
34 |
35 | class STFT(torch.nn.Module):
36 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
37 | def __init__(self, filter_length=800, hop_length=200, win_length=800,
38 | window='hann'):
39 | super(STFT, self).__init__()
40 | self.filter_length = filter_length
41 | self.hop_length = hop_length
42 | self.win_length = win_length
43 | self.window = window
44 | self.forward_transform = None
45 | scale = self.filter_length / self.hop_length
46 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
47 |
48 | cutoff = int((self.filter_length / 2 + 1))
49 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
50 | np.imag(fourier_basis[:cutoff, :])])
51 |
52 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
53 | inverse_basis = torch.FloatTensor(
54 | np.linalg.pinv(scale * fourier_basis).T[:, None, :])
55 |
56 | if window is not None:
57 | assert(filter_length >= win_length)
58 | # get window and zero center pad it to filter_length
59 | fft_window = get_window(window, win_length, fftbins=True)
60 | fft_window = pad_center(fft_window, filter_length)
61 | fft_window = torch.from_numpy(fft_window).float()
62 |
63 | # window the bases
64 | forward_basis *= fft_window
65 | inverse_basis *= fft_window
66 |
67 | self.register_buffer('forward_basis', forward_basis.float())
68 | self.register_buffer('inverse_basis', inverse_basis.float())
69 |
70 | def transform(self, input_data):
71 | num_batches = input_data.size(0)
72 | num_samples = input_data.size(1)
73 |
74 | self.num_samples = num_samples
75 |
76 | # similar to librosa, reflect-pad the input
77 | input_data = input_data.view(num_batches, 1, num_samples)
78 | input_data = F.pad(
79 | input_data.unsqueeze(1),
80 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
81 | mode='reflect')
82 | input_data = input_data.squeeze(1)
83 |
84 | forward_transform = F.conv1d(
85 | input_data,
86 | Variable(self.forward_basis, requires_grad=False),
87 | stride=self.hop_length,
88 | padding=0)
89 |
90 | cutoff = int((self.filter_length / 2) + 1)
91 | real_part = forward_transform[:, :cutoff, :]
92 | imag_part = forward_transform[:, cutoff:, :]
93 |
94 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
95 | phase = torch.autograd.Variable(
96 | torch.atan2(imag_part.data, real_part.data))
97 |
98 | return magnitude, phase
99 |
100 | def inverse(self, magnitude, phase):
101 | recombine_magnitude_phase = torch.cat(
102 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
103 |
104 | inverse_transform = F.conv_transpose1d(
105 | recombine_magnitude_phase,
106 | Variable(self.inverse_basis, requires_grad=False),
107 | stride=self.hop_length,
108 | padding=0)
109 |
110 | if self.window is not None:
111 | window_sum = window_sumsquare(
112 | self.window, magnitude.size(-1), hop_length=self.hop_length,
113 | win_length=self.win_length, n_fft=self.filter_length,
114 | dtype=np.float32)
115 | # remove modulation effects
116 | approx_nonzero_indices = torch.from_numpy(
117 | np.where(window_sum > tiny(window_sum))[0])
118 | window_sum = torch.autograd.Variable(
119 | torch.from_numpy(window_sum), requires_grad=False)
120 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
121 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
122 |
123 | # scale by hop ratio
124 | inverse_transform *= float(self.filter_length) / self.hop_length
125 |
126 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
127 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
128 |
129 | return inverse_transform
130 |
131 | def forward(self, input_data):
132 | self.magnitude, self.phase = self.transform(input_data)
133 | reconstruction = self.inverse(self.magnitude, self.phase)
134 | return reconstruction
--------------------------------------------------------------------------------
/ASR_model/LRW/src/data/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 | Copyright (c) 2017, Prem Seetharaman
4 | All rights reserved.
5 | * Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 | * Redistributions of source code must retain the above copyright notice,
8 | this list of conditions and the following disclaimer.
9 | * Redistributions in binary form must reproduce the above copyright notice, this
10 | list of conditions and the following disclaimer in the
11 | documentation and/or other materials provided with the distribution.
12 | * Neither the name of the copyright holder nor the names of its
13 | contributors may be used to endorse or promote products derived from this
14 | software without specific prior written permission.
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | """
26 |
27 | import torch
28 | import numpy as np
29 | import torch.nn.functional as F
30 | from torch.autograd import Variable
31 | from scipy.signal import get_window
32 | from librosa.util import pad_center, tiny
33 | from src.data.audio_processing import window_sumsquare
34 |
35 | class STFT(torch.nn.Module):
36 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
37 | def __init__(self, filter_length=800, hop_length=200, win_length=800,
38 | window='hann'):
39 | super(STFT, self).__init__()
40 | self.filter_length = filter_length
41 | self.hop_length = hop_length
42 | self.win_length = win_length
43 | self.window = window
44 | self.forward_transform = None
45 | scale = self.filter_length / self.hop_length
46 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
47 |
48 | cutoff = int((self.filter_length / 2 + 1))
49 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
50 | np.imag(fourier_basis[:cutoff, :])])
51 |
52 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
53 | inverse_basis = torch.FloatTensor(
54 | np.linalg.pinv(scale * fourier_basis).T[:, None, :])
55 |
56 | if window is not None:
57 | assert(filter_length >= win_length)
58 | # get window and zero center pad it to filter_length
59 | fft_window = get_window(window, win_length, fftbins=True)
60 | fft_window = pad_center(fft_window, filter_length)
61 | fft_window = torch.from_numpy(fft_window).float()
62 |
63 | # window the bases
64 | forward_basis *= fft_window
65 | inverse_basis *= fft_window
66 |
67 | self.register_buffer('forward_basis', forward_basis.float())
68 | self.register_buffer('inverse_basis', inverse_basis.float())
69 |
70 | def transform(self, input_data):
71 | num_batches = input_data.size(0)
72 | num_samples = input_data.size(1)
73 |
74 | self.num_samples = num_samples
75 |
76 | # similar to librosa, reflect-pad the input
77 | input_data = input_data.view(num_batches, 1, num_samples)
78 | input_data = F.pad(
79 | input_data.unsqueeze(1),
80 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
81 | mode='reflect')
82 | input_data = input_data.squeeze(1)
83 |
84 | forward_transform = F.conv1d(
85 | input_data,
86 | Variable(self.forward_basis, requires_grad=False),
87 | stride=self.hop_length,
88 | padding=0)
89 |
90 | cutoff = int((self.filter_length / 2) + 1)
91 | real_part = forward_transform[:, :cutoff, :]
92 | imag_part = forward_transform[:, cutoff:, :]
93 |
94 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
95 | phase = torch.autograd.Variable(
96 | torch.atan2(imag_part.data, real_part.data))
97 |
98 | return magnitude, phase
99 |
100 | def inverse(self, magnitude, phase):
101 | recombine_magnitude_phase = torch.cat(
102 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
103 |
104 | inverse_transform = F.conv_transpose1d(
105 | recombine_magnitude_phase,
106 | Variable(self.inverse_basis, requires_grad=False),
107 | stride=self.hop_length,
108 | padding=0)
109 |
110 | if self.window is not None:
111 | window_sum = window_sumsquare(
112 | self.window, magnitude.size(-1), hop_length=self.hop_length,
113 | win_length=self.win_length, n_fft=self.filter_length,
114 | dtype=np.float32)
115 | # remove modulation effects
116 | approx_nonzero_indices = torch.from_numpy(
117 | np.where(window_sum > tiny(window_sum))[0])
118 | window_sum = torch.autograd.Variable(
119 | torch.from_numpy(window_sum), requires_grad=False)
120 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
121 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
122 |
123 | # scale by hop ratio
124 | inverse_transform *= float(self.filter_length) / self.hop_length
125 |
126 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
127 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
128 |
129 | return inverse_transform
130 |
131 | def forward(self, input_data):
132 | self.magnitude, self.phase = self.transform(input_data)
133 | reconstruction = self.inverse(self.magnitude, self.phase)
134 | return reconstruction
--------------------------------------------------------------------------------
/README_GRID.md:
--------------------------------------------------------------------------------
1 | ### Datasets
2 | #### Download
3 | GRID dataset (video normal) can be downloaded from the below link.
4 | - http://spandh.dcs.shef.ac.uk/gridcorpus/
5 |
6 | For data preprocessing, download the face landmark of GRID from the below link.
7 | - https://drive.google.com/file/d/10upLpydfbqCJ7t64h210Xx-35B8qCcZr/view?usp=sharing
8 |
9 | #### Preprocessing
10 | After download the dataset, preprocess the dataset with the following scripts in `./preprocess`.
11 | It supposes the data directory is constructed as
12 | ```
13 | Data_dir
14 | ├── subject
15 | | ├── video
16 | | | └── xxx.mpg
17 | ```
18 |
19 | 1. Extract frames
20 | `Extract_frames.py` extract images and audio from the video.
21 | ```shell
22 | python Extract_frames.py --Grid_dir "Data dir of GRID_corpus" --Out_dir "Output dir of images and audio of GRID_corpus"
23 | ```
24 |
25 | 2. Align faces and audio processing
26 | `Preprocess.py` aligns faces and generates videos, which enables cropping the video lip-centered during training.
27 | ```shell
28 | python Preprocess.py \
29 | --Data_dir "Data dir of extracted images and audio of GRID_corpus" \
30 | --Landmark "Downloaded landmark dir of GRID" \
31 | --Output_dir "Output dir of processed data"
32 | ```
33 |
34 | ## Training the Model
35 | The speaker setting (different subject) can be selected by `subject` argument. Please refer to below examples.
36 | To train the model, run following command:
37 |
38 | ```shell
39 | # Data Parallel training example using 4 GPUs for multi-speaker setting in GRID
40 | python train.py \
41 | --grid 'enter_the_processed_data_path' \
42 | --checkpoint_dir 'enter_the_path_to_save' \
43 | --batch_size 88 \
44 | --epochs 500 \
45 | --subject 'overlap' \
46 | --eval_step 720 \
47 | --dataparallel \
48 | --gpu 0,1,2,3
49 | ```
50 |
51 | ```shell
52 | # 1 GPU training example for GRID for unseen-speaker setting in GRID
53 | python train.py \
54 | --grid 'enter_the_processed_data_path' \
55 | --checkpoint_dir 'enter_the_path_to_save' \
56 | --batch_size 22 \
57 | --epochs 500 \
58 | --subject 'unseen' \
59 | --eval_step 1000 \
60 | --gpu 0
61 | ```
62 |
63 | Descriptions of training parameters are as follows:
64 | - `--grid`: Dataset location (grid)
65 | - `--checkpoint_dir`: directory for saving checkpoints
66 | - `--checkpoint` : saved checkpoint where the training is resumed from
67 | - `--batch_size`: batch size
68 | - `--epochs`: number of epochs
69 | - `--augmentations`: whether performing augmentation
70 | - `--dataparallel`: Use DataParallel
71 | - `--subject`: different speaker settings, `s#` is speaker specific training, `overlap` for multi-speaker setting, `unseen` for unseen-speaker setting, `four` for four speaker training
72 | - `--gpu`: gpu number for training
73 | - `--lr`: learning rate
74 | - `--eval_step`: steps for performing evaluation
75 | - `--window_size`: number of frames to be used for training
76 | - Refer to `train.py` for the other training parameters
77 |
78 | The evaluation during training is performed for a subset of the validation dataset due to the heavy time costs of waveform conversion (griffin-lim).
79 | In order to evaluate the entire performance of the trained model run the test code (refer to "Testing the Model" section).
80 |
81 | ### check the training logs
82 | ```shell
83 | tensorboard --logdir='./runs/logs to watch' --host='ip address of the server'
84 | ```
85 | The tensorboard shows the training and validation loss, evaluation metrics, generated mel-spectrogram, and audio
86 |
87 |
88 | ## Testing the Model
89 | To test the model, run following command:
90 | ```shell
91 | # Dataparallel test example for multi-speaker setting in GRID
92 | python test.py \
93 | --grid 'enter_the_processed_data_path' \
94 | --checkpoint 'enter_the_checkpoint_path' \
95 | --batch_size 100 \
96 | --subject 'overlap' \
97 | --save_mel \
98 | --save_wav \
99 | --dataparallel \
100 | --gpu 0,1
101 | ```
102 |
103 | Descriptions of training parameters are as follows:
104 | - `--grid`: Dataset location (grid)
105 | - `--checkpoint` : saved checkpoint where the training is resumed from
106 | - `--batch_size`: batch size
107 | - `--dataparallel`: Use DataParallel
108 | - `--subject`: different speaker settings, `s#` is speaker specific training, `overlap` for multi-speaker setting, `unseen` for unseen-speaker setting, `four` for four speaker training
109 | - `--save_mel`: whether to save the 'mel_spectrogram' and 'spectrogram' in `.npz` format
110 | - `--save_wav`: whether to save the 'waveform' in `.wav` format
111 | - `--gpu`: gpu number for training
112 | - Refer to `test.py` for the other parameters
113 |
114 | ## Test Automatic Speech Recognition (ASR) results of generated results: WER
115 | Transcription (Ground-truth) of GRID dataset can be downloaded from the below link.
116 | - https://drive.google.com/file/d/112UES-N0OKjj0xV0hCDO9ZaOyPsvnJzR/view?usp=sharing
117 |
118 | move to the ASR_model directory
119 | ```shell
120 | cd ASR_model/GRID
121 | ```
122 |
123 | To evaluate the WER, run following command:
124 | ```shell
125 | # test example for multi-speaker setting in GRID
126 | python test.py \
127 | --data 'enter_the_generated_data_dir (mel or wav) (ex. ./../../test/spec_mel)' \
128 | --gtpath 'enter_the_downloaded_transcription_path' \
129 | --subject 'overlap' \
130 | --gpu 0
131 | ```
132 |
133 | Descriptions of training parameters are as follows:
134 | - `--data`: Data for evaluation (wav or mel(.npz))
135 | - `--wav` : whether the data is waveform or not
136 | - `--batch_size`: batch size
137 | - `--subject`: different speaker settings, `s#` is speaker specific training, `overlap` for multi-speaker setting, `unseen` for unseen-speaker setting, `four` for four speaker training
138 | - `--gpu`: gpu number for training
139 | - Refer to `./ASR_model/GRID/test.py` for the other parameters
140 |
141 |
142 | ### Pre-trained ASR model checkpoint
143 | Below lists are the pre-trained ASR model to evaluate the generated speech.
144 | WER shows the original performances of the model on ground-truth audio.
145 |
146 | | Setting | WER |
147 | |:-------------------:|:--------:|
148 | |GRID (constrained-speaker) | [0.83 %](https://drive.google.com/file/d/11OyjBnfLU7M3qt98udIkUiGZ6frHLwP1/view?usp=sharing) |
149 | |GRID (multi-speaker) | [0.37 %](https://drive.google.com/file/d/113XmPNjlY7c1fSjUgf5IzJi1us7qVeX1/view?usp=sharing) |
150 | |GRID (unseen-speaker) | [1.67 %](https://drive.google.com/file/d/11CAO_z_eUWcdGWPuVcCYy1mpEG2YghAT/view?usp=sharing) |
151 | |LRW | [1.57 %](https://drive.google.com/file/d/11erwY_Tf69OBfBSdZFwikmBXLI_G2YKb/view?usp=sharing) |
152 |
153 | Put the checkpoints in `./ASR_model/GRID/data` for GRID, and in `./ASR_model/LRW/data` for LRW.
154 | The LRW checkpoint is modified from originally trained using torchvision to that using librosa library.
155 |
--------------------------------------------------------------------------------
/ASR_model/LRW/data/class.txt:
--------------------------------------------------------------------------------
1 | ABOUT
2 | ABSOLUTELY
3 | ABUSE
4 | ACCESS
5 | ACCORDING
6 | ACCUSED
7 | ACROSS
8 | ACTION
9 | ACTUALLY
10 | AFFAIRS
11 | AFFECTED
12 | AFRICA
13 | AFTER
14 | AFTERNOON
15 | AGAIN
16 | AGAINST
17 | AGREE
18 | AGREEMENT
19 | AHEAD
20 | ALLEGATIONS
21 | ALLOW
22 | ALLOWED
23 | ALMOST
24 | ALREADY
25 | ALWAYS
26 | AMERICA
27 | AMERICAN
28 | AMONG
29 | AMOUNT
30 | ANNOUNCED
31 | ANOTHER
32 | ANSWER
33 | ANYTHING
34 | AREAS
35 | AROUND
36 | ARRESTED
37 | ASKED
38 | ASKING
39 | ATTACK
40 | ATTACKS
41 | AUTHORITIES
42 | BANKS
43 | BECAUSE
44 | BECOME
45 | BEFORE
46 | BEHIND
47 | BEING
48 | BELIEVE
49 | BENEFIT
50 | BENEFITS
51 | BETTER
52 | BETWEEN
53 | BIGGEST
54 | BILLION
55 | BLACK
56 | BORDER
57 | BRING
58 | BRITAIN
59 | BRITISH
60 | BROUGHT
61 | BUDGET
62 | BUILD
63 | BUILDING
64 | BUSINESS
65 | BUSINESSES
66 | CALLED
67 | CAMERON
68 | CAMPAIGN
69 | CANCER
70 | CANNOT
71 | CAPITAL
72 | CASES
73 | CENTRAL
74 | CERTAINLY
75 | CHALLENGE
76 | CHANCE
77 | CHANGE
78 | CHANGES
79 | CHARGE
80 | CHARGES
81 | CHIEF
82 | CHILD
83 | CHILDREN
84 | CHINA
85 | CLAIMS
86 | CLEAR
87 | CLOSE
88 | CLOUD
89 | COMES
90 | COMING
91 | COMMUNITY
92 | COMPANIES
93 | COMPANY
94 | CONCERNS
95 | CONFERENCE
96 | CONFLICT
97 | CONSERVATIVE
98 | CONTINUE
99 | CONTROL
100 | COULD
101 | COUNCIL
102 | COUNTRIES
103 | COUNTRY
104 | COUPLE
105 | COURSE
106 | COURT
107 | CRIME
108 | CRISIS
109 | CURRENT
110 | CUSTOMERS
111 | DAVID
112 | DEATH
113 | DEBATE
114 | DECIDED
115 | DECISION
116 | DEFICIT
117 | DEGREES
118 | DESCRIBED
119 | DESPITE
120 | DETAILS
121 | DIFFERENCE
122 | DIFFERENT
123 | DIFFICULT
124 | DOING
125 | DURING
126 | EARLY
127 | EASTERN
128 | ECONOMIC
129 | ECONOMY
130 | EDITOR
131 | EDUCATION
132 | ELECTION
133 | EMERGENCY
134 | ENERGY
135 | ENGLAND
136 | ENOUGH
137 | EUROPE
138 | EUROPEAN
139 | EVENING
140 | EVENTS
141 | EVERY
142 | EVERYBODY
143 | EVERYONE
144 | EVERYTHING
145 | EVIDENCE
146 | EXACTLY
147 | EXAMPLE
148 | EXPECT
149 | EXPECTED
150 | EXTRA
151 | FACING
152 | FAMILIES
153 | FAMILY
154 | FIGHT
155 | FIGHTING
156 | FIGURES
157 | FINAL
158 | FINANCIAL
159 | FIRST
160 | FOCUS
161 | FOLLOWING
162 | FOOTBALL
163 | FORCE
164 | FORCES
165 | FOREIGN
166 | FORMER
167 | FORWARD
168 | FOUND
169 | FRANCE
170 | FRENCH
171 | FRIDAY
172 | FRONT
173 | FURTHER
174 | FUTURE
175 | GAMES
176 | GENERAL
177 | GEORGE
178 | GERMANY
179 | GETTING
180 | GIVEN
181 | GIVING
182 | GLOBAL
183 | GOING
184 | GOVERNMENT
185 | GREAT
186 | GREECE
187 | GROUND
188 | GROUP
189 | GROWING
190 | GROWTH
191 | GUILTY
192 | HAPPEN
193 | HAPPENED
194 | HAPPENING
195 | HAVING
196 | HEALTH
197 | HEARD
198 | HEART
199 | HEAVY
200 | HIGHER
201 | HISTORY
202 | HOMES
203 | HOSPITAL
204 | HOURS
205 | HOUSE
206 | HOUSING
207 | HUMAN
208 | HUNDREDS
209 | IMMIGRATION
210 | IMPACT
211 | IMPORTANT
212 | INCREASE
213 | INDEPENDENT
214 | INDUSTRY
215 | INFLATION
216 | INFORMATION
217 | INQUIRY
218 | INSIDE
219 | INTEREST
220 | INVESTMENT
221 | INVOLVED
222 | IRELAND
223 | ISLAMIC
224 | ISSUE
225 | ISSUES
226 | ITSELF
227 | JAMES
228 | JUDGE
229 | JUSTICE
230 | KILLED
231 | KNOWN
232 | LABOUR
233 | LARGE
234 | LATER
235 | LATEST
236 | LEADER
237 | LEADERS
238 | LEADERSHIP
239 | LEAST
240 | LEAVE
241 | LEGAL
242 | LEVEL
243 | LEVELS
244 | LIKELY
245 | LITTLE
246 | LIVES
247 | LIVING
248 | LOCAL
249 | LONDON
250 | LONGER
251 | LOOKING
252 | MAJOR
253 | MAJORITY
254 | MAKES
255 | MAKING
256 | MANCHESTER
257 | MARKET
258 | MASSIVE
259 | MATTER
260 | MAYBE
261 | MEANS
262 | MEASURES
263 | MEDIA
264 | MEDICAL
265 | MEETING
266 | MEMBER
267 | MEMBERS
268 | MESSAGE
269 | MIDDLE
270 | MIGHT
271 | MIGRANTS
272 | MILITARY
273 | MILLION
274 | MILLIONS
275 | MINISTER
276 | MINISTERS
277 | MINUTES
278 | MISSING
279 | MOMENT
280 | MONEY
281 | MONTH
282 | MONTHS
283 | MORNING
284 | MOVING
285 | MURDER
286 | NATIONAL
287 | NEEDS
288 | NEVER
289 | NIGHT
290 | NORTH
291 | NORTHERN
292 | NOTHING
293 | NUMBER
294 | NUMBERS
295 | OBAMA
296 | OFFICE
297 | OFFICERS
298 | OFFICIALS
299 | OFTEN
300 | OPERATION
301 | OPPOSITION
302 | ORDER
303 | OTHER
304 | OTHERS
305 | OUTSIDE
306 | PARENTS
307 | PARLIAMENT
308 | PARTIES
309 | PARTS
310 | PARTY
311 | PATIENTS
312 | PAYING
313 | PEOPLE
314 | PERHAPS
315 | PERIOD
316 | PERSON
317 | PERSONAL
318 | PHONE
319 | PLACE
320 | PLACES
321 | PLANS
322 | POINT
323 | POLICE
324 | POLICY
325 | POLITICAL
326 | POLITICIANS
327 | POLITICS
328 | POSITION
329 | POSSIBLE
330 | POTENTIAL
331 | POWER
332 | POWERS
333 | PRESIDENT
334 | PRESS
335 | PRESSURE
336 | PRETTY
337 | PRICE
338 | PRICES
339 | PRIME
340 | PRISON
341 | PRIVATE
342 | PROBABLY
343 | PROBLEM
344 | PROBLEMS
345 | PROCESS
346 | PROTECT
347 | PROVIDE
348 | PUBLIC
349 | QUESTION
350 | QUESTIONS
351 | QUITE
352 | RATES
353 | RATHER
354 | REALLY
355 | REASON
356 | RECENT
357 | RECORD
358 | REFERENDUM
359 | REMEMBER
360 | REPORT
361 | REPORTS
362 | RESPONSE
363 | RESULT
364 | RETURN
365 | RIGHT
366 | RIGHTS
367 | RULES
368 | RUNNING
369 | RUSSIA
370 | RUSSIAN
371 | SAYING
372 | SCHOOL
373 | SCHOOLS
374 | SCOTLAND
375 | SCOTTISH
376 | SECOND
377 | SECRETARY
378 | SECTOR
379 | SECURITY
380 | SEEMS
381 | SENIOR
382 | SENSE
383 | SERIES
384 | SERIOUS
385 | SERVICE
386 | SERVICES
387 | SEVEN
388 | SEVERAL
389 | SHORT
390 | SHOULD
391 | SIDES
392 | SIGNIFICANT
393 | SIMPLY
394 | SINCE
395 | SINGLE
396 | SITUATION
397 | SMALL
398 | SOCIAL
399 | SOCIETY
400 | SOMEONE
401 | SOMETHING
402 | SOUTH
403 | SOUTHERN
404 | SPEAKING
405 | SPECIAL
406 | SPEECH
407 | SPEND
408 | SPENDING
409 | SPENT
410 | STAFF
411 | STAGE
412 | STAND
413 | START
414 | STARTED
415 | STATE
416 | STATEMENT
417 | STATES
418 | STILL
419 | STORY
420 | STREET
421 | STRONG
422 | SUNDAY
423 | SUNSHINE
424 | SUPPORT
425 | SYRIA
426 | SYRIAN
427 | SYSTEM
428 | TAKEN
429 | TAKING
430 | TALKING
431 | TALKS
432 | TEMPERATURES
433 | TERMS
434 | THEIR
435 | THEMSELVES
436 | THERE
437 | THESE
438 | THING
439 | THINGS
440 | THINK
441 | THIRD
442 | THOSE
443 | THOUGHT
444 | THOUSANDS
445 | THREAT
446 | THREE
447 | THROUGH
448 | TIMES
449 | TODAY
450 | TOGETHER
451 | TOMORROW
452 | TONIGHT
453 | TOWARDS
454 | TRADE
455 | TRIAL
456 | TRUST
457 | TRYING
458 | UNDER
459 | UNDERSTAND
460 | UNION
461 | UNITED
462 | UNTIL
463 | USING
464 | VICTIMS
465 | VIOLENCE
466 | VOTERS
467 | WAITING
468 | WALES
469 | WANTED
470 | WANTS
471 | WARNING
472 | WATCHING
473 | WATER
474 | WEAPONS
475 | WEATHER
476 | WEEKEND
477 | WEEKS
478 | WELCOME
479 | WELFARE
480 | WESTERN
481 | WESTMINSTER
482 | WHERE
483 | WHETHER
484 | WHICH
485 | WHILE
486 | WHOLE
487 | WINDS
488 | WITHIN
489 | WITHOUT
490 | WOMEN
491 | WORDS
492 | WORKERS
493 | WORKING
494 | WORLD
495 | WORST
496 | WOULD
497 | WRONG
498 | YEARS
499 | YESTERDAY
500 | YOUNG
501 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import torch
4 | from torch import nn, optim
5 | from torch.utils.tensorboard import SummaryWriter
6 | import numpy as np
7 | from src.models.visual_front import Visual_front
8 | from src.models.generator import Decoder, Discriminator, gan_loss, sync_Discriminator, Postnet
9 | import os
10 | from torch.utils.data import DataLoader
11 | from torch.nn import functional as F
12 | from src.data.vid_aud_grid import MultiDataset
13 | from torch.nn import DataParallel as DP
14 | import torch.nn.parallel
15 | import time
16 | import glob
17 | from torch.autograd import grad
18 | import soundfile as sf
19 | from pesq import pesq
20 | from pystoi import stoi
21 | from matplotlib import pyplot as plt
22 | import copy, librosa
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--grid', default="Data_dir")
28 | parser.add_argument("--checkpoint_dir", type=str, default='./data/checkpoints/GRID')
29 | parser.add_argument("--checkpoint", type=str, default=None)
30 | parser.add_argument("--batch_size", type=int, default=100)
31 | parser.add_argument("--epochs", type=int, default=1000)
32 | parser.add_argument("--lr", type=float, default=0.0001)
33 | parser.add_argument("--weight_decay", type=float, default=0.00001)
34 | parser.add_argument("--workers", type=int, default=10)
35 | parser.add_argument("--seed", type=int, default=1)
36 |
37 | parser.add_argument("--subject", type=str, default='overlap')
38 |
39 | parser.add_argument("--start_epoch", type=int, default=0)
40 | parser.add_argument("--augmentations", default=True)
41 |
42 | parser.add_argument("--window_size", type=int, default=40)
43 | parser.add_argument("--max_timesteps", type=int, default=75)
44 | parser.add_argument("--temp", type=float, default=1.0)
45 |
46 | parser.add_argument("--dataparallel", default=False, action='store_true')
47 | parser.add_argument("--gpu", type=str, default='0,1')
48 |
49 | parser.add_argument("--save_mel", default=False, action='store_true')
50 | parser.add_argument("--save_wav", default=False, action='store_true')
51 |
52 | args = parser.parse_args()
53 | return args
54 |
55 |
56 | def train_net(args):
57 | torch.backends.cudnn.deterministic = False
58 | torch.backends.cudnn.benchmark = True
59 | torch.manual_seed(args.seed)
60 | torch.cuda.manual_seed_all(args.seed)
61 | random.seed(args.seed)
62 | os.environ['OMP_NUM_THREADS'] = '2'
63 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
64 |
65 | v_front = Visual_front(in_channels=1)
66 | gen = Decoder()
67 | post = Postnet()
68 |
69 | print(f"Loading checkpoint: {args.checkpoint}")
70 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda())
71 |
72 | v_front.load_state_dict(checkpoint['v_front_state_dict'])
73 | gen.load_state_dict(checkpoint['gen_state_dict'])
74 | post.load_state_dict(checkpoint['post_state_dict'])
75 | del checkpoint
76 |
77 | v_front.cuda()
78 | gen.cuda()
79 | post.cuda()
80 |
81 | if args.dataparallel:
82 | v_front = DP(v_front)
83 | gen = DP(gen)
84 | post = DP(post)
85 |
86 | _ = test(v_front, gen, post)
87 |
88 | def test(v_front, gen, post, fast_validate=False):
89 | with torch.no_grad():
90 | v_front.eval()
91 | gen.eval()
92 | post.eval()
93 |
94 | val_data = MultiDataset(
95 | grid=args.grid,
96 | mode='test',
97 | subject=args.subject,
98 | window_size=args.window_size,
99 | max_v_timesteps=args.max_timesteps,
100 | augmentations=False,
101 | fast_validate=fast_validate
102 | )
103 |
104 | dataloader = DataLoader(
105 | val_data,
106 | shuffle=False,
107 | batch_size=args.batch_size * 2,
108 | num_workers=args.workers,
109 | drop_last=False
110 | )
111 |
112 | stft = copy.deepcopy(val_data.stft).cuda()
113 | stoi_spec_list = []
114 | estoi_spec_list = []
115 | pesq_spec_list = []
116 | batch_size = dataloader.batch_size
117 | if fast_validate:
118 | samples = min(2 * batch_size, int(len(dataloader.dataset)))
119 | max_batches = 2
120 | else:
121 | samples = int(len(dataloader.dataset))
122 | max_batches = int(len(dataloader))
123 |
124 | description = 'Check validation step' if fast_validate else 'Validation'
125 | print(description)
126 | for i, batch in enumerate(dataloader):
127 | if i % 10 == 0:
128 | if not fast_validate:
129 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples))
130 | mel, spec, vid, vid_len, wav_tr, mel_len, f_name = batch
131 | vid = vid.cuda()
132 | phon, sent = v_front(vid) # S,B,512
133 | g1, g2, g3 = gen(sent, phon, vid_len)
134 | g3_temp = g3.clone()
135 |
136 | vid = vid.flip(4)
137 | phon, sent = v_front(vid) # S,B,512
138 | g1, g2, g3 = gen(sent, phon, vid_len)
139 |
140 | g3 = (g3_temp + g3) / 2.
141 | gs = post(g3)
142 |
143 | wav_spec = val_data.inverse_spec(gs[:, :, :, :mel_len[0]].detach(), stft)
144 |
145 | for b in range(g3.size(0)):
146 | stoi_spec_list.append(stoi(wav_tr[b][:len(wav_spec[b])], wav_spec[b], 16000, extended=False))
147 | estoi_spec_list.append(stoi(wav_tr[b][:len(wav_spec[b])], wav_spec[b], 16000, extended=True))
148 | pesq_spec_list.append(pesq(8000, librosa.resample(wav_tr[b][:len(wav_spec[b])].numpy(), 16000, 8000), librosa.resample(wav_spec[b], 16000, 8000), 'nb'))
149 |
150 | sub_name, _, file_name = f_name[b].split('/')
151 | if not os.path.exists(f'./test/spec_mel/{sub_name}'):
152 | os.makedirs(f'./test/spec_mel/{sub_name}')
153 | np.savez(f'./test/spec_mel/{sub_name}/{file_name}.npz',
154 | mel=g3[b, :, :, :mel_len[b]].detach().cpu().numpy(),
155 | spec=gs[b, :, :, :mel_len[b]].detach().cpu().numpy())
156 |
157 | if not os.path.exists(f'./test/wav/{sub_name}'):
158 | os.makedirs(f'./test/wav/{sub_name}')
159 | sf.write(f'./test/wav/{sub_name}/{file_name}.wav', wav_spec[b], 16000, subtype='PCM_16')
160 |
161 | if i >= max_batches:
162 | break
163 |
164 | print('STOI: ', np.mean(stoi_spec_list))
165 | print('ESTOI: ', np.mean(estoi_spec_list))
166 | print('PESQ: ', np.mean(pesq_spec_list))
167 | with open(f'./test/metric.txt', 'w') as f:
168 | f.write(f'STOI : {np.mean(stoi_spec_list)}')
169 | f.write(f'ESTOI : {np.mean(estoi_spec_list)}')
170 | f.write(f'PESQ : {np.mean(pesq_spec_list)}')
171 |
172 | if __name__ == "__main__":
173 | args = parse_args()
174 | train_net(args)
175 |
176 |
--------------------------------------------------------------------------------
/ASR_model/GRID/src/data/vid_aud_GRID_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torchaudio
8 | import torchvision
9 | from torchvision import transforms
10 | from torch.utils.data import DataLoader, Dataset
11 | import random
12 | from librosa.filters import mel as librosa_mel_fn
13 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression
14 | from src.data.stft import STFT
15 | import math, glob
16 | from scipy import signal
17 |
18 | log1e5 = math.log(1e-5)
19 |
20 | letters = ['_', ' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
21 | 'U', 'V', 'W', 'X', 'Y', 'Z']
22 |
23 |
24 | class MultiDataset(Dataset):
25 | def __init__(self, grid, mode, gtpath, subject, max_v_timesteps=155, max_text_len=150, augmentations=False, num_mel_bins=80, wav=False):
26 | self.wav = wav
27 | self.gtpath = gtpath
28 | self.max_v_timesteps = max_v_timesteps
29 | self.max_text_len = max_text_len
30 | self.augmentations = augmentations if mode == 'train' else False
31 | self.file_paths = self.build_file_list(grid, subject)
32 | self.int2char = dict(enumerate(letters))
33 | self.char2int = {char: index for index, char in self.int2char.items()}
34 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=num_mel_bins, sampling_rate=16000, mel_fmin=55., mel_fmax=7500.)
35 |
36 | def build_file_list(self, grid, subject):
37 | check_list = []
38 | if subject == 'overlap':
39 | with open('./../../data/overlap_val.txt', 'r') as f:
40 | lines = f.readlines()
41 | for l in lines:
42 | file = l.strip().replace('mpg_6000/', '') + '.mp4'
43 | check_list.append(os.path.join(grid, file))
44 | elif subject == 'unseen':
45 | with open('./../../data/unseen_splits.txt', 'r') as f:
46 | lines = f.readlines()
47 | for l in lines:
48 | if 'test' in l.strip():
49 | _, sub, fname = l.strip().split('/')
50 | file = f'{sub}/video/{fname}.mp4'
51 | if os.path.exists(os.path.join(grid, file)):
52 | check_list.append(os.path.join(grid, file))
53 | else:
54 | with open('./../../data/test_4.txt', 'r') as f:
55 | lines = f.readlines()
56 | for l in lines:
57 | file = l.strip()
58 | if subject == 'four':
59 | check_list.append(os.path.join(grid, file))
60 | elif file.split('/')[0] == subject:
61 | check_list.append(os.path.join(grid, file))
62 |
63 | if self.wav:
64 | file_list = sorted(glob.glob(os.path.join(grid, '*', '*.wav')))
65 | else:
66 | file_list = sorted(glob.glob(os.path.join(grid, '*', '*.npz')))
67 |
68 | assert len(check_list) == len(file_list), 'The data for testing is not full'
69 | return file_list
70 |
71 | def __len__(self):
72 | return len(self.file_paths)
73 |
74 | def build_content(self, content):
75 | words = []
76 | with open(content, 'r') as f:
77 | lines = f.readlines()
78 | for l in lines:
79 | word = l.strip().split()[2]
80 | if not word in ['SIL', 'SP', 'sil', 'sp']:
81 | words.append(word)
82 | return words
83 |
84 | def __getitem__(self, idx):
85 | file_path = self.file_paths[idx]
86 | t, f_name = os.path.split(file_path)
87 | _, sub = os.path.split(t)
88 |
89 | words = self.build_content(os.path.join(self.gtpath, sub.split('_')[0], 'align', f_name.split('.')[0] + '.align'))
90 | content = ' '.join(words).upper()
91 |
92 | if self.wav:
93 | aud, sr = torchaudio.load(file_path)
94 |
95 | if round(sr) != 16000:
96 | aud = torch.tensor(librosa.resample(aud.squeeze(0).numpy(), sr, 16000)).unsqueeze(0)
97 |
98 | aud = aud / torch.abs(aud).max() * 0.9
99 | aud = torch.FloatTensor(self.preemphasize(aud.squeeze(0))).unsqueeze(0)
100 | aud = torch.clamp(aud, min=-1, max=1)
101 |
102 | spec = self.stft.mel_spectrogram(aud)
103 | num_a_frames = spec.size(2)
104 |
105 | else:
106 | data = np.load(file_path)
107 | mel = data['mel']
108 | data.close()
109 |
110 | spec = torch.FloatTensor(self.denormalize(mel))
111 |
112 | num_a_frames = spec.size(2)
113 |
114 | target, txt_len = self.encode(content)
115 |
116 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(spec)
117 |
118 | return spec, target, num_a_frames, txt_len
119 |
120 | def encode(self, content):
121 | encoded = [self.char2int[c] for c in content]
122 | if len(encoded) > self.max_text_len:
123 | print(f"Max output length too short. Required {len(encoded)}")
124 | encoded = encoded[:self.max_text_len]
125 | num_txt = len(encoded)
126 | encoded += [self.char2int['_'] for _ in range(self.max_text_len - len(encoded))]
127 | return torch.Tensor(encoded), num_txt
128 |
129 | def preemphasize(self, aud):
130 | aud = signal.lfilter([1, -0.97], [1], aud)
131 | return aud
132 |
133 | def denormalize(self, melspec):
134 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5
135 | return melspec
136 |
137 | class TacotronSTFT(torch.nn.Module):
138 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
139 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
140 | mel_fmax=8000.0):
141 | super(TacotronSTFT, self).__init__()
142 | self.n_mel_channels = n_mel_channels
143 | self.sampling_rate = sampling_rate
144 | self.stft_fn = STFT(filter_length, hop_length, win_length)
145 | mel_basis = librosa_mel_fn(
146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
147 | mel_basis = torch.from_numpy(mel_basis).float()
148 | self.register_buffer('mel_basis', mel_basis)
149 |
150 | def spectral_normalize(self, magnitudes):
151 | output = dynamic_range_compression(magnitudes)
152 | return output
153 |
154 | def spectral_de_normalize(self, magnitudes):
155 | output = dynamic_range_decompression(magnitudes)
156 | return output
157 |
158 | def mel_spectrogram(self, y):
159 | """Computes mel-spectrograms from a batch of waves
160 | PARAMS
161 | ------
162 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
163 | RETURNS
164 | -------
165 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
166 | """
167 | assert(torch.min(y.data) >= -1)
168 | assert(torch.max(y.data) <= 1)
169 |
170 | magnitudes, phases = self.stft_fn.transform(y)
171 | magnitudes = magnitudes.data
172 | mel_output = torch.matmul(self.mel_basis, magnitudes)
173 | mel_output = self.spectral_normalize(mel_output)
174 | return mel_output
175 |
--------------------------------------------------------------------------------
/ASR_model/GRID/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import torch
4 | from torch import nn, optim
5 | from torch.utils.tensorboard import SummaryWriter
6 | import numpy as np
7 | from src.models.audio_front import Audio_front
8 | from src.models.classifier import Backend
9 | import os
10 | from torch.utils.data import DataLoader
11 | from torch.nn import functional as F
12 | from src.data.vid_aud_GRID_test import MultiDataset
13 | from torch.nn import DataParallel as DP
14 | import torch.nn.parallel
15 | import math
16 | import editdistance
17 | import re
18 | from matplotlib import pyplot as plt
19 | import time
20 | import glob
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--data', default="TEST_DIR", help='./../../test/spec_mel')
25 | parser.add_argument('--wav', default=False, action='store_true', help='Is waveform or Mel(.npz) form')
26 | parser.add_argument('--gtpath', default="GT_path", help='GT transcription path')
27 | parser.add_argument('--model', default="GRID_CTC")
28 | parser.add_argument("--checkpoint_dir", type=str, default='./data')
29 | parser.add_argument("--checkpoint", type=str, default='./data/GRID_4_wer_0.00833_cer_0.00252.ckpt')
30 | parser.add_argument("--batch_size", type=int, default=160)
31 | parser.add_argument("--epochs", type=int, default=150)
32 | parser.add_argument("--lr", type=float, default=0.001)
33 | parser.add_argument("--weight_decay", type=float, default=0.00001)
34 | parser.add_argument("--workers", type=int, default=4)
35 | parser.add_argument("--resnet", type=int, default=18)
36 | parser.add_argument("--seed", type=int, default=1)
37 |
38 | parser.add_argument("--subject", default='overlap', help=['overlap', 'unseen', 's1', 's2', 's4', 's29', 'four'])
39 |
40 | parser.add_argument("--max_timesteps", type=int, default=75)
41 | parser.add_argument("--max_text_len", type=int, default=75)
42 |
43 | parser.add_argument("--dataparallel", default=False, action='store_true')
44 | parser.add_argument("--gpu", type=str, default='0')
45 | args = parser.parse_args()
46 | return args
47 |
48 |
49 | def train_net(args):
50 | torch.backends.cudnn.deterministic = True
51 | torch.backends.cudnn.benchmark = False
52 | torch.manual_seed(args.seed)
53 | torch.cuda.manual_seed_all(args.seed)
54 | random.seed(args.seed)
55 | os.environ['OMP_NUM_THREADS'] = '2'
56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
57 |
58 | a_front = Audio_front()
59 | a_back = Backend()
60 |
61 | if args.subject == 'unseen':
62 | args.checkpoint = './data/GRID_unseen_wer_0.01676_cer_0.00896.ckpt'
63 | elif args.subject == 'overlap':
64 | args.checkpoint = './data/GRID_33_wer_0.00368_cer_0.00120.ckpt'
65 | else:
66 | args.checkpoint = './data/GRID_4_wer_0.00833_cer_0.00252.ckpt'
67 |
68 | if args.checkpoint is not None:
69 | print(f"Loading checkpoint: {args.checkpoint}")
70 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda())
71 | a_front.load_state_dict(checkpoint['a_front_state_dict'])
72 | a_back.load_state_dict(checkpoint['a_back_state_dict'])
73 | del checkpoint
74 |
75 | a_front.cuda()
76 | a_back.cuda()
77 |
78 | if args.dataparallel:
79 | a_front = DP(a_front)
80 | a_back = DP(a_back)
81 |
82 | wer, cer = validate(a_front, a_back)
83 |
84 | def validate(a_front, a_back, fast_validate=False):
85 | with torch.no_grad():
86 | a_front.eval()
87 | a_back.eval()
88 |
89 | val_data = MultiDataset(
90 | grid=args.data,
91 | mode='test',
92 | gtpath=args.gtpath,
93 | subject=args.subject,
94 | max_v_timesteps=args.max_timesteps,
95 | max_text_len=args.max_text_len,
96 | wav=args.wav,
97 | augmentations=False
98 | )
99 |
100 | dataloader = DataLoader(
101 | val_data,
102 | shuffle=False,
103 | batch_size=args.batch_size * 2,
104 | num_workers=args.workers,
105 | drop_last=False
106 | )
107 |
108 | batch_size = dataloader.batch_size
109 | if fast_validate:
110 | samples = min(2 * batch_size, int(len(dataloader.dataset)))
111 | max_batches = 2
112 | else:
113 | samples = int(len(dataloader.dataset))
114 | max_batches = int(len(dataloader))
115 |
116 | wer_sum, cer_sum, tot_num = 0, 0, 0
117 |
118 | description = 'Check validation step' if fast_validate else 'Validation'
119 | print(description)
120 | for i, batch in enumerate(dataloader):
121 | if i % 50 == 0:
122 | if not fast_validate:
123 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples))
124 | a_in, target, aud_len, txt_len = batch
125 |
126 | a_feat = a_front(a_in.cuda()) # B,S,512
127 | pred = a_back(a_feat)
128 |
129 | cer, wer, sentences = greedy_decode(val_data, F.softmax(pred, dim=2).cpu(), target)
130 |
131 | B, S, _ = a_feat.size()
132 |
133 | tot_num += B
134 | wer_sum += B * wer
135 | cer_sum += B * cer
136 | batch_size = B
137 |
138 | if i % 50 == 0:
139 | for j in range(2):
140 | print('label: ', sentences[j][0])
141 | print('prediction: ', sentences[j][1])
142 |
143 | if i >= max_batches:
144 | break
145 |
146 | a_front.train()
147 | a_back.train()
148 |
149 | wer = wer_sum / tot_num
150 | cer = cer_sum / tot_num
151 |
152 | print('test_cer:', cer)
153 | print('test_wer:', wer)
154 |
155 | if fast_validate:
156 | return {}
157 | else:
158 | return wer, cer
159 |
160 | def decode(dataset, label_tokens, pred_tokens):
161 | label, output = '', ''
162 | for index in range(len(label_tokens)):
163 | label += dataset.int2char[int(label_tokens[index])]
164 | for index in range(len(pred_tokens)):
165 | output += dataset.int2char[int(pred_tokens[index])]
166 |
167 | output = re.sub(' +', ' ', output)
168 | pattern = re.compile(r"(.)\1{1,}", re.DOTALL) # remove characters that are repeated more than 2 times
169 | output = pattern.sub(r"\1", output)
170 |
171 | label = label.replace('_', '')
172 | output = output.replace('_', '')
173 |
174 | output_words, label_words = output.split(" "), label.split(" ")
175 |
176 | cer = editdistance.eval(output, label) / len(label)
177 | wer = editdistance.eval(output_words, label_words) / len(label_words)
178 |
179 | return label, output, cer, wer
180 |
181 | def greedy_decode(dataset, results, target):
182 | _, results = results.topk(1, dim=2)
183 | results = results.squeeze(dim=2)
184 | cer_sum, wer_sum = 0, 0
185 | batch_size = results.size(0)
186 | sentences = []
187 | for batch in range(batch_size):
188 | label, output, cer, wer = decode(dataset, target[batch], results[batch])
189 | sentences.append([label, output])
190 | cer_sum += cer
191 | wer_sum += wer
192 |
193 | return cer_sum / batch_size, wer_sum / batch_size, sentences
194 |
195 | if __name__ == "__main__":
196 | args = parse_args()
197 | train_net(args)
198 |
199 |
--------------------------------------------------------------------------------
/test_LRS.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import torch
4 | from torch import nn, optim
5 | from torch.utils.tensorboard import SummaryWriter
6 | import numpy as np
7 | from src.models.visual_front import Visual_front
8 | from src.models.generator import Decoder, Discriminator, gan_loss, sync_Discriminator, Postnet
9 | import os
10 | from torch.utils.data import DataLoader
11 | from torch.nn import functional as F
12 | from src.data.vid_aud_lrs2 import MultiDataset as LRS2_dataset
13 | from src.data.vid_aud_lrs3 import MultiDataset as LRS3_dataset
14 | from torch.nn import DataParallel as DP
15 | import torch.nn.parallel
16 | import time
17 | import glob
18 | from torch.autograd import grad
19 | import soundfile as sf
20 | from pesq import pesq
21 | from pystoi import stoi
22 | from matplotlib import pyplot as plt
23 | import copy, librosa
24 |
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--data', default="Data_dir")
29 | parser.add_argument('--data_name', type=str, default="LRS2")
30 | parser.add_argument("--checkpoint_dir", type=str, default='./data/checkpoints/')
31 | parser.add_argument("--checkpoint", type=str, default='checkpoint')
32 | parser.add_argument("--batch_size", type=int, default=16)
33 | parser.add_argument("--epochs", type=int, default=1000)
34 | parser.add_argument("--lr", type=float, default=0.0001)
35 | parser.add_argument("--weight_decay", type=float, default=0.00001)
36 | parser.add_argument("--workers", type=int, default=10)
37 | parser.add_argument("--seed", type=int, default=1)
38 |
39 | parser.add_argument("--subject", type=str, default='overlap')
40 |
41 | parser.add_argument("--f_min", type=int, default=55.)
42 | parser.add_argument("--f_max", type=int, default=7600.)
43 |
44 | parser.add_argument("--start_epoch", type=int, default=0)
45 | parser.add_argument("--augmentations", default=True)
46 |
47 | parser.add_argument("--window_size", type=int, default=50)
48 | parser.add_argument("--max_timesteps", type=int, default=160)
49 |
50 | parser.add_argument("--dataparallel", default=False, action='store_true')
51 | parser.add_argument("--gpu", type=str, default='0')
52 |
53 | parser.add_argument("--save_mel", default=True, action='store_true')
54 | parser.add_argument("--save_wav", default=True, action='store_true')
55 |
56 | args = parser.parse_args()
57 | return args
58 |
59 |
60 | def train_net(args):
61 | torch.backends.cudnn.deterministic = False
62 | torch.backends.cudnn.benchmark = True
63 | torch.manual_seed(args.seed)
64 | torch.cuda.manual_seed_all(args.seed)
65 | random.seed(args.seed)
66 | os.environ['OMP_NUM_THREADS'] = '2'
67 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
68 |
69 | v_front = Visual_front(in_channels=1)
70 | gen = Decoder()
71 | post = Postnet()
72 |
73 | print(f"Loading checkpoint: {args.checkpoint}")
74 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda())
75 |
76 | v_front.load_state_dict(checkpoint['v_front_state_dict'])
77 | gen.load_state_dict(checkpoint['gen_state_dict'])
78 | post.load_state_dict(checkpoint['post_state_dict'])
79 | del checkpoint
80 |
81 | v_front.cuda()
82 | gen.cuda()
83 | post.cuda()
84 |
85 | if args.dataparallel:
86 | v_front = DP(v_front)
87 | gen = DP(gen)
88 | post = DP(post)
89 |
90 | _ = test(v_front, gen, post)
91 |
92 | def test(v_front, gen, post, fast_validate=False):
93 | with torch.no_grad():
94 | v_front.eval()
95 | gen.eval()
96 | post.eval()
97 |
98 | if args.data_name == 'LRS2':
99 | val_data = LRS2_dataset(
100 | data=args.data,
101 | mode='test',
102 | max_v_timesteps=args.max_timesteps,
103 | window_size=args.window_size,
104 | augmentations=False,
105 | f_min=args.f_min,
106 | f_max=args.f_max
107 | )
108 | elif args.data_name == 'LRS3':
109 | val_data = LRS3_dataset(
110 | data=args.data,
111 | mode='test',
112 | max_v_timesteps=args.max_timesteps,
113 | window_size=args.window_size,
114 | augmentations=False,
115 | f_min=args.f_min,
116 | f_max=args.f_max
117 | )
118 |
119 | dataloader = DataLoader(
120 | val_data,
121 | shuffle=True if fast_validate else False,
122 | batch_size=args.batch_size,
123 | num_workers=args.workers,
124 | drop_last=False,
125 | collate_fn=lambda x: val_data.collate_fn(x)
126 | )
127 |
128 | stft = copy.deepcopy(val_data.stft).cuda()
129 | stoi_spec_list = []
130 | estoi_spec_list = []
131 | pesq_spec_list = []
132 | batch_size = dataloader.batch_size
133 | if fast_validate:
134 | samples = min(2 * batch_size, int(len(dataloader.dataset)))
135 | max_batches = 2
136 | else:
137 | samples = int(len(dataloader.dataset))
138 | max_batches = int(len(dataloader))
139 |
140 | description = 'Check validation step' if fast_validate else 'Validation'
141 | print(description)
142 | for i, batch in enumerate(dataloader):
143 | if i % 10 == 0:
144 | if not fast_validate:
145 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples))
146 | mel, spec, vid, vid_len, wav_tr, mel_len, f_name = batch
147 |
148 | vid = vid.cuda()
149 | phon, sent = v_front(vid) # S,B,512
150 | g1, g2, g3 = gen(sent, phon, vid_len)
151 | g3_temp = g3.clone()
152 |
153 | vid = vid.flip(4)
154 | phon, sent = v_front(vid) # S,B,512
155 | g1, g2, g3 = gen(sent, phon, vid_len)
156 |
157 | g3 = (g3_temp + g3) / 2.
158 | gs = post(g3)
159 |
160 | for b in range(g3.size(0)):
161 | wav_spec = val_data.inverse_spec(gs[b, :, :, :mel_len[b]].detach(), stft)[0]
162 | min_length = min(len(wav_spec), len(wav_tr[b]))
163 | stoi_spec_list.append(stoi(wav_tr[b][:min_length], wav_spec[:min_length], 16000, extended=False))
164 | estoi_spec_list.append(stoi(wav_tr[b][:min_length], wav_spec[:min_length], 16000, extended=True))
165 | pesq_spec_list.append(pesq(8000, librosa.resample(wav_tr[b][:min_length].numpy(), 16000, 8000), librosa.resample(wav_spec, 16000, 8000), 'nb'))
166 |
167 | m_name, v_name, file_name = f_name[b].split('/')
168 | if not os.path.exists(f'./test/{args.data_name}/mel/{m_name}/{v_name}'):
169 | os.makedirs(f'./test/{args.data_name}/mel/{m_name}/{v_name}')
170 | np.savez(f'./test/{args.data_name}/mel/{m_name}/{v_name}/{file_name}.npz',
171 | mel=g3[b, :, :, :mel_len[b]].detach().cpu().numpy(),
172 | spec=gs[b, :, :, :mel_len[b]].detach().cpu().numpy())
173 |
174 | if not os.path.exists(f'./test/{args.data_name}/wav/{m_name}/{v_name}'):
175 | os.makedirs(f'./test/{args.data_name}/wav/{m_name}/{v_name}')
176 | sf.write(f'./test/{args.data_name}/wav/{m_name}/{v_name}/{file_name}.wav', wav_spec, 16000, subtype='PCM_16')
177 |
178 |
179 | if i >= max_batches:
180 | break
181 |
182 | print('STOI: ', np.mean(stoi_spec_list))
183 | print('ESTOI: ', np.mean(estoi_spec_list))
184 | print('PESQ: ', np.mean(pesq_spec_list))
185 | with open(f'./test/{args.data_name}/metric.txt', 'w') as f:
186 | f.write(f'STOI : {np.mean(stoi_spec_list)}')
187 | f.write(f'ESTOI : {np.mean(estoi_spec_list)}')
188 | f.write(f'PESQ : {np.mean(pesq_spec_list)}')
189 |
190 | if __name__ == "__main__":
191 | args = parse_args()
192 | train_net(args)
193 |
194 |
--------------------------------------------------------------------------------
/src/data/vid_aud_grid.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torchaudio
8 | from torchvision import transforms
9 | from torch.utils.data import DataLoader, Dataset
10 | from src.data.transforms import Crop, StatefulRandomHorizontalFlip
11 | from PIL import Image
12 | import librosa
13 | from matplotlib import pyplot as plt
14 | import glob
15 | from scipy import signal
16 | import torchvision
17 | from torch.autograd import Variable
18 | from librosa.filters import mel as librosa_mel_fn
19 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression, griffin_lim
20 | from src.data.stft import STFT
21 | import math
22 | log1e5 = math.log(1e-5)
23 |
24 | class MultiDataset(Dataset):
25 | def __init__(self, grid, mode, max_v_timesteps=155, window_size=40, subject=None, augmentations=False, num_mel_bins=80, fast_validate=False):
26 | assert mode in ['train', 'test', 'val']
27 | self.grid = grid
28 | self.mode = mode
29 | self.sample_window = True if mode == 'train' else False
30 | self.fast_validate = fast_validate
31 | self.max_v_timesteps = window_size if self.sample_window else max_v_timesteps
32 | self.window_size = window_size
33 | self.augmentations = augmentations if mode == 'train' else False
34 | self.num_mel_bins = num_mel_bins
35 | self.file_paths = self.build_file_list(grid, mode, subject)
36 | self.f_min = 55.
37 | self.f_max = 7500.
38 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=80, sampling_rate=16000, mel_fmin=self.f_min, mel_fmax=self.f_max)
39 |
40 | def build_file_list(self, grid, mode, subject):
41 | file_list = []
42 | if subject == 'overlap':
43 | if mode == 'train':
44 | with open('./data/overlap_train.txt', 'r') as f:
45 | lines = f.readlines()
46 | for l in lines:
47 | file = l.strip().replace('mpg_6000/', '') + '.mp4'
48 | file_list.append(os.path.join(grid, file))
49 | else:
50 | with open('./data/overlap_val.txt', 'r') as f:
51 | lines = f.readlines()
52 | for l in lines:
53 | file = l.strip().replace('mpg_6000/', '') + '.mp4'
54 | file_list.append(os.path.join(grid, file))
55 | elif subject == 'unseen':
56 | with open('./data/unseen_splits.txt', 'r') as f:
57 | lines = f.readlines()
58 | for l in lines:
59 | if mode in l.strip():
60 | _, sub, fname = l.strip().split('/')
61 | file = f'{sub}/video/{fname}.mp4'
62 | if os.path.exists(os.path.join(grid, file)):
63 | file_list.append(os.path.join(grid, file))
64 | else:
65 | if mode == 'train':
66 | with open('./data/train_4.txt', 'r') as f:
67 | lines = f.readlines()
68 | for l in lines:
69 | file = l.strip()
70 | if subject == 'four':
71 | file_list.append(os.path.join(grid, file))
72 | elif file.split('/')[0] == subject:
73 | file_list.append(os.path.join(grid, file))
74 | elif mode == 'val':
75 | with open('./data/val_4.txt', 'r') as f:
76 | lines = f.readlines()
77 | for l in lines:
78 | file = l.strip()
79 | if subject == 'four':
80 | file_list.append(os.path.join(grid, file))
81 | elif file.split('/')[0] == subject:
82 | file_list.append(os.path.join(grid, file))
83 | else:
84 | with open('./data/test_4.txt', 'r') as f:
85 | lines = f.readlines()
86 | for l in lines:
87 | file = l.strip()
88 | if subject == 'four':
89 | file_list.append(os.path.join(grid, file))
90 | elif file.split('/')[0] == subject:
91 | file_list.append(os.path.join(grid, file))
92 | return file_list
93 |
94 | def build_tensor(self, frames):
95 | if self.augmentations:
96 | augmentations1 = transforms.Compose([StatefulRandomHorizontalFlip(0.5)])
97 | else:
98 | augmentations1 = transforms.Compose([])
99 | crop = [59, 95, 195, 231]
100 |
101 | transform = transforms.Compose([
102 | transforms.ToPILImage(),
103 | Crop(crop),
104 | transforms.Resize([112, 112]),
105 | augmentations1,
106 | transforms.Grayscale(num_output_channels=1),
107 | transforms.ToTensor(),
108 | transforms.Normalize(0.4136, 0.1700)
109 | ])
110 |
111 | temporalVolume = torch.zeros(self.max_v_timesteps, 1, 112, 112)
112 | for i, frame in enumerate(frames):
113 | temporalVolume[i] = transform(frame)
114 |
115 | ### Random Erasing ###
116 | if self.augmentations:
117 | x_s, y_s = [random.randint(-10, 66) for _ in range(2)] # starting point
118 | temporalVolume[:, :, np.maximum(0, y_s):np.minimum(112, y_s + 56), np.maximum(0, x_s):np.minimum(112, x_s + 56)] = 0.
119 |
120 | temporalVolume = temporalVolume.transpose(1, 0) # (C, T, H, W)
121 | return temporalVolume
122 |
123 | def __len__(self):
124 | return len(self.file_paths)
125 |
126 | def __getitem__(self, idx):
127 | file_path = self.file_paths[idx]
128 |
129 | vid, _, info = torchvision.io.read_video(file_path, pts_unit='sec')
130 | audio, info['audio_fps'] = librosa.load(file_path.replace('video', 'audio')[:-4] + '.flac', sr=16000)
131 | audio = torch.FloatTensor(audio).unsqueeze(0)
132 |
133 | if not 'video_fps' in info:
134 | info['video_fps'] = 25
135 | info['audio_fps'] = 16000
136 |
137 | if vid.size(0) < 5 or audio.size(1) < 5:
138 | vid = torch.zeros([1, 112, 112, 3])
139 | audio = torch.zeros([1, 16000//25])
140 |
141 | ## Audio ##
142 | aud = audio / torch.abs(audio).max() * 0.9
143 | aud = torch.FloatTensor(self.preemphasize(aud.squeeze(0))).unsqueeze(0)
144 | aud = torch.clamp(aud, min=-1, max=1)
145 |
146 | melspec, spec = self.stft.mel_spectrogram(aud)
147 |
148 | ## Video ##
149 | vid = vid.permute(0, 3, 1, 2) # T C H W
150 |
151 | if self.sample_window:
152 | vid, melspec, spec, audio = self.extract_window(vid, melspec, spec, audio, info)
153 |
154 | num_v_frames = vid.size(0)
155 | vid = self.build_tensor(vid)
156 |
157 | melspec = self.normalize(melspec)
158 |
159 | num_a_frames = melspec.size(2)
160 | melspec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(melspec)
161 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(spec)
162 |
163 | if not self.sample_window:
164 | audio = audio[:, :self.max_v_timesteps * 4 * 160]
165 | audio = torch.cat([audio, torch.zeros([1, int(self.max_v_timesteps / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1)
166 |
167 | if self.mode == 'test':
168 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames, file_path.replace(self.grid, '')[1:-4]
169 | else:
170 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames
171 |
172 | def extract_window(self, vid, mel, spec, aud, info):
173 | # vid : T,C,H,W
174 | vid_2_aud = info['audio_fps'] / info['video_fps'] / 160
175 |
176 | st_fr = random.randint(0, vid.size(0) - self.window_size)
177 | vid = vid[st_fr:st_fr + self.window_size]
178 |
179 | st_mel_fr = int(st_fr * vid_2_aud)
180 | mel_window_size = int(self.window_size * vid_2_aud)
181 |
182 | mel = mel[:, :, st_mel_fr:st_mel_fr + mel_window_size]
183 | spec = spec[:, :, st_mel_fr:st_mel_fr + mel_window_size]
184 |
185 | aud = aud[:, st_mel_fr*160:st_mel_fr*160 + mel_window_size*160]
186 | aud = torch.cat([aud, torch.zeros([1, int(self.window_size / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1)
187 |
188 | return vid, mel, spec, aud
189 |
190 | def inverse_mel(self, mel, stft):
191 | if len(mel.size()) < 4:
192 | mel = mel.unsqueeze(0) # B,1,80,T
193 |
194 | mel = self.denormalize(mel)
195 | mel = stft.spectral_de_normalize(mel)
196 | mel = mel.transpose(2, 3).contiguous() # B,80,T --> B,T,80
197 | spec_from_mel_scaling = 1000
198 | spec_from_mel = torch.matmul(mel, stft.mel_basis)
199 | spec_from_mel = spec_from_mel.transpose(2, 3).squeeze(1) # B,1,F,T
200 | spec_from_mel = spec_from_mel * spec_from_mel_scaling
201 |
202 | wav = griffin_lim(spec_from_mel, stft.stft_fn, 60).squeeze(1) # B,L
203 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy()
204 | wavs = []
205 | for w in wav:
206 | w = self.deemphasize(w)
207 | wavs += [w]
208 | wavs = np.array(wavs)
209 | wavs = np.clip(wavs, -1, 1)
210 | return wavs
211 |
212 | def inverse_spec(self, spec, stft):
213 | if len(spec.size()) < 4:
214 | spec = spec.unsqueeze(0) # B,1,321,T
215 |
216 | wav = griffin_lim(spec.squeeze(1), stft.stft_fn, 60).squeeze(1) # B,L
217 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy()
218 | wavs = []
219 | for w in wav:
220 | w = self.deemphasize(w)
221 | wavs += [w]
222 | wavs = np.array(wavs)
223 | wavs = np.clip(wavs, -1, 1)
224 | return wavs
225 |
226 | def preemphasize(self, aud):
227 | aud = signal.lfilter([1, -0.97], [1], aud)
228 | return aud
229 |
230 | def deemphasize(self, aud):
231 | aud = signal.lfilter([1], [1, -0.97], aud)
232 | return aud
233 |
234 | def normalize(self, melspec):
235 | melspec = ((melspec - log1e5) / (-log1e5 / 2)) - 1 #0~2 --> -1~1
236 | return melspec
237 |
238 | def denormalize(self, melspec):
239 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5
240 | return melspec
241 |
242 | def audio_preprocessing(self, aud):
243 | fc = self.f_min
244 | w = fc / (16000 / 2)
245 | b, a = signal.butter(7, w, 'high')
246 | aud = aud.squeeze(0).numpy()
247 | aud = signal.filtfilt(b, a, aud)
248 | return torch.tensor(aud.copy()).unsqueeze(0)
249 |
250 | def plot_spectrogram_to_numpy(self, mels):
251 | fig, ax = plt.subplots(figsize=(15, 4))
252 | im = ax.imshow(np.squeeze(mels, 0), aspect="auto", origin="lower",
253 | interpolation='none')
254 | plt.colorbar(im, ax=ax)
255 | plt.xlabel("Frames")
256 | plt.ylabel("Channels")
257 | plt.tight_layout()
258 |
259 | fig.canvas.draw()
260 | data = self.save_figure_to_numpy(fig)
261 | plt.close()
262 | return data
263 |
264 | def save_figure_to_numpy(self, fig):
265 | # save it to a numpy array.
266 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
267 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
268 | return data.transpose(2, 0, 1)
269 |
270 | class TacotronSTFT(torch.nn.Module):
271 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
272 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
273 | mel_fmax=8000.0):
274 | super(TacotronSTFT, self).__init__()
275 | self.n_mel_channels = n_mel_channels
276 | self.sampling_rate = sampling_rate
277 | self.stft_fn = STFT(filter_length, hop_length, win_length)
278 | mel_basis = librosa_mel_fn(
279 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
280 | mel_basis = torch.from_numpy(mel_basis).float()
281 | self.register_buffer('mel_basis', mel_basis)
282 |
283 | def spectral_normalize(self, magnitudes):
284 | output = dynamic_range_compression(magnitudes)
285 | return output
286 |
287 | def spectral_de_normalize(self, magnitudes):
288 | output = dynamic_range_decompression(magnitudes)
289 | return output
290 |
291 | def mel_spectrogram(self, y):
292 | """Computes mel-spectrograms from a batch of waves
293 | PARAMS
294 | ------
295 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
296 | RETURNS
297 | -------
298 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
299 | """
300 | assert(torch.min(y.data) >= -1)
301 | assert(torch.max(y.data) <= 1)
302 |
303 | magnitudes, phases = self.stft_fn.transform(y)
304 | magnitudes = magnitudes.data
305 | mel_output = torch.matmul(self.mel_basis, magnitudes)
306 | mel_output = self.spectral_normalize(mel_output)
307 | return mel_output, magnitudes
308 |
--------------------------------------------------------------------------------
/src/models/generator.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from src.models.resnet import BasicBlock
7 |
8 | class ResBlk1D(nn.Module):
9 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
10 | normalize=False, downsample=False):
11 | super().__init__()
12 | self.actv = actv
13 | self.normalize = normalize
14 | self.downsample = downsample
15 | self.learned_sc = dim_in != dim_out
16 | self._build_weights(dim_in, dim_out)
17 |
18 | def _build_weights(self, dim_in, dim_out):
19 | self.conv1 = nn.Conv1d(dim_in, dim_in, 5, 1, 2)
20 | self.conv2 = nn.Conv1d(dim_in, dim_out, 5, 1, 2)
21 | if self.normalize:
22 | self.norm1 = nn.BatchNorm1d(dim_in)
23 | self.norm2 = nn.BatchNorm1d(dim_in)
24 | if self.learned_sc:
25 | self.conv1x1 = nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False)
26 |
27 | def _shortcut(self, x):
28 | if self.learned_sc:
29 | x = self.conv1x1(x)
30 | if self.downsample:
31 | x = F.avg_pool1d(x, 2)
32 | return x
33 |
34 | def _residual(self, x):
35 | if self.normalize:
36 | x = self.norm1(x)
37 | x = self.actv(x)
38 | x = self.conv1(x)
39 | if self.downsample:
40 | x = F.avg_pool1d(x, 2)
41 | if self.normalize:
42 | x = self.norm2(x)
43 | x = self.actv(x)
44 | x = self.conv2(x)
45 | return x
46 |
47 | def forward(self, x):
48 | x = self._shortcut(x) + self._residual(x)
49 | return x / math.sqrt(2) # unit variance
50 |
51 | class ResBlk(nn.Module):
52 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
53 | normalize=False, downsample=False):
54 | super().__init__()
55 | self.actv = actv
56 | self.normalize = normalize
57 | self.downsample = downsample
58 | self.learned_sc = dim_in != dim_out
59 | self._build_weights(dim_in, dim_out)
60 |
61 | def _build_weights(self, dim_in, dim_out):
62 | self.conv1 = nn.Conv2d(dim_in, dim_in, 5, 1, 2)
63 | self.conv2 = nn.Conv2d(dim_in, dim_out, 5, 1, 2)
64 | if self.normalize:
65 | self.norm1 = nn.BatchNorm2d(dim_in)
66 | self.norm2 = nn.BatchNorm2d(dim_in)
67 | if self.learned_sc:
68 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
69 |
70 | def _shortcut(self, x):
71 | if self.learned_sc:
72 | x = self.conv1x1(x)
73 | if self.downsample:
74 | x = F.avg_pool2d(x, 2)
75 | return x
76 |
77 | def _residual(self, x):
78 | if self.normalize:
79 | x = self.norm1(x)
80 | x = self.actv(x)
81 | x = self.conv1(x)
82 | if self.downsample:
83 | x = F.avg_pool2d(x, 2)
84 | if self.normalize:
85 | x = self.norm2(x)
86 | x = self.actv(x)
87 | x = self.conv2(x)
88 | return x
89 |
90 | def forward(self, x):
91 | x = self._shortcut(x) + self._residual(x)
92 | return x / math.sqrt(2) # unit variance
93 |
94 | class GenResBlk(nn.Module):
95 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), upsample=False):
96 | super().__init__()
97 | self.actv = actv
98 | self.upsample = upsample
99 | self.learned_sc = dim_in != dim_out
100 | self._build_weights(dim_in, dim_out)
101 |
102 | def _build_weights(self, dim_in, dim_out):
103 | self.conv1 = nn.Conv2d(dim_in, dim_out, 5, 1, 2)
104 | self.conv2 = nn.Conv2d(dim_out, dim_out, 5, 1, 2)
105 | self.norm1 = nn.BatchNorm2d(dim_in)
106 | self.norm2 = nn.BatchNorm2d(dim_out)
107 | if self.learned_sc:
108 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
109 |
110 | def _shortcut(self, x):
111 | if self.upsample:
112 | x = F.interpolate(x, scale_factor=2, mode='nearest')
113 | if self.learned_sc:
114 | x = self.conv1x1(x)
115 | return x
116 |
117 | def _residual(self, x):
118 | x = self.norm1(x)
119 | x = self.actv(x)
120 | if self.upsample:
121 | x = F.interpolate(x, scale_factor=2, mode='nearest')
122 | x = self.conv1(x)
123 | x = self.norm2(x)
124 | x = self.actv(x)
125 | x = self.conv2(x)
126 | return x
127 |
128 | def forward(self, x):
129 | out = self._residual(x)
130 | out = (out + self._shortcut(x)) / math.sqrt(2)
131 | return out
132 |
133 | class Flatten(nn.Module):
134 | def forward(self, input):
135 | return input.view(input.size(0), -1)
136 |
137 | class Avgpool(nn.Module):
138 | def forward(self, input):
139 | #input:B,C,H,W
140 | return input.mean([2, 3])
141 |
142 | class AVAttention(nn.Module):
143 | def __init__(self, out_dim):
144 | super().__init__()
145 |
146 | self.softmax = nn.Softmax(2)
147 | self.k = nn.Linear(512, out_dim)
148 | self.v = nn.Linear(512, out_dim)
149 | self.q = nn.Linear(2560, out_dim)
150 | self.out_dim = out_dim
151 | dim = 20 * 64
152 | self.mel = nn.Linear(out_dim, dim)
153 |
154 | def forward(self, ph, g, len):
155 | #ph: B,S,512
156 | #g: B,C,F,T
157 | B, C, F, T = g.size()
158 | k = self.k(ph).transpose(1, 2).contiguous() # B,256,S
159 | q = self.q(g.view(B, C * F, T).transpose(1, 2).contiguous()) # B,T,256
160 |
161 | att = torch.bmm(q, k) / math.sqrt(self.out_dim) # B,T,S
162 | for i in range(att.size(0)):
163 | att[i, :, len[i]:] = float('-inf')
164 | att = self.softmax(att) # B,T,S
165 |
166 | v = self.v(ph) # B,S,256
167 | value = torch.bmm(att, v) # B,T,256
168 | out = self.mel(value) # B, T, 20*64
169 | out = out.view(B, T, F, -1).permute(0, 3, 2, 1)
170 |
171 | return out #B,C,F,T
172 |
173 | class Postnet(nn.Module):
174 | def __init__(self):
175 | super().__init__()
176 |
177 | self.postnet = nn.Sequential(
178 | nn.Conv1d(80, 128, 7, 1, 3),
179 | nn.BatchNorm1d(128),
180 | nn.LeakyReLU(0.2),
181 | ResBlk1D(128, 256),
182 | ResBlk1D(256, 256),
183 | ResBlk1D(256, 256),
184 | nn.Conv1d(256, 321, 1, 1, 0, bias=False)
185 | )
186 |
187 | def forward(self, x):
188 | # x: B,1,80,T
189 | x = x.squeeze(1) # B, 80, t
190 | x = self.postnet(x) # B, 321, T
191 | x = x.unsqueeze(1) # B, 1, 321, T
192 | return x
193 |
194 | class Decoder(nn.Module):
195 | def __init__(self):
196 | super().__init__()
197 |
198 | self.decode = nn.ModuleList()
199 | self.g1 = nn.ModuleList()
200 | self.g2 = nn.ModuleList()
201 | self.g3 = nn.ModuleList()
202 |
203 | self.att1 = AVAttention(256)
204 | self.attconv1 = nn.Conv2d(128 + 64, 128, 5, 1, 2)
205 | self.att2 = AVAttention(256)
206 | self.attconv2 = nn.Conv2d(64 + 32, 64, 5, 1, 2)
207 |
208 | self.to_mel1 = nn.Sequential(
209 | nn.BatchNorm2d(128),
210 | nn.LeakyReLU(0.2),
211 | nn.Conv2d(128, 1, 1, 1, 0),
212 | nn.Tanh()
213 | )
214 | self.to_mel2 = nn.Sequential(
215 | nn.BatchNorm2d(64),
216 | nn.LeakyReLU(0.2),
217 | nn.Conv2d(64, 1, 1, 1, 0),
218 | nn.Tanh()
219 | )
220 | self.to_mel3 = nn.Sequential(
221 | nn.BatchNorm2d(32),
222 | nn.LeakyReLU(0.2),
223 | nn.Conv2d(32, 1, 1, 1, 0),
224 | nn.Tanh()
225 | )
226 |
227 | # bottleneck blocks
228 | self.decode.append(GenResBlk(512 + 128, 512)) # 20,T
229 | self.decode.append(GenResBlk(512, 256))
230 | self.decode.append(GenResBlk(256, 256))
231 |
232 | # up-sampling blocks
233 | self.g1.append(GenResBlk(256, 128)) # 20,T
234 | self.g1.append(GenResBlk(128, 128))
235 | self.g1.append(GenResBlk(128, 128))
236 |
237 | self.g2.append(GenResBlk(128, 64, upsample=True)) # 40,2T
238 | self.g2.append(GenResBlk(64, 64))
239 | self.g2.append(GenResBlk(64, 64))
240 |
241 | self.g3.append(GenResBlk(64, 32, upsample=True)) # 80,4T
242 | self.g3.append(GenResBlk(32, 32))
243 | self.g3.append(GenResBlk(32, 32))
244 |
245 | def forward(self, s, x, len):
246 | # s: B,512,T x: B,T,512
247 | s = s.transpose(1, 2).contiguous()
248 | n = torch.randn([x.size(0), 128, 20, x.size(1)]).cuda() # B,128,20,T
249 | x = x.transpose(1, 2).contiguous().unsqueeze(2).repeat(1, 1, 20, 1) # B, 512, 20, T
250 | x = torch.cat([x, n], 1)
251 | for block in self.decode:
252 | x = block(x)
253 | for block in self.g1:
254 | x = block(x)
255 | g1 = x.clone()
256 | c1 = self.att1(s, g1, len)
257 | x = self.attconv1(torch.cat([x, c1], 1))
258 | for block in self.g2:
259 | x = block(x)
260 | g2 = x.clone()
261 | c2 = self.att2(s, g2, len)
262 | x = self.attconv2(torch.cat([x, c2], 1))
263 | for block in self.g3:
264 | x = block(x)
265 | return self.to_mel1(g1), self.to_mel2(g2), self.to_mel3(x)
266 |
267 | class Discriminator(nn.Module):
268 | def __init__(self, num_class=1, max_conv_dim=512, phase='1'):
269 | super().__init__()
270 | dim_in = 32
271 | blocks = []
272 | blocks += [nn.Conv2d(1, dim_in, 5, 1, 2)]
273 |
274 | if phase == '1':
275 | repeat_num = 2
276 | elif phase == '2':
277 | repeat_num = 3
278 | else:
279 | repeat_num = 4
280 |
281 | for _ in range(repeat_num): # 80,4T --> 40,2T --> 20,T --> 10,T/2 --> 5,T/4
282 | dim_out = min(dim_in*2, max_conv_dim)
283 | blocks += [ResBlk(dim_in, dim_out, downsample=True)]
284 | dim_in = dim_out
285 |
286 | self.main = nn.Sequential(*blocks)
287 |
288 | uncond = []
289 | uncond += [nn.LeakyReLU(0.2)]
290 | uncond += [nn.Conv2d(dim_out, dim_out, 5, 1, 0)]
291 | uncond += [nn.LeakyReLU(0.2)]
292 | uncond += [Avgpool()]
293 | uncond += [nn.Linear(dim_out, num_class)]
294 | self.uncond = nn.Sequential(*uncond)
295 |
296 | cond = []
297 | cond += [nn.LeakyReLU(0.2)]
298 | cond += [nn.Conv2d(dim_out + 512, dim_out, 5, 1, 2)]
299 | cond += [nn.LeakyReLU(0.2)]
300 | cond += [nn.Conv2d(dim_out, dim_out, 5, 1, 0)]
301 | cond += [nn.LeakyReLU(0.2)]
302 | cond += [Avgpool()]
303 | cond += [nn.Linear(dim_out, num_class)]
304 | self.cond = nn.Sequential(*cond)
305 |
306 | def forward(self, x, c, vid_max_length):
307 | # c: B,C,T
308 | f_len = final_length(vid_max_length)
309 | c = c.mean(2) #B,C
310 | c = c.unsqueeze(2).unsqueeze(2).repeat(1, 1, 5, f_len)
311 | out = self.main(x).clone()
312 | uout = self.uncond(out)
313 | out = torch.cat([out, c], dim=1)
314 | cout = self.cond(out)
315 | uout = uout.view(uout.size(0), -1) # (batch, num_domains)
316 | cout = cout.view(cout.size(0), -1) # (batch, num_domains)
317 | return uout, cout
318 |
319 | class sync_Discriminator(nn.Module):
320 | def __init__(self, temp=1.0):
321 | super().__init__()
322 |
323 | self.frontend = nn.Sequential(
324 | nn.Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
325 | nn.BatchNorm2d(128),
326 | nn.PReLU(128),
327 | nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
328 | nn.BatchNorm2d(256),
329 | nn.PReLU(256)
330 | )
331 |
332 | self.Res_block = nn.Sequential(
333 | BasicBlock(256, 256)
334 | )
335 |
336 | self.Linear = nn.Linear(256 * 20, 512)
337 | self.temp = temp
338 |
339 | def forward(self, v_feat, aud, gen=False):
340 | # v_feat: B, S, 512
341 | a_feat = self.frontend(aud)
342 | a_feat = self.Res_block(a_feat)
343 | b, c, f, t = a_feat.size()
344 | a_feat = a_feat.view(b, c * f, t).transpose(1, 2).contiguous() # B, T/4, 256 * F/4
345 | a_feat = self.Linear(a_feat) # B, S, 512
346 |
347 | if gen:
348 | sim = torch.abs(F.cosine_similarity(v_feat, a_feat, 2)).mean(1) #B, S
349 | loss = 5 * torch.ones_like(sim) - sim
350 | else:
351 | v_feat_norm = F.normalize(v_feat, dim=2) #B,S,512
352 | a_feat_norm = F.normalize(a_feat, dim=2) #B,S,512
353 |
354 | sim = torch.bmm(v_feat_norm, a_feat_norm.transpose(1, 2)) / self.temp #B,v_S,a_S
355 |
356 | nce_va = torch.mean(torch.diagonal(F.log_softmax(sim, dim=2), dim1=-2, dim2=-1), dim=1)
357 | nce_av = torch.mean(torch.diagonal(F.log_softmax(sim, dim=1), dim1=-2, dim2=-1), dim=1)
358 |
359 | loss = -1/2 * (nce_va + nce_av)
360 |
361 | return loss
362 |
363 | def gan_loss(inputs, label=None):
364 | # non-saturating loss with R1 regularization
365 | l = -1 if label else 1
366 | return F.softplus(l*inputs).mean()
367 |
368 | def final_length(vid_length):
369 | half = (vid_length // 2)
370 | quad = (half // 2)
371 | return quad
--------------------------------------------------------------------------------
/src/data/vid_aud_lrs3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torchaudio
8 | from torchvision import transforms
9 | from torch.utils.data import DataLoader, Dataset
10 | from src.data.transforms import Crop, StatefulRandomHorizontalFlip
11 | from PIL import Image
12 | import librosa
13 | import cv2
14 | from matplotlib import pyplot as plt
15 | import glob
16 | from scipy import signal
17 | import torchvision
18 | from torch.autograd import Variable
19 | from librosa.filters import mel as librosa_mel_fn
20 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression, griffin_lim
21 | from src.data.stft import STFT
22 | import math
23 | log1e5 = math.log(1e-5)
24 |
25 | class MultiDataset(Dataset):
26 | def __init__(self, data, mode, max_v_timesteps=155, window_size=40, augmentations=False, num_mel_bins=80, fast_validate=False, f_min=55., f_max=7600.):
27 | assert mode in ['pretrain', 'train', 'test', 'val']
28 | self.data = data
29 | self.sample_window = True if (mode == 'pretrain' or mode == 'train') else False
30 | self.fast_validate = fast_validate
31 | self.max_v_timesteps = window_size if self.sample_window else max_v_timesteps
32 | self.window_size = window_size
33 | self.augmentations = augmentations if mode == 'train' else False
34 | self.num_mel_bins = num_mel_bins
35 | self.file_paths, self.file_names, self.crops = self.build_file_list(data, mode)
36 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=num_mel_bins, sampling_rate=16000, mel_fmin=f_min, mel_fmax=f_max)
37 |
38 | def build_file_list(self, lrs3, mode):
39 | file_list, paths = [], []
40 | crops = {}
41 |
42 | ## LRS3 crop (lip centered axis) load ##
43 | file = open(f"./data/LRS3/LRS3_crop/preprocess_pretrain.txt", "r")
44 | content = file.read()
45 | file.close()
46 | for i, line in enumerate(content.splitlines()):
47 | split = line.split(".")
48 | file = split[0]
49 | crop_str = split[1][4:]
50 | crops['pretrain/' + file] = crop_str
51 | file = open(f"./data/LRS3/LRS3_crop/preprocess_test.txt", "r")
52 | content = file.read()
53 | file.close()
54 | for i, line in enumerate(content.splitlines()):
55 | split = line.split(".")
56 | file = split[0]
57 | crop_str = split[1][4:]
58 | crops['test/' + file] = crop_str
59 | file = open(f"./data/LRS3/LRS3_crop/preprocess_trainval.txt", "r")
60 | content = file.read()
61 | file.close()
62 | for i, line in enumerate(content.splitlines()):
63 | split = line.split(".")
64 | file = split[0]
65 | crop_str = split[1][4:]
66 | crops['trainval/' + file] = crop_str
67 |
68 | ## LRS3 file lists##
69 | file = open(f"./data/LRS3/lrs3_unseen_{mode}.txt", "r")
70 | content = file.read()
71 | file.close()
72 | for file in content.splitlines():
73 | if file in crops:
74 | file_list.append(file)
75 | paths.append(f"{lrs3}/{file}")
76 |
77 | print(f'Mode: {mode}, File Num: {len(file_list)}')
78 | return paths, file_list, crops
79 |
80 | def build_tensor(self, frames, crops):
81 | if self.augmentations:
82 | s = random.randint(-5, 5)
83 | else:
84 | s = 0
85 | crop = []
86 | for i in range(0, len(crops), 2):
87 | left = int(crops[i]) - 40 + s
88 | upper = int(crops[i + 1]) - 40 + s
89 | right = int(crops[i]) + 40 + s
90 | bottom = int(crops[i + 1]) + 40 + s
91 | crop.append([left, upper, right, bottom])
92 | crops = crop
93 |
94 | if self.augmentations:
95 | augmentations1 = transforms.Compose([StatefulRandomHorizontalFlip(0.5)])
96 | else:
97 | augmentations1 = transforms.Compose([])
98 |
99 | temporalVolume = torch.zeros(self.max_v_timesteps, 1, 112, 112)
100 | for i, frame in enumerate(frames):
101 | transform = transforms.Compose([
102 | transforms.ToPILImage(),
103 | Crop(crops[i]),
104 | transforms.Resize([112, 112]),
105 | augmentations1,
106 | transforms.Grayscale(num_output_channels=1),
107 | transforms.ToTensor(),
108 | transforms.Normalize(0.4136, 0.1700),
109 | ])
110 | temporalVolume[i] = transform(frame)
111 |
112 | temporalVolume = temporalVolume.transpose(1, 0) # (C, T, H, W)
113 | return temporalVolume
114 |
115 | def __len__(self):
116 | return len(self.file_paths)
117 |
118 | def __getitem__(self, idx):
119 | file = self.file_names[idx]
120 | file_path = self.file_paths[idx]
121 | crops = self.crops[file].split("/")
122 |
123 | info = {}
124 | info['video_fps'] = 25
125 | cap = cv2.VideoCapture(file_path + '.mp4')
126 | frames = []
127 | while (cap.isOpened()):
128 | ret, frame = cap.read()
129 | if ret:
130 | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
131 | else:
132 | break
133 | cap.release()
134 | audio, info['audio_fps'] = librosa.load(file_path.replace('LRS3-TED', 'LRS3-TED_audio') + '.wav', sr=16000)
135 | vid = torch.tensor(np.stack(frames, 0))
136 | audio = torch.tensor(audio).unsqueeze(0)
137 |
138 | if not 'video_fps' in info:
139 | info['video_fps'] = 25
140 | info['audio_fps'] = 16000
141 |
142 | assert vid.size(0) > 5 or audio.size(1) > 5
143 |
144 | ## Audio ##
145 | audio = audio / torch.abs(audio).max() * 0.9
146 | aud = torch.FloatTensor(self.preemphasize(audio.squeeze(0))).unsqueeze(0)
147 | aud = torch.clamp(aud, min=-1, max=1)
148 |
149 | melspec, spec = self.stft.mel_spectrogram(aud)
150 |
151 | ## Video ##
152 | vid = vid.permute(0, 3, 1, 2) # T C H W
153 |
154 | if self.sample_window:
155 | vid, melspec, spec, audio, crops = self.extract_window(vid, melspec, spec, audio, info, crops)
156 | elif vid.size(0) > self.max_v_timesteps:
157 | print('Sample is longer than Max video frames! Trimming to the length of ', self.max_v_timesteps)
158 | vid = vid[:self.max_v_timesteps]
159 | melspec = melspec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)]
160 | spec = spec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)]
161 | audio = audio[:, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'])]
162 | crops = crops[:self.max_v_timesteps * 2]
163 |
164 | num_v_frames = vid.size(0)
165 | vid = self.build_tensor(vid, crops)
166 |
167 | melspec = self.normalize(melspec) #0~2 --> -1~1
168 |
169 | spec = self.normalize_spec(spec) # 0 ~ 1
170 | spec = self.stft.spectral_normalize(spec) # log(1e-5) ~ 0 # in log scale
171 | spec = self.normalize(spec) # -1 ~ 1
172 |
173 | num_a_frames = melspec.size(2)
174 | melspec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(melspec)
175 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(spec)
176 |
177 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames, file_path.replace(self.data, '')[1:]
178 |
179 | def extract_window(self, vid, mel, spec, aud, info, crops):
180 | # vid : T,C,H,W
181 | st_fr = random.randint(0, max(0, vid.size(0) - self.window_size))
182 | vid = vid[st_fr:st_fr + self.window_size]
183 | crops = crops[st_fr * 2: st_fr * 2 + self.window_size * 2]
184 |
185 | assert vid.size(0) * 2 == len(crops), f'vid length: {vid.size(0)}, crop length: {len(crops)}'
186 |
187 | st_mel_fr = int(st_fr * info['audio_fps'] / info['video_fps'] / 160)
188 | mel_window_size = int(self.window_size * info['audio_fps'] / info['video_fps'] / 160)
189 | mel = mel[:, :, st_mel_fr:st_mel_fr + mel_window_size]
190 | spec = spec[:, :, st_mel_fr:st_mel_fr + mel_window_size]
191 | aud = aud[:, st_mel_fr*160:st_mel_fr*160 + mel_window_size*160]
192 | aud = torch.cat([aud, torch.zeros([1, int(self.window_size / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1)
193 |
194 | return vid, mel, spec, aud, crops
195 |
196 | def collate_fn(self, batch):
197 | vid_lengths, spec_lengths, padded_spec_lengths, aud_lengths = [], [], [], []
198 | for data in batch:
199 | vid_lengths.append(data[3])
200 | spec_lengths.append(data[5])
201 | padded_spec_lengths.append(data[0].size(2))
202 | aud_lengths.append(data[4].size(0))
203 |
204 | max_aud_length = max(aud_lengths)
205 | max_spec_length = max(padded_spec_lengths)
206 | padded_vid = []
207 | padded_melspec = []
208 | padded_spec = []
209 | padded_audio = []
210 | f_names = []
211 |
212 | for i, (melspec, spec, vid, num_v_frames, audio, spec_len, f_name) in enumerate(batch):
213 | padded_vid.append(vid) # B, C, T, H, W
214 | padded_melspec.append(nn.ConstantPad2d((0, max_spec_length - melspec.size(2), 0, 0), -1.0)(melspec))
215 | padded_spec.append(nn.ConstantPad2d((0, max_spec_length - spec.size(2), 0, 0), -1.0)(spec))
216 | padded_audio.append(torch.cat([audio, torch.zeros([max_aud_length - audio.size(0)])], 0))
217 | f_names.append(f_name)
218 |
219 | vid = torch.stack(padded_vid, 0).float()
220 | vid_length = torch.IntTensor(vid_lengths)
221 | melspec = torch.stack(padded_melspec, 0).float()
222 | spec = torch.stack(padded_spec, 0).float()
223 | spec_length = torch.IntTensor(spec_lengths)
224 | audio = torch.stack(padded_audio, 0).float()
225 |
226 | return melspec, spec, vid, vid_length, audio, spec_length, f_names
227 |
228 | def inverse_mel(self, mel, stft):
229 | if len(mel.size()) < 4:
230 | mel = mel.unsqueeze(0) #B,1,80,T
231 |
232 | mel = self.denormalize(mel)
233 | mel = stft.spectral_de_normalize(mel)
234 | mel = mel.transpose(2, 3).contiguous() #B,80,T --> B,T,80
235 | spec_from_mel_scaling = 1000
236 | spec_from_mel = torch.matmul(mel, stft.mel_basis)
237 | spec_from_mel = spec_from_mel.transpose(2, 3).squeeze(1) # B,1,F,T
238 | spec_from_mel = spec_from_mel * spec_from_mel_scaling
239 |
240 | wav = griffin_lim(spec_from_mel, stft.stft_fn, 60).squeeze(1) #B,L
241 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy()
242 | wavs = []
243 | for w in wav:
244 | w = self.deemphasize(w)
245 | wavs += [w]
246 | wavs = np.array(wavs)
247 | wavs = np.clip(wavs, -1, 1)
248 | return wavs
249 |
250 | def inverse_spec(self, spec, stft):
251 | if len(spec.size()) < 4:
252 | spec = spec.unsqueeze(0) #B,1,321,T
253 |
254 | spec = self.denormalize(spec) # log1e5 ~ 0
255 | spec = stft.spectral_de_normalize(spec) # 0 ~ 1
256 | spec = self.denormalize_spec(spec) # 0 ~ 14
257 | wav = griffin_lim(spec.squeeze(1), stft.stft_fn, 60).squeeze(1) #B,L
258 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy()
259 | wavs = []
260 | for w in wav:
261 | w = self.deemphasize(w)
262 | wavs += [w]
263 | wavs = np.array(wavs)
264 | wavs = np.clip(wavs, -1, 1)
265 | return wavs
266 |
267 | def preemphasize(self, aud):
268 | aud = signal.lfilter([1, -0.97], [1], aud)
269 | return aud
270 |
271 | def deemphasize(self, aud):
272 | aud = signal.lfilter([1], [1, -0.97], aud)
273 | return aud
274 |
275 | def normalize(self, melspec):
276 | melspec = ((melspec - log1e5) / (-log1e5 / 2)) - 1 #0~2 --> -1~1
277 | return melspec
278 |
279 | def denormalize(self, melspec):
280 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5
281 | return melspec
282 |
283 | def normalize_spec(self, spec):
284 | spec = (spec - spec.min()) / (spec.max() - spec.min()) # 0 ~ 1
285 | return spec
286 |
287 | def denormalize_spec(self, spec):
288 | spec = spec * 14. # 0 ~ 14
289 | return spec
290 |
291 | def plot_spectrogram_to_numpy(self, mels):
292 | fig, ax = plt.subplots(figsize=(15, 4))
293 | im = ax.imshow(np.squeeze(mels, 0), aspect="auto", origin="lower",
294 | interpolation='none')
295 | plt.colorbar(im, ax=ax)
296 | plt.xlabel("Frames")
297 | plt.ylabel("Channels")
298 | plt.tight_layout()
299 |
300 | fig.canvas.draw()
301 | data = self.save_figure_to_numpy(fig)
302 | plt.close()
303 | return data
304 |
305 | def save_figure_to_numpy(self, fig):
306 | # save it to a numpy array.
307 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
308 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
309 | return data.transpose(2, 0, 1)
310 |
311 | class TacotronSTFT(torch.nn.Module):
312 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
313 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
314 | mel_fmax=8000.0):
315 | super(TacotronSTFT, self).__init__()
316 | self.n_mel_channels = n_mel_channels
317 | self.sampling_rate = sampling_rate
318 | self.stft_fn = STFT(filter_length, hop_length, win_length)
319 | mel_basis = librosa_mel_fn(
320 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
321 | mel_basis = torch.from_numpy(mel_basis).float()
322 | self.register_buffer('mel_basis', mel_basis)
323 |
324 | def spectral_normalize(self, magnitudes):
325 | output = dynamic_range_compression(magnitudes)
326 | return output
327 |
328 | def spectral_de_normalize(self, magnitudes):
329 | output = dynamic_range_decompression(magnitudes)
330 | return output
331 |
332 | def mel_spectrogram(self, y):
333 | """Computes mel-spectrograms from a batch of waves
334 | PARAMS
335 | ------
336 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
337 | RETURNS
338 | -------
339 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
340 | """
341 | assert(torch.min(y.data) >= -1)
342 | assert(torch.max(y.data) <= 1)
343 |
344 | magnitudes, phases = self.stft_fn.transform(y)
345 | magnitudes = magnitudes.data
346 | mel_output = torch.matmul(self.mel_basis, magnitudes)
347 | mel_output = self.spectral_normalize(mel_output)
348 | return mel_output, magnitudes
349 |
--------------------------------------------------------------------------------
/src/data/vid_aud_lrs2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torchaudio
8 | from torchvision import transforms
9 | from torch.utils.data import DataLoader, Dataset
10 | from src.data.transforms import Crop, StatefulRandomHorizontalFlip
11 | from PIL import Image
12 | import librosa
13 | import cv2
14 | import matplotlib
15 | matplotlib.use('Agg')
16 | from matplotlib import pyplot as plt
17 | import glob
18 | from scipy import signal
19 | import torchvision
20 | from torch.autograd import Variable
21 | from librosa.filters import mel as librosa_mel_fn
22 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression, griffin_lim
23 | from src.data.stft import STFT
24 | import math
25 | log1e5 = math.log(1e-5)
26 |
27 | class MultiDataset(Dataset):
28 | def __init__(self, data, mode, max_v_timesteps=155, window_size=40, augmentations=False, num_mel_bins=80, fast_validate=False, f_min=55., f_max=7600.):
29 | assert mode in ['train', 'test', 'val']
30 | self.data = data
31 | self.sample_window = True if (mode == 'train') else False
32 | self.fast_validate = fast_validate
33 | self.max_v_timesteps = window_size if self.sample_window else max_v_timesteps
34 | self.window_size = window_size
35 | self.augmentations = augmentations if mode == 'train' else False
36 | self.num_mel_bins = num_mel_bins
37 | self.file_paths, self.file_names, self.crops = self.build_file_list(data, mode)
38 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=num_mel_bins, sampling_rate=16000, mel_fmin=f_min, mel_fmax=f_max)
39 |
40 | def build_file_list(self, lrs2, mode):
41 | file_list, paths = [], []
42 | crops = {}
43 | lrs2_mode = 'main'
44 |
45 | ## LRS2 crop (lip centered axis) load ##
46 | file = open(f"./data/LRS2/LRS2_crop/preprocess_{lrs2_mode}.txt", "r")
47 | content = file.read()
48 | file.close()
49 | for i, line in enumerate(content.splitlines()):
50 | split = line.split(".")
51 | file = split[0]
52 | crop_str = split[1][4:]
53 | crops[f'{lrs2_mode}/{file}'] = crop_str
54 | ## LRS2 file in crop ##
55 | file = open(f"./data/LRS2/{mode}.txt", "r")
56 | content = file.readlines()
57 | file.close()
58 | for file in content:
59 | file = file.strip().split()[0]
60 | file = f'{lrs2_mode}/{file}'
61 | if file in crops:
62 | file_list.append(file)
63 | paths.append(f"{lrs2}/{file}")
64 | if mode == 'train':
65 | file = open(f"./data/LRS2/LRS2_crop/preprocess_pretrain.txt", "r")
66 | content = file.read()
67 | file.close()
68 | for i, line in enumerate(content.splitlines()):
69 | split = line.split(".")
70 | file = split[0]
71 | crop_str = split[1][4:]
72 | crops[f'pretrain/{file}'] = crop_str
73 | ## LRS2 file in crop ##
74 | file = open(f"./data/LRS2/pretrain.txt", "r")
75 | content = file.readlines()
76 | file.close()
77 | for file in content:
78 | file = file.strip().split()[0]
79 | file = f'pretrain/{file}'
80 | if file in crops:
81 | file_list.append(file)
82 | paths.append(f"{lrs2}/{file}")
83 |
84 | print(f'Mode: {mode}, File Num: {len(file_list)}')
85 | return paths, file_list, crops
86 |
87 | def build_tensor(self, frames, crops):
88 | if self.augmentations:
89 | s = random.randint(-5, 5)
90 | else:
91 | s = 0
92 | crop = []
93 | for i in range(0, len(crops), 2):
94 | left = int(crops[i]) - 40 + s
95 | upper = int(crops[i + 1]) - 40 + s
96 | right = int(crops[i]) + 40 + s
97 | bottom = int(crops[i + 1]) + 40 + s
98 | crop.append([left, upper, right, bottom])
99 | crops = crop
100 |
101 | if self.augmentations:
102 | augmentations1 = transforms.Compose([StatefulRandomHorizontalFlip(0.5)])
103 | else:
104 | augmentations1 = transforms.Compose([])
105 |
106 | temporalVolume = torch.zeros(self.max_v_timesteps, 1, 112, 112)
107 | for i, frame in enumerate(frames):
108 | transform = transforms.Compose([
109 | transforms.ToPILImage(),
110 | Crop(crops[i]),
111 | transforms.Resize([112, 112]),
112 | augmentations1,
113 | transforms.Grayscale(num_output_channels=1),
114 | transforms.ToTensor(),
115 | transforms.Normalize(0.4136, 0.1700),
116 | ])
117 | temporalVolume[i] = transform(frame)
118 |
119 | temporalVolume = temporalVolume.transpose(1, 0) # (C, T, H, W)
120 | return temporalVolume
121 |
122 | def __len__(self):
123 | return len(self.file_paths)
124 |
125 | def __getitem__(self, idx):
126 | file = self.file_names[idx]
127 | file_path = self.file_paths[idx]
128 | crops = self.crops[file].split("/")
129 |
130 | info = {}
131 | info['video_fps'] = 25
132 | cap = cv2.VideoCapture(file_path + '.mp4')
133 | frames = []
134 | while (cap.isOpened()):
135 | ret, frame = cap.read()
136 | if ret:
137 | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
138 | else:
139 | break
140 | cap.release()
141 | audio, info['audio_fps'] = librosa.load(file_path.replace('LRS2-BBC', 'LRS2-BBC_audio') + '.wav', sr=16000)
142 | vid = torch.tensor(np.stack(frames, 0))
143 | audio = torch.tensor(audio).unsqueeze(0)
144 |
145 | if not 'video_fps' in info:
146 | info['video_fps'] = 25
147 | info['audio_fps'] = 16000
148 |
149 | assert vid.size(0) > 5 or audio.size(1) > 5
150 |
151 | ## Audio ##
152 | audio = audio / torch.abs(audio).max() * 0.9
153 | aud = torch.FloatTensor(self.preemphasize(audio.squeeze(0))).unsqueeze(0)
154 | aud = torch.clamp(aud, min=-1, max=1)
155 |
156 | melspec, spec = self.stft.mel_spectrogram(aud)
157 |
158 | ## Video ##
159 | vid = vid.permute(0, 3, 1, 2) # T C H W
160 |
161 | if self.sample_window:
162 | vid, melspec, spec, audio, crops = self.extract_window(vid, melspec, spec, audio, info, crops)
163 | elif vid.size(0) > self.max_v_timesteps:
164 | print('Sample is longer than Max video frames! Trimming to the length of ', self.max_v_timesteps)
165 | vid = vid[:self.max_v_timesteps]
166 | melspec = melspec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)]
167 | spec = spec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)]
168 | audio = audio[:, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'])]
169 | crops = crops[:self.max_v_timesteps * 2]
170 |
171 | num_v_frames = vid.size(0)
172 | vid = self.build_tensor(vid, crops)
173 |
174 | melspec = self.normalize(melspec) #0~2 --> -1~1
175 |
176 | spec = self.normalize_spec(spec) # 0 ~ 1
177 | spec = self.stft.spectral_normalize(spec) # log(1e-5) ~ 0 # in log scale
178 | spec = self.normalize(spec) # -1 ~ 1
179 |
180 | num_a_frames = melspec.size(2)
181 | melspec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(melspec)
182 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(spec)
183 |
184 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames, file_path.replace(self.data, '')[1:]
185 |
186 | def extract_window(self, vid, mel, spec, aud, info, crops):
187 | # vid : T,C,H,W
188 | st_fr = random.randint(0, max(0, vid.size(0) - self.window_size))
189 | vid = vid[st_fr:st_fr + self.window_size]
190 | crops = crops[st_fr * 2: st_fr * 2 + self.window_size * 2]
191 |
192 | assert vid.size(0) * 2 == len(crops), f'vid length: {vid.size(0)}, crop length: {len(crops)}'
193 |
194 | st_mel_fr = int(st_fr * info['audio_fps'] / info['video_fps'] / 160)
195 | mel_window_size = int(self.window_size * info['audio_fps'] / info['video_fps'] / 160)
196 | mel = mel[:, :, st_mel_fr:st_mel_fr + mel_window_size]
197 | spec = spec[:, :, st_mel_fr:st_mel_fr + mel_window_size]
198 | aud = aud[:, st_mel_fr*160:st_mel_fr*160 + mel_window_size*160]
199 | aud = torch.cat([aud, torch.zeros([1, int(self.window_size / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1)
200 |
201 | return vid, mel, spec, aud, crops
202 |
203 | def collate_fn(self, batch):
204 | vid_lengths, spec_lengths, padded_spec_lengths, aud_lengths = [], [], [], []
205 | for data in batch:
206 | vid_lengths.append(data[3])
207 | spec_lengths.append(data[5])
208 | padded_spec_lengths.append(data[0].size(2))
209 | aud_lengths.append(data[4].size(0))
210 |
211 | max_aud_length = max(aud_lengths)
212 | max_spec_length = max(padded_spec_lengths)
213 | padded_vid = []
214 | padded_melspec = []
215 | padded_spec = []
216 | padded_audio = []
217 | f_names = []
218 |
219 | for i, (melspec, spec, vid, num_v_frames, audio, spec_len, f_name) in enumerate(batch):
220 | padded_vid.append(vid) # B, C, T, H, W
221 | padded_melspec.append(nn.ConstantPad2d((0, max_spec_length - melspec.size(2), 0, 0), -1.0)(melspec))
222 | padded_spec.append(nn.ConstantPad2d((0, max_spec_length - spec.size(2), 0, 0), -1.0)(spec))
223 | padded_audio.append(torch.cat([audio, torch.zeros([max_aud_length - audio.size(0)])], 0))
224 | f_names.append(f_name)
225 |
226 | vid = torch.stack(padded_vid, 0).float()
227 | vid_length = torch.IntTensor(vid_lengths)
228 | melspec = torch.stack(padded_melspec, 0).float()
229 | spec = torch.stack(padded_spec, 0).float()
230 | spec_length = torch.IntTensor(spec_lengths)
231 | audio = torch.stack(padded_audio, 0).float()
232 |
233 | return melspec, spec, vid, vid_length, audio, spec_length, f_names
234 |
235 | def inverse_mel(self, mel, stft):
236 | if len(mel.size()) < 4:
237 | mel = mel.unsqueeze(0) #B,1,80,T
238 |
239 | mel = self.denormalize(mel)
240 | mel = stft.spectral_de_normalize(mel)
241 | mel = mel.transpose(2, 3).contiguous() #B,80,T --> B,T,80
242 | spec_from_mel_scaling = 1000
243 | spec_from_mel = torch.matmul(mel, stft.mel_basis)
244 | spec_from_mel = spec_from_mel.transpose(2, 3).squeeze(1) # B,1,F,T
245 | spec_from_mel = spec_from_mel * spec_from_mel_scaling
246 |
247 | wav = griffin_lim(spec_from_mel, stft.stft_fn, 60).squeeze(1) #B,L
248 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy()
249 | wavs = []
250 | for w in wav:
251 | w = self.deemphasize(w)
252 | wavs += [w]
253 | wavs = np.array(wavs)
254 | wavs = np.clip(wavs, -1, 1)
255 | return wavs
256 |
257 | def inverse_spec(self, spec, stft):
258 | if len(spec.size()) < 4:
259 | spec = spec.unsqueeze(0) #B,1,321,T
260 |
261 | spec = self.denormalize(spec) # log1e5 ~ 0
262 | spec = stft.spectral_de_normalize(spec) # 0 ~ 1
263 | spec = self.denormalize_spec(spec) # 0 ~ 14
264 | wav = griffin_lim(spec.squeeze(1), stft.stft_fn, 60).squeeze(1) #B,L
265 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy()
266 | wavs = []
267 | for w in wav:
268 | w = self.deemphasize(w)
269 | wavs += [w]
270 | wavs = np.array(wavs)
271 | wavs = np.clip(wavs, -1, 1)
272 | return wavs
273 |
274 | def preemphasize(self, aud):
275 | aud = signal.lfilter([1, -0.97], [1], aud)
276 | return aud
277 |
278 | def deemphasize(self, aud):
279 | aud = signal.lfilter([1], [1, -0.97], aud)
280 | return aud
281 |
282 | def normalize(self, melspec):
283 | melspec = ((melspec - log1e5) / (-log1e5 / 2)) - 1 #0~2 --> -1~1
284 | return melspec
285 |
286 | def denormalize(self, melspec):
287 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5
288 | return melspec
289 |
290 | def normalize_spec(self, spec):
291 | spec = (spec - spec.min()) / (spec.max() - spec.min()) # 0 ~ 1
292 | return spec
293 |
294 | def denormalize_spec(self, spec):
295 | spec = spec * 14. # 0 ~ 14
296 | return spec
297 |
298 | def plot_spectrogram_to_numpy(self, mels):
299 | fig, ax = plt.subplots(figsize=(15, 4))
300 | im = ax.imshow(np.squeeze(mels, 0), aspect="auto", origin="lower",
301 | interpolation='none')
302 | plt.colorbar(im, ax=ax)
303 | plt.xlabel("Frames")
304 | plt.ylabel("Channels")
305 | plt.tight_layout()
306 |
307 | fig.canvas.draw()
308 | data = self.save_figure_to_numpy(fig)
309 | plt.close()
310 | return data
311 |
312 | def save_figure_to_numpy(self, fig):
313 | # save it to a numpy array.
314 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
315 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
316 | return data.transpose(2, 0, 1)
317 |
318 | class TacotronSTFT(torch.nn.Module):
319 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
320 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
321 | mel_fmax=8000.0):
322 | super(TacotronSTFT, self).__init__()
323 | self.n_mel_channels = n_mel_channels
324 | self.sampling_rate = sampling_rate
325 | self.stft_fn = STFT(filter_length, hop_length, win_length)
326 | mel_basis = librosa_mel_fn(
327 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
328 | mel_basis = torch.from_numpy(mel_basis).float()
329 | self.register_buffer('mel_basis', mel_basis)
330 |
331 | def spectral_normalize(self, magnitudes):
332 | output = dynamic_range_compression(magnitudes)
333 | return output
334 |
335 | def spectral_de_normalize(self, magnitudes):
336 | output = dynamic_range_decompression(magnitudes)
337 | return output
338 |
339 | def mel_spectrogram(self, y):
340 | """Computes mel-spectrograms from a batch of waves
341 | PARAMS
342 | ------
343 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
344 | RETURNS
345 | -------
346 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
347 | """
348 | assert(torch.min(y.data) >= -1)
349 | assert(torch.max(y.data) <= 1)
350 |
351 | magnitudes, phases = self.stft_fn.transform(y)
352 | magnitudes = magnitudes.data
353 | mel_output = torch.matmul(self.mel_basis, magnitudes)
354 | mel_output = self.spectral_normalize(mel_output)
355 | return mel_output, magnitudes
356 |
--------------------------------------------------------------------------------
/preprocess/Ref_face.txt:
--------------------------------------------------------------------------------
1 | ABOUT/test/ABOUT_00001.mp4:71 98,71 106,71 115,72 124,72 132,73 141,74 149,76 158,79 166,82 174,87 181,92 188,97 194,103 199,111 202,118 205,126 207,134 206,142 204,150 201,157 196,163 190,169 184,174 177,178 169,181 161,184 153,185 144,187 135,188 126,189 117,190 108,190 99,81 82,89 74,97 73,106 73,113 76,113 81,105 80,97 80,89 80,144 75,153 71,161 71,171 72,178 81,170 78,162 77,153 79,144 80,129 94,129 106,128 116,129 127,114 133,121 135,129 136,135 135,143 134,89 95,94 93,99 92,106 93,111 96,105 97,100 98,94 97,146 96,152 93,159 91,164 93,169 95,164 98,158 98,152 97,103 155,113 151,124 149,129 150,133 151,145 153,154 157,147 165,137 169,127 170,118 169,110 163,105 156,117 154,129 155,141 154,152 157,141 162,128 164,116 161,100 95,158 95|71 96,71 104,71 114,71 122,72 131,73 140,74 148,75 157,78 166,82 173,86 181,91 188,96 193,103 199,111 202,118 205,126 207,134 206,142 204,149 200,157 195,163 189,168 183,174 176,178 168,181 160,183 152,185 143,186 134,188 125,188 116,189 107,189 98,81 82,89 74,97 72,105 73,113 75,113 81,105 80,96 80,89 80,144 75,153 73,161 73,170 73,177 82,169 79,161 78,152 80,144 81,128 94,129 106,128 116,129 127,114 133,121 135,128 136,135 135,143 134,89 95,94 93,99 91,106 92,111 96,106 97,100 98,94 97,145 96,151 93,158 91,164 93,168 96,163 98,158 99,152 98,102 155,112 151,123 149,129 151,133 151,144 152,154 157,146 164,137 168,127 169,117 168,109 162,105 155,117 154,128 155,140 154,152 157,140 161,127 163,116 160,100 95,157 95|70 96,70 104,70 114,70 122,71 131,72 140,73 148,75 157,78 166,81 174,86 181,91 188,96 194,103 199,110 203,117 205,126 207,134 206,142 204,149 200,156 195,163 190,168 184,173 177,177 169,180 160,183 152,184 144,186 135,188 125,188 117,189 108,189 99,80 79,88 71,96 70,105 71,113 74,113 80,104 78,96 77,88 77,144 74,153 71,161 71,170 72,177 81,169 78,161 76,153 78,144 80,128 94,128 106,128 117,128 128,113 134,121 136,128 137,135 136,142 135,89 94,94 92,99 91,106 92,111 96,105 97,99 97,94 97,145 96,151 94,158 92,164 94,168 96,163 98,158 99,152 98,102 155,112 152,123 149,128 151,133 151,144 153,154 157,146 165,137 169,127 171,117 169,109 163,104 156,116 154,128 156,140 155,151 157,140 162,127 164,115 161,100 95,157 95|69 96,69 104,69 114,70 122,70 131,71 140,72 148,74 157,77 166,80 174,85 181,90 188,96 194,102 199,110 203,117 205,126 207,133 207,141 204,149 200,156 195,162 190,167 183,173 177,177 168,180 160,183 152,184 143,186 134,187 125,188 116,189 107,189 98,79 81,87 72,96 71,105 72,113 75,112 81,104 79,96 78,88 79,144 75,153 72,161 72,170 73,178 82,170 79,161 78,152 79,143 81,128 95,128 106,128 118,128 129,113 134,120 136,127 137,134 136,142 135,89 95,93 93,99 91,106 92,111 96,105 98,99 98,93 97,144 97,151 94,157 92,163 94,168 96,163 99,157 99,151 98,102 155,112 152,123 150,128 151,133 151,143 153,153 157,146 164,137 169,126 170,117 168,109 163,104 156,116 155,128 156,139 155,151 157,140 162,127 164,115 161,100 95,157 95|71 97,71 105,71 115,71 123,72 132,73 140,74 148,76 157,79 166,82 173,87 181,92 187,97 193,104 199,111 202,118 204,127 206,135 206,142 203,149 200,157 195,163 189,168 183,174 177,178 169,181 161,184 152,185 144,187 135,188 126,189 118,190 108,190 99,81 82,89 73,97 72,106 73,114 75,114 81,105 80,97 79,89 80,145 75,154 72,162 72,171 73,179 82,171 79,162 78,154 79,145 81,129 95,129 107,129 118,129 129,114 134,121 136,128 137,136 136,143 135,90 96,95 93,100 92,107 93,112 97,106 98,100 99,95 98,146 97,152 94,159 92,165 94,169 97,164 99,159 100,152 98,102 155,113 152,124 150,129 151,134 151,145 153,155 157,147 164,138 168,127 169,118 168,110 162,105 155,117 154,129 156,141 155,153 156,141 162,128 163,116 160,101 95,158 96|70 96,70 104,70 114,71 122,72 131,73 139,73 147,75 156,78 165,81 172,86 180,91 186,96 192,103 197,111 201,118 203,126 205,134 205,142 202,150 199,157 194,163 188,169 182,174 176,178 167,181 159,184 151,185 143,187 134,188 125,188 116,190 107,190 98,80 84,88 75,97 73,106 74,114 76,114 82,105 81,96 81,88 81,145 76,154 73,163 73,172 74,179 84,171 80,163 78,154 80,144 82,129 96,129 108,129 119,129 130,114 135,121 137,128 138,136 137,143 136,90 97,95 94,100 93,107 94,113 98,106 99,101 100,95 99,145 98,151 94,158 93,164 94,169 97,164 100,158 101,152 99,103 155,113 153,124 151,129 152,134 153,144 154,154 157,146 164,137 167,127 169,118 167,110 162,105 156,117 156,129 157,141 156,152 157,140 161,128 162,116 160,101 96,158 96|70 96,70 104,70 113,71 122,71 130,72 139,73 147,74 156,77 165,81 172,85 179,91 186,96 192,103 197,111 200,118 203,127 203,135 204,143 201,150 198,157 193,163 188,169 182,174 176,178 168,181 160,184 152,185 143,187 134,188 125,189 117,190 108,190 99,79 81,87 73,96 72,105 73,113 76,113 82,104 80,96 79,87 79,145 76,154 73,162 73,172 74,179 84,171 80,162 79,154 80,144 82,129 97,129 109,129 120,129 131,114 137,121 139,129 140,136 139,143 138,88 97,93 94,99 93,106 94,112 99,106 100,99 101,93 100,145 100,151 96,159 94,165 96,169 99,164 102,159 103,152 101,103 157,113 154,124 152,129 154,134 154,144 155,154 158,146 164,137 166,128 168,119 166,110 162,105 157,117 157,129 159,140 158,152 158,140 161,128 161,117 159,100 97,158 98|69 100,69 108,69 117,70 125,70 134,71 142,72 150,73 159,76 167,80 174,85 181,90 187,96 193,102 198,110 200,118 203,126 204,134 204,142 201,149 198,156 194,162 189,168 183,173 177,177 169,180 161,183 154,184 146,186 137,187 128,188 120,189 111,189 103,79 83,87 75,95 74,104 76,112 79,112 84,104 82,95 81,87 81,144 79,153 76,161 75,170 76,178 85,170 82,161 81,152 83,143 85,128 98,127 110,127 122,128 133,113 138,120 140,127 141,134 140,142 139,87 98,92 95,98 94,105 95,110 100,104 101,98 102,92 100,145 101,151 97,158 95,164 97,168 100,163 103,158 103,151 102,102 157,112 155,123 153,128 154,133 154,143 155,153 159,145 164,136 167,126 168,117 167,109 163,104 158,116 158,127 159,139 158,151 159,139 161,127 162,115 160,99 98,157 99|69 100,69 108,69 117,70 125,70 134,71 142,72 150,73 159,76 167,80 174,85 181,90 187,96 193,103 197,110 200,118 203,126 203,134 204,142 201,148 198,156 194,162 189,168 183,173 177,177 169,180 162,183 154,184 146,186 137,187 128,188 121,188 111,189 103,79 82,87 74,95 73,104 75,112 78,112 84,104 82,95 81,87 81,144 79,153 75,161 75,170 75,177 85,169 81,161 80,152 82,143 84,128 98,128 110,128 121,128 132,113 137,121 139,128 140,135 139,142 138,88 97,93 95,98 94,106 95,110 99,104 100,98 101,93 100,145 100,150 97,157 95,163 97,168 99,163 102,157 103,151 101,102 156,113 153,123 151,128 153,133 153,143 154,152 158,145 164,136 166,127 168,118 166,110 162,105 156,117 156,128 158,139 157,150 158,139 161,127 161,116 159,99 97,157 98|70 100,69 108,70 117,70 125,71 133,71 142,72 149,74 159,77 167,80 174,85 181,90 187,96 192,103 197,110 200,118 202,126 203,134 204,142 201,148 198,156 194,162 189,168 183,173 177,177 169,181 161,183 154,184 146,186 137,188 128,188 120,189 111,189 102,79 82,87 74,95 73,104 75,112 78,112 83,104 81,95 80,87 80,144 79,153 75,161 75,170 75,177 84,169 81,161 80,152 82,143 84,128 97,128 109,128 120,129 131,114 136,121 138,128 139,134 138,142 137,88 97,93 95,99 94,106 95,111 99,105 100,99 101,93 99,144 99,150 96,157 95,163 96,167 99,162 101,157 102,151 101,103 156,113 153,123 151,128 152,133 152,143 154,152 158,144 164,136 167,127 168,118 167,110 163,105 156,117 156,128 157,139 156,149 158,139 161,127 162,116 160,99 97,156 98|69 100,69 108,70 117,71 125,71 133,72 142,73 149,74 159,77 167,80 174,85 181,91 187,96 192,103 197,111 200,118 202,126 203,134 203,142 201,149 198,156 193,163 188,168 182,173 176,178 168,181 160,183 153,185 145,186 136,188 127,188 119,189 110,189 102,79 82,87 74,96 73,104 75,112 77,112 83,104 81,95 80,87 80,144 78,153 75,161 74,170 75,177 84,169 80,161 80,152 82,143 84,128 96,128 108,128 119,128 130,114 135,121 137,128 138,134 137,141 136,89 96,93 93,99 92,106 94,111 98,105 99,99 100,93 98,144 98,150 95,157 93,163 95,167 98,162 100,157 101,151 100,102 156,113 153,123 151,128 152,133 152,143 153,152 157,144 164,136 167,127 168,118 167,110 163,105 156,116 155,128 157,139 156,150 157,139 161,127 162,116 160,100 96,156 97|69 97,69 105,70 114,70 123,71 132,72 140,73 148,74 157,77 166,80 173,85 180,91 187,96 193,103 198,111 201,118 203,127 205,135 205,143 202,149 199,157 193,163 188,168 182,174 176,178 167,181 159,183 152,185 143,186 134,188 126,188 117,189 108,189 99,79 83,87 75,96 74,104 75,112 78,112 84,104 82,96 80,87 80,144 79,152 75,161 75,169 75,177 84,169 81,161 80,152 82,143 84,128 96,128 107,128 118,129 129,114 135,121 137,128 138,135 137,142 136,88 96,93 93,99 92,106 93,111 97,105 99,99 99,93 98,145 98,151 94,158 93,163 94,168 97,163 100,158 101,151 99,103 156,113 153,123 151,128 152,133 152,143 154,152 158,145 164,136 167,127 168,118 167,110 163,105 157,117 156,128 157,139 156,150 158,139 161,127 162,116 160,100 96,157 96|70 98,69 106,70 116,71 124,71 133,72 141,73 149,74 159,77 167,81 174,85 181,91 188,96 193,102 199,110 202,118 204,126 205,134 205,142 202,149 199,157 194,163 189,168 183,173 177,177 169,181 160,183 153,184 145,186 136,188 127,188 119,189 110,189 100,80 83,88 76,96 76,104 77,112 81,113 86,104 83,96 82,88 82,144 81,152 78,160 77,169 77,176 86,168 83,160 82,152 84,143 87,128 96,128 108,128 119,128 130,114 136,121 138,127 139,134 138,141 137,89 96,94 93,100 92,106 94,111 98,105 99,100 100,94 98,144 98,150 95,157 93,163 95,167 98,162 100,157 101,151 100,102 157,112 155,123 153,128 154,133 154,142 156,152 158,144 163,135 165,127 166,118 165,110 162,105 157,116 158,128 159,139 159,150 159,139 160,127 160,115 159,100 96,157 96|70 98,70 107,70 116,71 124,71 133,72 141,73 149,75 159,78 167,81 174,86 182,91 188,96 194,103 199,111 202,118 205,127 206,134 206,143 203,149 200,157 195,163 190,168 184,173 177,177 169,181 161,183 153,184 145,186 136,188 128,188 119,189 110,189 100,80 85,88 78,97 78,105 79,113 82,113 88,105 85,97 84,89 84,143 83,152 79,159 79,168 79,176 87,168 85,160 84,151 86,143 88,128 97,128 109,128 120,128 131,114 136,121 138,128 139,134 138,141 137,89 97,94 94,100 93,106 94,111 99,106 100,100 100,94 99,145 99,151 96,157 94,163 96,167 99,162 101,157 102,151 101,102 157,113 156,123 154,128 155,133 155,143 157,152 159,144 164,135 165,127 166,118 165,110 162,105 158,117 159,128 160,139 160,150 160,139 160,127 159,116 158,100 96,157 97|72 99,71 107,72 116,72 125,73 133,74 141,75 149,77 158,80 166,83 174,87 181,92 188,98 194,104 200,111 203,118 205,127 207,135 207,143 204,150 201,158 196,164 191,169 185,174 178,178 171,181 163,184 155,185 146,187 138,189 130,189 120,190 112,190 102,82 84,90 76,98 76,107 78,114 81,114 86,106 84,98 83,90 82,144 82,153 78,161 78,169 78,177 86,169 84,161 83,153 85,144 87,129 96,129 108,129 118,129 129,115 135,122 136,129 137,135 137,142 136,90 95,96 92,101 92,108 93,113 97,107 98,101 99,95 98,146 98,152 94,159 93,165 95,169 98,164 100,159 101,153 100,104 156,114 154,124 152,129 153,134 153,144 155,153 159,145 165,136 167,128 168,119 166,111 162,106 157,118 156,129 158,140 158,151 159,140 161,128 162,117 160,102 95,159 96|71 99,70 107,71 117,72 125,72 134,73 142,75 150,77 160,80 168,83 175,88 183,94 189,99 195,105 201,113 205,121 206,129 208,137 207,145 204,151 200,159 195,165 189,169 183,174 176,178 168,181 160,184 152,185 144,186 135,188 127,189 118,189 109,189 100,82 83,90 75,98 75,107 77,115 80,115 86,106 83,98 82,90 82,144 80,152 77,160 76,168 76,176 85,168 82,160 82,152 84,143 86,129 96,129 107,129 118,130 128,115 134,122 135,129 137,136 135,143 135,90 96,95 94,101 93,107 94,112 98,107 98,101 99,95 98,146 97,152 95,158 93,164 95,168 98,163 99,159 100,152 99,103 156,114 152,125 150,130 151,134 151,145 153,154 157,147 165,138 169,129 170,119 169,110 164,106 156,118 154,130 156,141 155,152 157,141 163,129 164,116 162,101 96,158 96|71 94,70 103,70 113,71 122,72 131,73 140,74 148,76 158,79 166,82 175,87 182,92 189,98 196,105 201,113 206,120 208,129 209,137 208,146 205,152 202,160 196,165 190,171 182,176 176,180 167,183 159,185 151,187 142,188 133,190 124,191 115,190 106,191 97,80 81,88 74,97 74,106 77,115 80,115 86,106 83,97 81,88 80,144 81,153 77,161 75,170 76,179 83,170 81,161 81,153 84,143 86,130 94,130 106,130 117,130 127,116 132,122 135,129 136,136 134,143 133,90 94,96 93,101 93,107 93,112 95,106 96,101 96,95 95,147 95,154 95,160 94,165 95,170 96,165 97,160 97,154 96,103 156,114 151,125 149,130 151,135 150,145 153,154 158,147 165,139 170,129 172,118 170,110 165,105 156,118 154,130 155,141 155,152 158,141 164,129 164,116 163,102 95,160 95|69 94,68 103,69 112,70 121,70 130,71 139,73 147,75 158,77 166,81 175,85 182,91 190,96 196,103 202,111 206,119 208,127 209,136 209,144 206,150 202,158 196,163 190,168 183,173 176,178 168,181 160,184 152,185 143,187 134,189 125,189 116,189 107,190 98,78 79,87 72,96 73,105 76,114 79,114 85,105 82,96 80,87 78,141 81,151 77,159 75,168 76,177 84,168 81,159 81,150 84,141 86,128 94,128 105,128 116,129 126,114 131,121 134,128 135,135 134,142 133,89 93,94 92,100 92,106 93,111 94,105 95,99 95,94 94,146 96,152 95,159 95,164 96,168 97,163 97,159 98,152 97,101 155,112 151,123 149,128 150,133 150,144 152,153 158,146 165,137 171,127 173,117 170,108 164,104 155,116 153,128 154,140 154,151 158,140 164,127 165,115 163,100 94,159 96|69 95,69 103,69 113,70 122,70 131,71 140,73 148,75 158,77 166,81 175,86 182,91 190,96 196,103 201,111 205,118 208,127 209,135 208,143 206,150 202,157 196,163 190,168 183,173 176,178 168,181 160,184 152,185 144,187 134,188 126,189 117,189 108,190 98,79 80,88 73,97 73,106 76,114 79,114 84,105 82,97 80,88 79,142 80,151 76,160 75,169 76,177 84,168 81,160 81,151 83,142 85,128 93,128 105,128 116,129 126,114 131,121 134,128 135,135 134,142 133,89 93,94 92,100 92,106 92,111 94,105 94,100 95,94 94,146 95,153 94,159 94,164 95,168 96,163 97,158 97,153 96,101 155,112 150,123 148,128 150,133 150,144 152,153 157,146 165,137 171,127 173,116 170,108 164,104 156,116 153,128 154,140 154,151 157,140 164,127 165,114 163,100 93,159 96|71 96,70 104,71 114,71 123,72 132,73 140,74 148,76 158,79 167,82 175,87 182,92 190,98 196,104 201,112 205,120 207,128 208,137 208,145 206,151 202,159 196,164 190,170 183,175 177,179 168,182 161,185 152,186 144,188 135,190 126,190 118,191 108,191 99,80 80,89 73,98 72,107 75,115 78,115 84,106 81,97 80,89 79,144 79,153 76,161 75,171 75,179 83,170 81,161 80,153 83,143 85,129 93,130 105,129 116,130 127,115 131,122 134,129 135,136 134,144 133,90 92,95 91,100 90,106 91,112 94,106 94,100 94,95 93,147 95,154 94,160 93,166 94,170 96,165 97,160 97,154 96,102 155,113 150,124 148,130 150,134 150,146 152,155 157,147 165,139 170,128 172,118 170,109 164,105 156,117 153,129 154,141 154,152 157,141 164,129 165,116 163,101 93,160 95|71 99,70 107,71 117,72 125,72 134,73 143,74 150,76 160,79 169,82 176,87 183,92 190,98 196,104 201,112 205,120 207,128 208,136 207,144 205,151 201,158 196,164 190,169 184,174 177,179 169,182 161,184 154,186 145,187 136,189 127,190 119,190 110,190 101,81 82,89 74,97 73,106 76,114 78,114 84,105 82,97 80,89 80,145 80,154 76,162 75,170 76,178 84,170 82,162 81,153 83,145 85,129 95,129 106,129 117,130 127,115 133,122 135,129 136,136 135,143 134,90 94,95 93,100 93,106 93,111 96,105 96,100 97,94 96,147 96,153 95,159 94,165 95,169 97,164 98,159 98,153 97,102 156,113 152,124 149,129 151,134 151,145 153,154 157,147 165,138 170,128 172,118 170,110 164,105 156,117 154,129 155,141 155,152 157,141 164,128 165,116 163,100 95,159 96|70 101,70 109,71 119,71 127,71 136,72 144,74 152,75 161,78 170,82 177,87 184,92 191,98 196,104 201,113 204,120 206,128 207,136 206,144 203,151 200,158 195,164 189,169 183,174 177,178 169,181 161,184 153,185 145,186 136,188 127,188 119,188 110,189 102,81 83,89 76,97 75,106 77,114 80,114 86,105 83,97 82,89 82,144 80,153 77,160 76,169 77,176 85,168 82,160 82,152 84,144 86,128 96,129 107,128 118,129 127,115 134,121 135,128 136,135 135,142 135,90 97,94 95,100 94,106 95,111 98,105 99,100 99,94 98,146 97,152 95,158 94,163 96,167 98,162 99,158 100,152 99,103 157,113 152,124 150,129 151,133 151,144 153,153 158,146 164,137 168,128 170,119 168,110 164,106 157,117 154,129 155,140 155,151 158,140 162,128 163,116 162,100 97,157 97|72 102,71 110,72 119,72 128,72 136,73 145,74 152,76 162,79 170,82 177,87 184,92 190,98 196,104 201,112 203,120 205,128 206,136 206,144 203,151 200,158 196,164 191,169 185,175 179,179 171,182 163,184 155,185 147,187 138,188 129,189 121,189 112,189 104,82 86,90 78,99 77,107 79,115 82,115 88,107 86,98 84,90 84,144 83,153 79,161 78,169 79,177 87,169 85,160 84,152 86,144 88,129 98,129 109,129 120,129 130,115 136,122 138,129 139,136 138,142 137,91 98,96 96,101 95,107 96,112 99,106 100,101 101,95 100,146 99,152 97,158 95,164 97,168 99,163 101,158 101,152 100,103 157,114 154,124 152,129 153,134 152,144 154,153 159,146 164,137 166,128 167,119 166,111 163,106 158,118 156,129 157,140 157,152 159,140 161,128 161,117 160,101 98,158 98|71 105,70 112,71 122,71 129,71 138,72 146,73 154,75 163,78 171,81 178,86 185,92 190,97 196,104 200,112 202,119 204,127 205,135 205,143 203,150 199,157 195,163 191,168 185,174 179,178 171,181 164,184 156,185 148,186 140,188 131,188 123,188 115,188 106,82 86,90 78,98 78,107 80,114 84,114 89,106 86,98 85,90 84,144 84,152 80,160 79,168 80,175 88,168 85,160 85,152 87,144 89,129 100,129 111,129 121,129 132,115 137,122 139,128 140,135 139,142 138,90 99,94 97,100 96,106 97,111 101,105 102,100 102,94 101,146 101,151 98,158 96,164 98,167 101,163 102,158 103,152 102,104 158,114 154,124 152,129 154,133 153,143 155,152 159,144 164,136 166,127 167,119 166,111 163,106 158,118 157,129 158,139 158,150 159,139 161,128 161,117 160,100 99,157 99|70 105,70 112,70 122,71 130,71 138,72 146,73 154,75 163,78 171,81 178,86 184,92 190,97 195,104 200,112 202,119 203,127 204,135 204,143 202,150 199,157 195,164 190,168 185,174 179,178 171,181 164,184 156,185 148,186 139,188 131,188 123,189 114,189 106,82 86,89 79,97 78,106 80,113 83,113 88,105 86,97 85,89 85,145 83,153 80,160 79,169 80,176 88,168 85,160 84,153 86,144 88,129 100,129 111,129 121,129 131,115 137,122 139,128 140,135 139,142 138,90 100,94 97,100 96,106 97,111 101,106 102,100 103,94 102,146 101,151 98,158 96,164 98,168 101,163 102,158 103,152 102,104 158,114 154,124 152,129 153,134 153,144 155,152 160,144 164,136 166,128 167,119 167,111 163,107 159,118 157,129 158,140 157,150 159,139 161,128 161,117 160,100 99,157 99|70 104,70 112,70 121,71 129,71 138,72 146,73 154,75 163,78 171,81 177,86 184,91 190,97 195,104 200,112 202,119 204,127 205,135 204,143 202,150 199,157 195,163 191,168 185,174 179,178 171,181 164,183 156,185 148,186 139,187 131,188 123,188 114,188 105,82 86,89 79,97 78,106 79,113 83,113 88,105 86,97 85,89 85,144 83,152 79,160 79,168 79,175 88,168 85,160 84,152 86,144 88,129 99,129 110,129 121,129 131,115 137,122 139,128 140,136 139,142 138,89 99,94 96,100 95,106 96,111 100,106 101,100 102,94 101,146 100,151 97,158 96,164 97,168 100,163 102,158 103,152 102,104 158,114 154,124 152,129 153,134 152,144 155,152 160,144 164,136 166,128 167,119 166,111 163,107 159,118 157,129 158,139 157,150 159,139 161,128 161,117 160,100 98,157 99|70 105,70 113,71 122,71 130,72 139,72 147,74 155,75 164,78 172,82 179,87 185,92 191,98 196,104 201,112 203,119 205,128 205,135 205,143 203,150 200,157 196,163 192,168 186,174 180,177 172,180 165,183 157,185 149,186 140,187 132,188 124,188 115,188 107,82 87,90 79,98 78,106 80,114 83,114 88,106 86,98 85,90 85,145 83,153 80,160 79,169 79,176 88,168 85,160 84,153 86,145 88,129 99,129 111,129 122,130 132,115 138,122 140,129 141,136 140,142 139,90 99,95 97,101 96,107 97,112 101,106 102,101 103,95 102,146 101,152 98,158 96,164 98,168 101,163 102,158 103,152 102,104 159,114 155,124 153,130 154,134 154,144 156,152 160,144 164,136 166,128 167,120 166,112 163,107 159,118 158,129 159,140 158,150 160,139 161,128 161,117 160,101 99,158 99|70 101,70 109,71 119,71 127,72 136,72 144,74 152,75 162,78 170,82 177,86 184,92 190,98 196,104 201,112 204,120 205,128 206,136 206,144 204,150 200,157 196,164 191,168 185,174 178,178 170,181 163,183 155,185 146,186 137,188 129,188 120,189 112,189 103,82 84,90 77,98 76,106 78,114 82,114 87,106 85,98 83,90 83,144 82,152 79,160 78,168 78,176 87,168 84,160 83,152 86,144 88,129 97,129 109,129 120,130 130,116 136,122 138,129 139,136 138,142 137,90 97,95 95,101 93,107 94,112 99,106 100,101 101,95 99,146 99,151 96,158 94,164 96,168 99,163 100,158 101,152 100,105 158,115 154,125 152,130 153,134 153,143 155,151 159,144 164,136 166,128 167,120 166,112 163,107 158,118 157,129 158,139 157,150 159,139 161,128 160,118 160,101 97,157 97|70 98,70 106,70 116,71 124,71 133,73 142,74 150,75 159,78 168,82 176,86 183,92 190,97 196,104 201,112 205,119 206,128 208,136 207,144 204,150 201,158 196,164 190,169 184,174 177,178 169,181 161,184 153,185 145,187 135,188 127,189 118,189 109,189 99,82 82,89 74,98 74,106 75,114 79,114 85,106 82,97 81,90 81,145 79,153 76,161 75,169 76,177 84,169 82,161 81,153 83,145 85,129 96,129 107,129 117,130 128,116 134,122 136,129 137,136 136,143 135,90 96,94 93,100 92,107 93,112 97,106 98,100 99,95 98,146 97,152 94,158 92,164 94,168 97,163 99,159 100,152 99,105 158,114 153,125 151,129 152,134 151,144 154,151 159,144 165,137 168,128 169,119 168,112 164,108 158,118 156,129 156,139 156,150 159,139 162,128 163,118 161,100 95,158 96
--------------------------------------------------------------------------------