├── ASR_model ├── GRID │ ├── data │ │ └── README.md │ ├── src │ │ ├── models │ │ │ ├── classifier.py │ │ │ ├── audio_front.py │ │ │ └── resnet.py │ │ └── data │ │ │ ├── transforms.py │ │ │ ├── audio_processing.py │ │ │ ├── stft.py │ │ │ └── vid_aud_GRID_test.py │ └── test.py └── LRW │ ├── data │ ├── README.md │ └── class.txt │ ├── src │ ├── data │ │ ├── transforms.py │ │ ├── audio_processing.py │ │ ├── vid_aud_lrw_test.py │ │ └── stft.py │ └── models │ │ ├── classifier.py │ │ ├── audio_front.py │ │ └── resnet.py │ └── test.py ├── data ├── LRS2 │ ├── LRS2_crop │ │ └── README.md │ └── README.md ├── LRS3 │ └── LRS3_crop │ │ └── README.md ├── val_4.txt └── test_4.txt ├── img └── Img.PNG ├── src ├── data │ ├── transforms.py │ ├── audio_processing.py │ ├── stft.py │ ├── vid_aud_grid.py │ ├── vid_aud_lrs3.py │ └── vid_aud_lrs2.py └── models │ ├── audio_front.py │ ├── visual_front.py │ ├── resnet.py │ └── generator.py ├── preprocess ├── Extract_frames.py ├── Extract_audio_LRS.py ├── Preprocess.py └── Ref_face.txt ├── README.md ├── README_LRS.md ├── README_GRID.md ├── test.py └── test_LRS.py /ASR_model/GRID/data/README.md: -------------------------------------------------------------------------------- 1 | Put the checkpoints of ASR here. 2 | -------------------------------------------------------------------------------- /ASR_model/LRW/data/README.md: -------------------------------------------------------------------------------- 1 | Put the checkpoints of ASR here. 2 | -------------------------------------------------------------------------------- /data/LRS2/LRS2_crop/README.md: -------------------------------------------------------------------------------- 1 | Put the lip coordinates files of LRS2. 2 | -------------------------------------------------------------------------------- /data/LRS3/LRS3_crop/README.md: -------------------------------------------------------------------------------- 1 | Put the lip coordinates files of LRS3. 2 | -------------------------------------------------------------------------------- /img/Img.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ms-dot-k/Visual-Context-Attentional-GAN/HEAD/img/Img.PNG -------------------------------------------------------------------------------- /data/LRS2/README.md: -------------------------------------------------------------------------------- 1 | Put here the data split files of LRS2. 2 | -pretrain.txt 3 | -test.txt 4 | -train.txt 5 | -val.txt 6 | -------------------------------------------------------------------------------- /ASR_model/GRID/src/models/classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class Backend(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self.gru = nn.GRU(256, 256, 2, bidirectional=True, dropout=0.3) 7 | self.fc = nn.Linear(512, 27 + 1) 8 | 9 | def forward(self, x): 10 | x = x.permute(1, 0, 2).contiguous() # S,B,96*7*7 11 | self.gru.flatten_parameters() 12 | 13 | x, _ = self.gru(x) # S,B,512 14 | x = x.permute(1, 0, 2).contiguous() 15 | x = self.fc(x) # B, S, 28 16 | return x 17 | 18 | -------------------------------------------------------------------------------- /src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms.functional as F 3 | 4 | class StatefulRandomHorizontalFlip(): 5 | def __init__(self, probability=0.5): 6 | self.probability = probability 7 | self.rand = random.random() 8 | 9 | def __call__(self, img): 10 | if self.rand < self.probability: 11 | return F.hflip(img) 12 | return img 13 | 14 | def __repr__(self): 15 | return self.__class__.__name__ + '(probability={})'.format(self.probability) 16 | 17 | 18 | class Crop(object): 19 | def __init__(self, crop): 20 | self.crop = crop 21 | 22 | def __call__(self, img): 23 | return img.crop(self.crop) 24 | -------------------------------------------------------------------------------- /ASR_model/GRID/src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms.functional as F 3 | 4 | class StatefulRandomHorizontalFlip(): 5 | def __init__(self, probability=0.5): 6 | self.probability = probability 7 | self.rand = random.random() 8 | 9 | def __call__(self, img): 10 | if self.rand < self.probability: 11 | return F.hflip(img) 12 | return img 13 | 14 | def __repr__(self): 15 | return self.__class__.__name__ + '(probability={})'.format(self.probability) 16 | 17 | 18 | class Crop(object): 19 | def __init__(self, crop): 20 | self.crop = crop 21 | 22 | def __call__(self, img): 23 | return img.crop(self.crop) 24 | -------------------------------------------------------------------------------- /ASR_model/LRW/src/data/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms.functional as F 3 | 4 | class StatefulRandomHorizontalFlip(): 5 | def __init__(self, probability=0.5): 6 | self.probability = probability 7 | self.rand = random.random() 8 | 9 | def __call__(self, img): 10 | if self.rand < self.probability: 11 | return F.hflip(img) 12 | return img 13 | 14 | def __repr__(self): 15 | return self.__class__.__name__ + '(probability={})'.format(self.probability) 16 | 17 | 18 | class Crop(object): 19 | def __init__(self, crop): 20 | self.crop = crop 21 | 22 | def __call__(self, img): 23 | return img.crop(self.crop) 24 | -------------------------------------------------------------------------------- /ASR_model/LRW/src/models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class Backend(nn.Module): 5 | def __init__(self, logits=True): 6 | super().__init__() 7 | self.logits = logits 8 | 9 | self.gru = nn.GRU(512, 512, 2, bidirectional=True, dropout=0.3) 10 | if logits: 11 | self.fc = nn.Linear(1024, 500) 12 | 13 | def forward(self, x): 14 | x = x.permute(1, 0, 2).contiguous() # S,B,512 15 | self.gru.flatten_parameters() 16 | 17 | x, _ = self.gru(x) 18 | x = x.mean(0, keepdim=False) 19 | 20 | if self.logits: 21 | pred = self.fc(x) # B, 500 22 | return pred 23 | else: 24 | return x 25 | 26 | -------------------------------------------------------------------------------- /ASR_model/LRW/src/models/audio_front.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from src.models.resnet import BasicBlock 3 | import torch 4 | 5 | class Audio_front(nn.Module): 6 | def __init__(self, in_channels=1): 7 | super().__init__() 8 | 9 | self.in_channels = in_channels 10 | 11 | self.frontend = nn.Sequential( 12 | nn.Conv2d(self.in_channels, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 13 | nn.BatchNorm2d(128), 14 | nn.PReLU(128), 15 | nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 16 | nn.BatchNorm2d(256), 17 | nn.PReLU(256) 18 | ) 19 | 20 | self.Res_block = nn.Sequential( 21 | BasicBlock(256, 256) 22 | ) 23 | 24 | self.Linear = nn.Linear(256 * 20, 512) 25 | 26 | self.dropout = nn.Dropout(0.3) 27 | 28 | def forward(self, x): 29 | x = self.frontend(x) #B, 256, F/4, T/4 30 | x = self.Res_block(x) #B, 256, F/4, T/4 31 | b, c, f, t = x.size() 32 | x = x.view(b, c*f, t).transpose(1, 2).contiguous() #B, T/4, 256 * F/4 33 | x = self.dropout(x) 34 | x = self.Linear(x) #B, T/4, 512 35 | return x 36 | 37 | -------------------------------------------------------------------------------- /ASR_model/GRID/src/models/audio_front.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from src.models.resnet import BasicBlock 3 | import torch 4 | 5 | class Audio_front(nn.Module): 6 | def __init__(self, in_channels=1): 7 | super().__init__() 8 | 9 | self.in_channels = in_channels 10 | 11 | self.frontend = nn.Sequential( 12 | nn.Conv2d(self.in_channels, 32, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)), 13 | nn.BatchNorm2d(32), 14 | nn.PReLU(32), 15 | 16 | nn.Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2)), 17 | nn.BatchNorm2d(64), 18 | nn.PReLU(64) 19 | ) 20 | 21 | self.Res_block = nn.Sequential( 22 | BasicBlock(64, 64, relu_type='prelu') 23 | ) 24 | 25 | self.Linear = nn.Linear(64 * 20, 256) 26 | 27 | self.dropout = nn.Dropout(0.3) 28 | 29 | def forward(self, x): 30 | x = self.frontend(x) #B, 64, F/4, T/4 31 | x = self.Res_block(x) #B, 64, F/4, T/4 32 | b, c, f, t = x.size() 33 | x = x.view(b, c*f, t).transpose(1, 2).contiguous() #B, T/4, 64 * F/4 34 | x = self.dropout(x) 35 | x = self.Linear(x) #B, T/4, 96 36 | return x 37 | 38 | -------------------------------------------------------------------------------- /src/models/audio_front.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from src.models.resnet import BasicBlock 3 | import torch 4 | 5 | class Audio_front(nn.Module): 6 | def __init__(self, in_channels=1): 7 | super().__init__() 8 | 9 | self.in_channels = in_channels 10 | 11 | self.frontend = nn.Sequential( 12 | nn.Conv2d(self.in_channels, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 13 | nn.BatchNorm2d(128), 14 | nn.PReLU(128), 15 | nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 16 | nn.BatchNorm2d(256), 17 | nn.PReLU(256) 18 | ) 19 | 20 | self.Res_block = nn.Sequential( 21 | BasicBlock(256, 256) 22 | ) 23 | 24 | self.Linear = nn.Linear(256 * 20, 512) 25 | 26 | self.dropout = nn.Dropout(0.3) 27 | 28 | def forward(self, x): 29 | x = self.frontend(x) #B, 256, F/4, T/4 30 | x = self.Res_block(x) #B, 256, F/4, T/4 31 | b, c, f, t = x.size() 32 | x = x.view(b, c*f, t).transpose(1, 2).contiguous() #B, T/4, 256 * F/4 33 | x = self.dropout(x) 34 | x = self.Linear(x) #B, T/4, 512 35 | x = x.permute(1, 0, 2).contiguous() # S,B,256 36 | return x 37 | 38 | -------------------------------------------------------------------------------- /preprocess/Extract_frames.py: -------------------------------------------------------------------------------- 1 | import os, glob, subprocess 2 | import argparse 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--Grid_dir', type=str, default="Data dir to GRID_corpus") 7 | parser.add_argument("--Output_dir", type=str, default='Output dir Ex) ./GRID_imgs_aud') 8 | args = parser.parse_args() 9 | return args 10 | 11 | args = parse_args() 12 | 13 | vid_files = sorted(glob.glob(os.path.join(args.Grid_dir, '*', 'video', '*.mpg'))) #suppose the directory: Data_dir/subject/video/mpg files 14 | for k, v in enumerate(vid_files): 15 | t, f_name = os.path.split(v) 16 | t, _ = os.path.split(t) 17 | _, sub_name = os.path.split(t) 18 | out_im = os.path.join(args.Output_dir, sub_name, 'video', f_name[:-4]) 19 | if len(glob.glob(os.path.join(out_im, '*.png'))) < 75: # Can resume after killed 20 | if not os.path.exists(out_im): 21 | os.makedirs(out_im) 22 | out_aud = os.path.join(args.Output_dir, sub_name, 'audio') 23 | if not os.path.exists(out_aud): 24 | os.makedirs(out_aud) 25 | subprocess.call(f'ffmpeg -y -i {v} -qscale:v 2 -r 25 {out_im}/%02d.png', shell=True) 26 | subprocess.call(f'ffmpeg -y -i {v} -ac 1 -acodec pcm_s16le -ar 16000 {os.path.join(out_aud, f_name[:-4] + ".wav")}', shell=True) 27 | print(f'{k}/{len(vid_files)}') 28 | -------------------------------------------------------------------------------- /src/models/visual_front.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from src.models.resnet import BasicBlock, ResNet 3 | 4 | class Visual_front(nn.Module): 5 | def __init__(self, in_channels=1): 6 | super().__init__() 7 | 8 | self.in_channels = in_channels 9 | 10 | self.frontend = nn.Sequential( 11 | nn.Conv3d(self.in_channels, 64, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), #44,44 12 | nn.BatchNorm3d(64), 13 | nn.PReLU(64), 14 | nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) #28,28 15 | ) 16 | 17 | self.resnet = ResNet(BasicBlock, [2, 2, 2, 2], relu_type='prelu') 18 | self.dropout = nn.Dropout(0.3) 19 | 20 | self.sentence_encoder = nn.GRU(512, 512, 2, bidirectional=True, dropout=0.3) 21 | self.fc = nn.Linear(1024, 512) 22 | 23 | def forward(self, x): 24 | #B,C,S,H,W 25 | x = self.frontend(x) #B,C,T,H,W 26 | B, C, T, H, W = x.size() 27 | x = x.transpose(1, 2).contiguous().view(B*T, C, H, W) 28 | x = self.resnet(x) # B*T, 512 29 | x = self.dropout(x) 30 | x = x.view(B, T, -1) 31 | phons = x.permute(1, 0, 2).contiguous() # S,B,512 32 | 33 | self.sentence_encoder.flatten_parameters() 34 | sentence, _ = self.sentence_encoder(phons) 35 | sentence = self.fc(sentence).permute(1, 2, 0).contiguous() # B,512,T 36 | 37 | return phons.permute(1, 0, 2), sentence 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lip to Speech Synthesis with Visual Context Attentional GAN 2 | 3 | This repository contains the PyTorch implementation of the following paper: 4 | > **Lip to Speech Synthesis with Visual Context Attentional GAN**
5 | > Minsu Kim, Joanna Hong, and Yong Man Ro
6 | > \[[Paper](https://proceedings.neurips.cc/paper/2021/file/16437d40c29a1a7b1e78143c9c38f289-Paper.pdf)\] \[[Demo Video](https://kaistackr-my.sharepoint.com/:v:/g/personal/ms_k_kaist_ac_kr/EQp2Zao1ZQFDm9xDVuZubKIB_ns_6gk0L6LB3U5jd4jYKw?e=Qw8ddt)\] 7 | 8 |
9 | 10 | ## Requirements 11 | - python 3.7 12 | - pytorch 1.6 ~ 1.8 13 | - torchvision 14 | - torchaudio 15 | - ffmpeg 16 | - av 17 | - tensorboard 18 | - scikit-image 0.17.0 ~ 19 | - opencv-python 3.4 ~ 20 | - pillow 21 | - librosa 22 | - pystoi 23 | - pesq 24 | - scipy 25 | 26 | ## GRID 27 | Please refer [here](README_GRID.md) to run the code on GRID dataset. 28 | 29 | ## LRS2/LRS3 30 | Please refer [here](README_LRS.md) to run the code and model on LRS2 and LRS3 datasets. 31 | 32 | ## Citation 33 | If you find this work useful in your research, please cite the papers: 34 | ``` 35 | @article{kim2021vcagan, 36 | title={Lip to Speech Synthesis with Visual Context Attentional GAN}, 37 | author={Kim, Minsu and Hong, Joanna and Ro, Yong Man}, 38 | journal={Advances in Neural Information Processing Systems}, 39 | volume={34}, 40 | year={2021} 41 | } 42 | 43 | @inproceedings{kim2023lip, 44 | title={Lip-to-speech synthesis in the wild with multi-task learning}, 45 | author={Kim, Minsu and Hong, Joanna and Ro, Yong Man}, 46 | booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 47 | pages={1--5}, 48 | year={2023}, 49 | organization={IEEE} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /preprocess/Extract_audio_LRS.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import argparse 3 | from tqdm import tqdm 4 | from joblib import Parallel, delayed 5 | 6 | 7 | def build_file_list(data_path, data_type): 8 | if data_type == 'LRS2': 9 | files = sorted(glob.glob(os.path.join(data_path, 'main', '*', '*.mp4'))) 10 | files.extend(glob.glob(os.path.join(data_path, 'pretrain', '*', '*.mp4'))) 11 | elif data_type == 'LRS3': 12 | files = sorted(glob.glob(os.path.join(data_path, 'trainval', '*', '*.mp4'))) 13 | files.extend(glob.glob(os.path.join(data_path, 'pretrain', '*', '*.mp4'))) 14 | files.extend(glob.glob(os.path.join(data_path, 'test', '*', '*.mp4'))) 15 | else: 16 | raise NotImplementedError 17 | return [f.replace(data_path + '/', '')[:-4] for f in files] 18 | 19 | def per_file(f, args): 20 | save_path = os.path.join(args.save_path, f) 21 | if os.path.exists(save_path + '.wav'): return 22 | if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) 23 | vid_name = os.path.join(args.data_path, f + '.mp4') 24 | os.system( 25 | f'ffmpeg -loglevel panic -nostdin -y -i {vid_name} -acodec pcm_s16le -ar 16000 -ac 1 {save_path}.wav') 26 | 27 | def main(): 28 | parser = get_parser() 29 | args = parser.parse_args() 30 | file_lists = build_file_list(args.data_path, args.data_type) 31 | Parallel(n_jobs=3)(delayed(per_file)(f, args) for f in tqdm(file_lists)) 32 | 33 | def get_parser(): 34 | parser = argparse.ArgumentParser( 35 | description="Command-line script for preprocessing." 36 | ) 37 | parser.add_argument( 38 | "--data_path", type=str, required=True, help="path for original data" 39 | ) 40 | parser.add_argument( 41 | "--save_path", type=str, required=True, help="path for saving" 42 | ) 43 | parser.add_argument( 44 | "--data_type", type=str, required=True, help="LRS2 or LRS3" 45 | ) 46 | return parser 47 | 48 | 49 | if __name__ == "__main__": 50 | main() -------------------------------------------------------------------------------- /src/data/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | This is used to estimate modulation effects induced by windowing 13 | observations in short-time fourier transforms. 14 | Parameters 15 | ---------- 16 | window : string, tuple, number, callable, or list-like 17 | Window specification, as in `get_window` 18 | n_frames : int > 0 19 | The number of analysis frames 20 | hop_length : int > 0 21 | The number of samples to advance between frames 22 | win_length : [optional] 23 | The length of the window function. By default, this matches `n_fft`. 24 | n_fft : int > 0 25 | The length of each analysis frame. 26 | dtype : np.dtype 27 | The data type of the output 28 | Returns 29 | ------- 30 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 31 | The sum-squared envelope of the window function 32 | """ 33 | if win_length is None: 34 | win_length = n_fft 35 | 36 | n = n_fft + hop_length * (n_frames - 1) 37 | x = np.zeros(n, dtype=dtype) 38 | 39 | # Compute the squared window at the desired length 40 | win_sq = get_window(window, win_length, fftbins=True) 41 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 42 | win_sq = librosa_util.pad_center(win_sq, n_fft) 43 | 44 | # Fill the envelope 45 | for i in range(n_frames): 46 | sample = i * hop_length 47 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 48 | return x 49 | 50 | 51 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 52 | """ 53 | PARAMS 54 | ------ 55 | magnitudes: spectrogram magnitudes 56 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 57 | """ 58 | 59 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 60 | angles = angles.astype(np.float32) 61 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 62 | angles = angles.cuda() if magnitudes.is_cuda else angles 63 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 64 | 65 | for i in range(n_iters): 66 | _, angles = stft_fn.transform(signal) 67 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 68 | return signal 69 | 70 | 71 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 72 | """ 73 | PARAMS 74 | ------ 75 | C: compression factor 76 | """ 77 | return torch.log(torch.clamp(x, min=clip_val) * C) 78 | 79 | 80 | def dynamic_range_decompression(x, C=1): 81 | """ 82 | PARAMS 83 | ------ 84 | C: compression factor used to compress 85 | """ 86 | return torch.exp(x) / C -------------------------------------------------------------------------------- /ASR_model/GRID/src/data/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | This is used to estimate modulation effects induced by windowing 13 | observations in short-time fourier transforms. 14 | Parameters 15 | ---------- 16 | window : string, tuple, number, callable, or list-like 17 | Window specification, as in `get_window` 18 | n_frames : int > 0 19 | The number of analysis frames 20 | hop_length : int > 0 21 | The number of samples to advance between frames 22 | win_length : [optional] 23 | The length of the window function. By default, this matches `n_fft`. 24 | n_fft : int > 0 25 | The length of each analysis frame. 26 | dtype : np.dtype 27 | The data type of the output 28 | Returns 29 | ------- 30 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 31 | The sum-squared envelope of the window function 32 | """ 33 | if win_length is None: 34 | win_length = n_fft 35 | 36 | n = n_fft + hop_length * (n_frames - 1) 37 | x = np.zeros(n, dtype=dtype) 38 | 39 | # Compute the squared window at the desired length 40 | win_sq = get_window(window, win_length, fftbins=True) 41 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 42 | win_sq = librosa_util.pad_center(win_sq, n_fft) 43 | 44 | # Fill the envelope 45 | for i in range(n_frames): 46 | sample = i * hop_length 47 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 48 | return x 49 | 50 | 51 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 52 | """ 53 | PARAMS 54 | ------ 55 | magnitudes: spectrogram magnitudes 56 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 57 | """ 58 | 59 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 60 | angles = angles.astype(np.float32) 61 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 62 | angles = angles.cuda() if magnitudes.is_cuda else angles 63 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 64 | 65 | for i in range(n_iters): 66 | _, angles = stft_fn.transform(signal) 67 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 68 | return signal 69 | 70 | 71 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 72 | """ 73 | PARAMS 74 | ------ 75 | C: compression factor 76 | """ 77 | return torch.log(torch.clamp(x, min=clip_val) * C) 78 | 79 | 80 | def dynamic_range_decompression(x, C=1): 81 | """ 82 | PARAMS 83 | ------ 84 | C: compression factor used to compress 85 | """ 86 | return torch.exp(x) / C -------------------------------------------------------------------------------- /ASR_model/LRW/src/data/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | This is used to estimate modulation effects induced by windowing 13 | observations in short-time fourier transforms. 14 | Parameters 15 | ---------- 16 | window : string, tuple, number, callable, or list-like 17 | Window specification, as in `get_window` 18 | n_frames : int > 0 19 | The number of analysis frames 20 | hop_length : int > 0 21 | The number of samples to advance between frames 22 | win_length : [optional] 23 | The length of the window function. By default, this matches `n_fft`. 24 | n_fft : int > 0 25 | The length of each analysis frame. 26 | dtype : np.dtype 27 | The data type of the output 28 | Returns 29 | ------- 30 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 31 | The sum-squared envelope of the window function 32 | """ 33 | if win_length is None: 34 | win_length = n_fft 35 | 36 | n = n_fft + hop_length * (n_frames - 1) 37 | x = np.zeros(n, dtype=dtype) 38 | 39 | # Compute the squared window at the desired length 40 | win_sq = get_window(window, win_length, fftbins=True) 41 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 42 | win_sq = librosa_util.pad_center(win_sq, n_fft) 43 | 44 | # Fill the envelope 45 | for i in range(n_frames): 46 | sample = i * hop_length 47 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 48 | return x 49 | 50 | 51 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 52 | """ 53 | PARAMS 54 | ------ 55 | magnitudes: spectrogram magnitudes 56 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 57 | """ 58 | 59 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 60 | angles = angles.astype(np.float32) 61 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 62 | angles = angles.cuda() if magnitudes.is_cuda else angles 63 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 64 | 65 | for i in range(n_iters): 66 | _, angles = stft_fn.transform(signal) 67 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 68 | return signal 69 | 70 | 71 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 72 | """ 73 | PARAMS 74 | ------ 75 | C: compression factor 76 | """ 77 | return torch.log(torch.clamp(x, min=clip_val) * C) 78 | 79 | 80 | def dynamic_range_decompression(x, C=1): 81 | """ 82 | PARAMS 83 | ------ 84 | C: compression factor used to compress 85 | """ 86 | return torch.exp(x) / C -------------------------------------------------------------------------------- /src/models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=1, bias=False) 8 | 9 | 10 | def downsample_basic_block(inplanes, outplanes, stride): 11 | return nn.Sequential( 12 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), 13 | nn.BatchNorm2d(outplanes), 14 | ) 15 | 16 | 17 | def downsample_basic_block_v2(inplanes, outplanes, stride): 18 | return nn.Sequential( 19 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), 20 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), 21 | nn.BatchNorm2d(outplanes), 22 | ) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='relu'): 29 | super(BasicBlock, self).__init__() 30 | 31 | assert relu_type in ['relu', 'prelu'] 32 | 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | 36 | # type of ReLU is an input option 37 | if relu_type == 'relu': 38 | self.relu1 = nn.ReLU(inplace=True) 39 | self.relu2 = nn.ReLU(inplace=True) 40 | elif relu_type == 'prelu': 41 | self.relu1 = nn.PReLU(num_parameters=planes) 42 | self.relu2 = nn.PReLU(num_parameters=planes) 43 | else: 44 | raise Exception('relu type not implemented') 45 | # -------- 46 | 47 | self.conv2 = conv3x3(planes, planes) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | 50 | self.downsample = downsample 51 | self.stride = stride 52 | 53 | def forward(self, x): 54 | residual = x 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu1(out) 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu2(out) 65 | 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | 71 | def __init__(self, block, layers, num_classes=1000, relu_type='relu', gamma_zero=False, avg_pool_downsample=False): 72 | self.inplanes = 64 73 | self.relu_type = relu_type 74 | self.gamma_zero = gamma_zero 75 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block 76 | 77 | super(ResNet, self).__init__() 78 | self.layer1 = self._make_layer(block, 64, layers[0]) 79 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 82 | self.avgpool = nn.AvgPool2d(4) 83 | 84 | # default init 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | # nn.init.ones_(m.weight) 93 | # nn.init.zeros_(m.bias) 94 | 95 | if self.gamma_zero: 96 | for m in self.modules(): 97 | if isinstance(m, BasicBlock): 98 | m.bn2.weight.data.zero_() 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | 102 | downsample = None 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | downsample = self.downsample_block(inplanes=self.inplanes, 105 | outplanes=planes * block.expansion, 106 | stride=stride) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, blocks): 112 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.layer1(x) 118 | x = self.layer2(x) 119 | x = self.layer3(x) 120 | x = self.layer4(x) 121 | x = self.avgpool(x) 122 | x = x.view(x.size(0), -1) 123 | return x 124 | -------------------------------------------------------------------------------- /ASR_model/GRID/src/models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=1, bias=False) 8 | 9 | def downsample_basic_block(inplanes, outplanes, stride): 10 | return nn.Sequential( 11 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), 12 | nn.BatchNorm2d(outplanes), 13 | ) 14 | 15 | 16 | def downsample_basic_block_v2(inplanes, outplanes, stride): 17 | return nn.Sequential( 18 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), 19 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), 20 | nn.BatchNorm2d(outplanes), 21 | ) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='relu'): 28 | super(BasicBlock, self).__init__() 29 | 30 | assert relu_type in ['relu', 'prelu'] 31 | 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | 35 | # type of ReLU is an input option 36 | if relu_type == 'relu': 37 | self.relu1 = nn.ReLU(inplace=True) 38 | self.relu2 = nn.ReLU(inplace=True) 39 | elif relu_type == 'prelu': 40 | self.relu1 = nn.PReLU(num_parameters=planes) 41 | self.relu2 = nn.PReLU(num_parameters=planes) 42 | else: 43 | raise Exception('relu type not implemented') 44 | # -------- 45 | 46 | self.conv2 = conv3x3(planes, planes) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | residual = x 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu1(out) 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu2(out) 64 | 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | 70 | def __init__(self, block, layers, num_classes=1000, relu_type='relu', gamma_zero=False, avg_pool_downsample=False): 71 | self.inplanes = 64 72 | self.relu_type = relu_type 73 | self.gamma_zero = gamma_zero 74 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block 75 | 76 | super(ResNet, self).__init__() 77 | self.layer1 = self._make_layer(block, 64, layers[0]) 78 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 79 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 80 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 81 | self.avgpool = nn.AvgPool2d(3) 82 | 83 | # default init 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 87 | m.weight.data.normal_(0, math.sqrt(2. / n)) 88 | elif isinstance(m, nn.BatchNorm2d): 89 | m.weight.data.fill_(1) 90 | m.bias.data.zero_() 91 | # nn.init.ones_(m.weight) 92 | # nn.init.zeros_(m.bias) 93 | 94 | if self.gamma_zero: 95 | for m in self.modules(): 96 | if isinstance(m, BasicBlock): 97 | m.bn2.weight.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, stride=1): 100 | 101 | downsample = None 102 | if stride != 1 or self.inplanes != planes * block.expansion: 103 | downsample = self.downsample_block(inplanes=self.inplanes, 104 | outplanes=planes * block.expansion, 105 | stride=stride) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | x = self.avgpool(x) 121 | x = x.view(x.size(0), -1) 122 | return x 123 | -------------------------------------------------------------------------------- /ASR_model/LRW/src/models/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=1, bias=False) 8 | 9 | 10 | def downsample_basic_block(inplanes, outplanes, stride): 11 | return nn.Sequential( 12 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), 13 | nn.BatchNorm2d(outplanes), 14 | ) 15 | 16 | 17 | def downsample_basic_block_v2(inplanes, outplanes, stride): 18 | return nn.Sequential( 19 | nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), 20 | nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), 21 | nn.BatchNorm2d(outplanes), 22 | ) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type='relu'): 29 | super(BasicBlock, self).__init__() 30 | 31 | assert relu_type in ['relu', 'prelu'] 32 | 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | 36 | # type of ReLU is an input option 37 | if relu_type == 'relu': 38 | self.relu1 = nn.ReLU(inplace=True) 39 | self.relu2 = nn.ReLU(inplace=True) 40 | elif relu_type == 'prelu': 41 | self.relu1 = nn.PReLU(num_parameters=planes) 42 | self.relu2 = nn.PReLU(num_parameters=planes) 43 | else: 44 | raise Exception('relu type not implemented') 45 | # -------- 46 | 47 | self.conv2 = conv3x3(planes, planes) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | 50 | self.downsample = downsample 51 | self.stride = stride 52 | 53 | def forward(self, x): 54 | residual = x 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu1(out) 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu2(out) 65 | 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | 71 | def __init__(self, block, layers, num_classes=1000, relu_type='relu', gamma_zero=False, avg_pool_downsample=False): 72 | self.inplanes = 64 73 | self.relu_type = relu_type 74 | self.gamma_zero = gamma_zero 75 | self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block 76 | 77 | super(ResNet, self).__init__() 78 | self.layer1 = self._make_layer(block, 64, layers[0]) 79 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 81 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 82 | self.avgpool = nn.AvgPool2d(3) 83 | 84 | # default init 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | # nn.init.ones_(m.weight) 93 | # nn.init.zeros_(m.bias) 94 | 95 | if self.gamma_zero: 96 | for m in self.modules(): 97 | if isinstance(m, BasicBlock): 98 | m.bn2.weight.data.zero_() 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | 102 | downsample = None 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | downsample = self.downsample_block(inplanes=self.inplanes, 105 | outplanes=planes * block.expansion, 106 | stride=stride) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, downsample, relu_type=self.relu_type)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, blocks): 112 | layers.append(block(self.inplanes, planes, relu_type=self.relu_type)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | x = self.layer1(x) 118 | x = self.layer2(x) 119 | x = self.layer3(x) 120 | x = self.layer4(x) 121 | x = self.avgpool(x) 122 | x = x.view(x.size(0), -1) 123 | return x 124 | -------------------------------------------------------------------------------- /data/val_4.txt: -------------------------------------------------------------------------------- 1 | s1/video/srih1s.mp4 2 | s1/video/srih2p.mp4 3 | s1/video/lwbz6p.mp4 4 | s1/video/bwwh5a.mp4 5 | s1/video/lbbq8n.mp4 6 | s1/video/lbbk5s.mp4 7 | s1/video/lwbl6n.mp4 8 | s1/video/srin5s.mp4 9 | s1/video/lwie6p.mp4 10 | s1/video/bwwn8p.mp4 11 | s1/video/pbib9a.mp4 12 | s1/video/sgai6n.mp4 13 | s1/video/bgaa9a.mp4 14 | s1/video/praj2p.mp4 15 | s1/video/bgig9a.mp4 16 | s1/video/bbwg1s.mp4 17 | s1/video/lwaz1s.mp4 18 | s1/video/sgbj2p.mp4 19 | s1/video/bbas1s.mp4 20 | s1/video/pwbd6n.mp4 21 | s1/video/lrwr8n.mp4 22 | s1/video/lwws5s.mp4 23 | s1/video/swao5s.mp4 24 | s1/video/pwix2p.mp4 25 | s1/video/priv4n.mp4 26 | s1/video/bgwu8p.mp4 27 | s1/video/sgav4n.mp4 28 | s1/video/pwaj7s.mp4 29 | s1/video/lgwm8p.mp4 30 | s1/video/pwwk7a.mp4 31 | s1/video/pgwl3a.mp4 32 | s1/video/pwbd7s.mp4 33 | s1/video/pgby6p.mp4 34 | s1/video/lbwr4p.mp4 35 | s1/video/sbwu6p.mp4 36 | s1/video/swbc2p.mp4 37 | s1/video/lwws7a.mp4 38 | s1/video/sgai9a.mp4 39 | s1/video/lwal2n.mp4 40 | s1/video/brwg6n.mp4 41 | s1/video/bbal7s.mp4 42 | s1/video/srbo5a.mp4 43 | s1/video/prip2p.mp4 44 | s1/video/pwwe1s.mp4 45 | s1/video/bbiz1s.mp4 46 | s1/video/srwi2n.mp4 47 | s1/video/bgahzn.mp4 48 | s1/video/lbwk9s.mp4 49 | s1/video/pgiq4p.mp4 50 | s2/video/lbakzs.mp4 51 | s2/video/pwwe1p.mp4 52 | s2/video/brwa1n.mp4 53 | s2/video/lriyzs.mp4 54 | s2/video/bwim4s.mp4 55 | s2/video/pgby6a.mp4 56 | s2/video/lrwr7n.mp4 57 | s2/video/lbwlza.mp4 58 | s2/video/bbir8a.mp4 59 | s2/video/lwbs2a.mp4 60 | s2/video/lgwtzs.mp4 61 | s2/video/pwaq2a.mp4 62 | s2/video/lbad6s.mp4 63 | s2/video/lgir9p.mp4 64 | s2/video/lbij7p.mp4 65 | s2/video/pgayzs.mp4 66 | s2/video/brimza.mp4 67 | s2/video/lrwz1n.mp4 68 | s2/video/lwbs1p.mp4 69 | s2/video/srah3n.mp4 70 | s2/video/sbbt9n.mp4 71 | s2/video/praj2a.mp4 72 | s2/video/lbbk3n.mp4 73 | s2/video/swao5p.mp4 74 | s2/video/bwbt8a.mp4 75 | s2/video/swwi8s.mp4 76 | s2/video/lgil3n.mp4 77 | s2/video/bwbn4a.mp4 78 | s2/video/lbaq5p.mp4 79 | s2/video/bbws7n.mp4 80 | s2/video/bbir5n.mp4 81 | s2/video/swbv2s.mp4 82 | s2/video/srwb9p.mp4 83 | s2/video/sgai7p.mp4 84 | s2/video/sgbjzs.mp4 85 | s2/video/bril7n.mp4 86 | s2/video/srwb7n.mp4 87 | s2/video/brbg4a.mp4 88 | s2/video/bwbn1n.mp4 89 | s2/video/sgbp4s.mp4 90 | s2/video/bgan4s.mp4 91 | s2/video/lbbk5p.mp4 92 | s2/video/pbwj3p.mp4 93 | s2/video/srah6a.mp4 94 | s2/video/pgiq4a.mp4 95 | s2/video/prwd4s.mp4 96 | s2/video/lgas4a.mp4 97 | s2/video/lrak5n.mp4 98 | s2/video/lwws4s.mp4 99 | s2/video/lbby3p.mp4 100 | s4/video/brwazs.mp4 101 | s4/video/lraq7n.mp4 102 | s4/video/sgai4s.mp4 103 | s4/video/srwb5n.mp4 104 | s4/video/lgaz3n.mp4 105 | s4/video/srau1p.mp4 106 | s4/video/pbbv5p.mp4 107 | s4/video/sgbc6a.mp4 108 | s4/video/prap2s.mp4 109 | s4/video/lgir5n.mp4 110 | s4/video/bwwn6a.mp4 111 | s4/video/sgwj1n.mp4 112 | s4/video/lgbl9n.mp4 113 | s4/video/lbip9p.mp4 114 | s4/video/sbig3p.mp4 115 | s4/video/pgix4s.mp4 116 | s4/video/sbbuza.mp4 117 | s4/video/pbbi5n.mp4 118 | s4/video/lbwk6s.mp4 119 | s4/video/pbwp6a.mp4 120 | s4/video/pwbj7n.mp4 121 | s4/video/lwwl9p.mp4 122 | s4/video/bwwn5p.mp4 123 | s4/video/brwt3p.mp4 124 | s4/video/sgwx1p.mp4 125 | s4/video/bras4s.mp4 126 | s4/video/swao4a.mp4 127 | s4/video/srwu8s.mp4 128 | s4/video/pgbq7n.mp4 129 | s4/video/sgih9n.mp4 130 | s4/video/pwwx9n.mp4 131 | s4/video/lwar3n.mp4 132 | s4/video/bgan2s.mp4 133 | s4/video/pbib5p.mp4 134 | s4/video/pwbq4a.mp4 135 | s4/video/lbaq1n.mp4 136 | s4/video/sbim6s.mp4 137 | s4/video/brif4a.mp4 138 | s4/video/pbwxza.mp4 139 | s4/video/lwbr9p.mp4 140 | s4/video/priv3p.mp4 141 | s4/video/prwd3p.mp4 142 | s4/video/lbwe4a.mp4 143 | s4/video/lrae2a.mp4 144 | s4/video/bwam7p.mp4 145 | s4/video/sbat3n.mp4 146 | s4/video/pgwk8s.mp4 147 | s4/video/swbv2a.mp4 148 | s4/video/lbij5p.mp4 149 | s4/video/lbaj9p.mp4 150 | s29/video/bric7s.mp4 151 | s29/video/swiz1s.mp4 152 | s29/video/pbif4p.mp4 153 | s29/video/sgbg3s.mp4 154 | s29/video/pwin1a.mp4 155 | s29/video/lwwj5a.mp4 156 | s29/video/lrwc5a.mp4 157 | s29/video/bbbp6n.mp4 158 | s29/video/srar4n.mp4 159 | s29/video/pratzn.mp4 160 | s29/video/lgax1a.mp4 161 | s29/video/lbauzn.mp4 162 | s29/video/sbwl4p.mp4 163 | s29/video/srwz7s.mp4 164 | s29/video/lgwx6n.mp4 165 | s29/video/sgbt1s.mp4 166 | s29/video/bwwk9s.mp4 167 | s29/video/pwbb1a.mp4 168 | s29/video/lgipzn.mp4 169 | s29/video/pgbu9a.mp4 170 | s29/video/lgic4p.mp4 171 | s29/video/sbix8n.mp4 172 | s29/video/sgwt4n.mp4 173 | s29/video/lbau3a.mp4 174 | s29/video/sbikzn.mp4 175 | s29/video/lrao2n.mp4 176 | s29/video/lgbq1a.mp4 177 | s29/video/bgwl6p.mp4 178 | s29/video/pgwi4p.mp4 179 | s29/video/pwwozn.mp4 180 | s29/video/lwav4p.mp4 181 | s29/video/lrin9s.mp4 182 | s29/video/pwag9s.mp4 183 | s29/video/srbl4n.mp4 184 | s29/video/sbiq7a.mp4 185 | s29/video/lwwp7s.mp4 186 | s29/video/bgik2n.mp4 187 | s29/video/swir8p.mp4 188 | s29/video/pbaz8n.mp4 189 | s29/video/srwm1a.mp4 190 | s29/video/lwbc4n.mp4 191 | s29/video/lrwi9a.mp4 192 | s29/video/sray9s.mp4 193 | s29/video/srae6n.mp4 194 | s29/video/srazzp.mp4 195 | s29/video/bbac6p.mp4 196 | s29/video/pgwo6n.mp4 197 | s29/video/lrbv3a.mp4 198 | s29/video/sbwzzn.mp4 199 | s29/video/bbbj5a.mp4 200 | -------------------------------------------------------------------------------- /data/test_4.txt: -------------------------------------------------------------------------------- 1 | s1/video/sbag8n.mp4 2 | s1/video/pbwj2n.mp4 3 | s1/video/bwat5a.mp4 4 | s1/video/pwbyzp.mp4 5 | s1/video/lgiz2n.mp4 6 | s1/video/sbwu4n.mp4 7 | s1/video/sbwo3a.mp4 8 | s1/video/brbm6n.mp4 9 | s1/video/swwc5s.mp4 10 | s1/video/bwaa3a.mp4 11 | s1/video/bwig2p.mp4 12 | s1/video/bwaa2p.mp4 13 | s1/video/sgbc6n.mp4 14 | s1/video/briz7s.mp4 15 | s1/video/lrie1a.mp4 16 | s1/video/pwip9a.mp4 17 | s1/video/swau8n.mp4 18 | s1/video/lrbr4n.mp4 19 | s1/video/lgwm9a.mp4 20 | s1/video/lwbl7s.mp4 21 | s1/video/pwidzp.mp4 22 | s1/video/bril9s.mp4 23 | s1/video/swiu5s.mp4 24 | s1/video/swwp2n.mp4 25 | s1/video/lrwszp.mp4 26 | s1/video/bwat4p.mp4 27 | s1/video/lwazzn.mp4 28 | s1/video/pwwq8n.mp4 29 | s1/video/sgav7a.mp4 30 | s1/video/brwg7s.mp4 31 | s1/video/bgbb3a.mp4 32 | s1/video/sgbc9a.mp4 33 | s1/video/swbc3a.mp4 34 | s1/video/sgbx1a.mp4 35 | s1/video/srit9s.mp4 36 | s1/video/bbwm4n.mp4 37 | s1/video/bbaf2n.mp4 38 | s1/video/srbu6n.mp4 39 | s1/video/bwwn9a.mp4 40 | s1/video/srwv1s.mp4 41 | s1/video/prii8p.mp4 42 | s1/video/lgbs6n.mp4 43 | s1/video/lwal3s.mp4 44 | s1/video/bwit1a.mp4 45 | s1/video/pbio6p.mp4 46 | s1/video/pwwq9s.mp4 47 | s1/video/lrbe6n.mp4 48 | s1/video/pwaq3a.mp4 49 | s1/video/lbad6n.mp4 50 | s1/video/sgii3s.mp4 51 | s2/video/pwwezs.mp4 52 | s2/video/lwar7p.mp4 53 | s2/video/pbao7n.mp4 54 | s2/video/lwae9p.mp4 55 | s2/video/bbbz9p.mp4 56 | s2/video/pgby4s.mp4 57 | s2/video/lgws9n.mp4 58 | s2/video/pbwp7p.mp4 59 | s2/video/sgbi9n.mp4 60 | s2/video/lwbszs.mp4 61 | s2/video/priv6a.mp4 62 | s2/video/sbaa5p.mp4 63 | s2/video/prwd5p.mp4 64 | s2/video/lgal7n.mp4 65 | s2/video/swio1p.mp4 66 | s2/video/lbaq3n.mp4 67 | s2/video/pwbq6a.mp4 68 | s2/video/prbc9n.mp4 69 | s2/video/brwt6a.mp4 70 | s2/video/bgaa5n.mp4 71 | s2/video/brwa4a.mp4 72 | s2/video/lgwaza.mp4 73 | s2/video/swah9n.mp4 74 | s2/video/bgag9n.mp4 75 | s2/video/pbih9n.mp4 76 | s2/video/lbby1n.mp4 77 | s2/video/brif6a.mp4 78 | s2/video/pwip7p.mp4 79 | s2/video/lbbk6a.mp4 80 | s2/video/bbal5n.mp4 81 | s2/video/lbij6s.mp4 82 | s2/video/bbie7n.mp4 83 | s2/video/brwg8a.mp4 84 | s2/video/lgaz8a.mp4 85 | s2/video/prip1p.mp4 86 | s2/video/pwwy3p.mp4 87 | s2/video/swao6a.mp4 88 | s2/video/lgal9p.mp4 89 | s2/video/lrwz4a.mp4 90 | s2/video/prbx1n.mp4 91 | s2/video/sban4a.mp4 92 | s2/video/pgad7n.mp4 93 | s2/video/lwwf6s.mp4 94 | s2/video/srwi4a.mp4 95 | s2/video/bbar9n.mp4 96 | s2/video/prwq2s.mp4 97 | s2/video/prio9n.mp4 98 | s2/video/srih1p.mp4 99 | s2/video/priv5p.mp4 100 | s2/video/srwv1p.mp4 101 | s4/video/briszs.mp4 102 | s4/video/bbwm1n.mp4 103 | s4/video/sgwxzs.mp4 104 | s4/video/srig7n.mp4 105 | s4/video/lriyza.mp4 106 | s4/video/braf8a.mp4 107 | s4/video/pwbx8a.mp4 108 | s4/video/pbai1n.mp4 109 | s4/video/bris2a.mp4 110 | s4/video/lbwe3p.mp4 111 | s4/video/bgia1p.mp4 112 | s4/video/bwbm9n.mp4 113 | s4/video/pwij1p.mp4 114 | s4/video/brwt4a.mp4 115 | s4/video/pbbi6s.mp4 116 | s4/video/lrby7p.mp4 117 | s4/video/pwbx7p.mp4 118 | s4/video/sgio6a.mp4 119 | s4/video/sgbp1n.mp4 120 | s4/video/swih4s.mp4 121 | s4/video/bbbf5p.mp4 122 | s4/video/lbwe2s.mp4 123 | s4/video/prap3p.mp4 124 | s4/video/bbir6a.mp4 125 | s4/video/lbidzs.mp4 126 | s4/video/prii6a.mp4 127 | s4/video/bgim9p.mp4 128 | s4/video/brwm8s.mp4 129 | s4/video/pbwp5p.mp4 130 | s4/video/pgaq3n.mp4 131 | s4/video/lwbe9n.mp4 132 | s4/video/briz5p.mp4 133 | s4/video/pgij5n.mp4 134 | s4/video/swbi4a.mp4 135 | s4/video/pwwq5n.mp4 136 | s4/video/lway9p.mp4 137 | s4/video/swbo8a.mp4 138 | s4/video/bwbg5n.mp4 139 | s4/video/bgit2s.mp4 140 | s4/video/pwijzs.mp4 141 | s4/video/lrbr3p.mp4 142 | s4/video/priv4a.mp4 143 | s4/video/pric2a.mp4 144 | s4/video/lbwk7p.mp4 145 | s4/video/pwiv9p.mp4 146 | s4/video/lbwy3n.mp4 147 | s4/video/sgwp6s.mp4 148 | s4/video/lwbz1n.mp4 149 | s4/video/srbb2s.mp4 150 | s4/video/prai8s.mp4 151 | s29/video/sgil8n.mp4 152 | s29/video/lbbozn.mp4 153 | s29/video/lbwi3a.mp4 154 | s29/video/sgam2n.mp4 155 | s29/video/lrwc3s.mp4 156 | s29/video/sbwr7s.mp4 157 | s29/video/prwhzn.mp4 158 | s29/video/lriu4p.mp4 159 | s29/video/lgic5a.mp4 160 | s29/video/bbwj7s.mp4 161 | s29/video/lrwv4n.mp4 162 | s29/video/bbbx3a.mp4 163 | s29/video/lrbi2n.mp4 164 | s29/video/brbkzp.mp4 165 | s29/video/swil4p.mp4 166 | s29/video/prwh2p.mp4 167 | s29/video/brap8n.mp4 168 | s29/video/bbipzp.mp4 169 | s29/video/pwig5s.mp4 170 | s29/video/pbwm8n.mp4 171 | s29/video/brbq4p.mp4 172 | s29/video/bgbl1s.mp4 173 | s29/video/lwio6p.mp4 174 | s29/video/bbap2n.mp4 175 | s29/video/bbai8n.mp4 176 | s29/video/lwbv8p.mp4 177 | s29/video/sgam4p.mp4 178 | s29/video/sbwe8n.mp4 179 | s29/video/briv9s.mp4 180 | s29/video/bwij6n.mp4 181 | s29/video/lwai5s.mp4 182 | s29/video/pwim9s.mp4 183 | s29/video/lbwo6p.mp4 184 | s29/video/srws5a.mp4 185 | s29/video/brap9s.mp4 186 | s29/video/pbbtzp.mp4 187 | s29/video/pgih3a.mp4 188 | s29/video/bgws1a.mp4 189 | s29/video/sbbe6p.mp4 190 | s29/video/lgwd6p.mp4 191 | s29/video/bgblzn.mp4 192 | s29/video/bwiqzn.mp4 193 | s29/video/pbwt2n.mp4 194 | s29/video/brbx8p.mp4 195 | s29/video/sriy5s.mp4 196 | s29/video/bgae5a.mp4 197 | s29/video/sgif6p.mp4 198 | s29/video/prim5a.mp4 199 | s29/video/lgac9a.mp4 200 | s29/video/srwz6n.mp4 201 | -------------------------------------------------------------------------------- /ASR_model/LRW/src/data/vid_aud_lrw_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchaudio 8 | import torchvision 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader, Dataset 11 | from librosa.filters import mel as librosa_mel_fn 12 | from src.data.stft import STFT 13 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression 14 | import glob, math 15 | from scipy import signal 16 | import librosa 17 | 18 | log1e5 = math.log(1e-5) 19 | 20 | class MultiDataset(Dataset): 21 | def __init__(self, lrw, mode, max_v_timesteps=155, augmentations=False, num_mel_bins=80, wav=False): 22 | self.max_v_timesteps = max_v_timesteps 23 | self.augmentations = augmentations if mode == 'train' else False 24 | self.num_mel_bins = num_mel_bins 25 | self.skip_long_samples = True 26 | self.wav = wav 27 | self.file_paths, self.word_list = self.build_file_list(lrw) 28 | self.word2int = {word: index for index, word in self.word_list.items()} 29 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=80, sampling_rate=16000, mel_fmin=55., mel_fmax=7600.) 30 | 31 | def build_file_list(self, lrw): 32 | word = {} 33 | # data_dir: spec_mel (or wav) / class / test (train, val) / class_#.npz(or .wav) 34 | if self.wav: 35 | files = sorted(glob.glob(os.path.join(lrw, '*', '*', '*.wav'))) 36 | else: 37 | files = sorted(glob.glob(os.path.join(lrw, '*', '*', '*.npz'))) 38 | 39 | with open('./data/class.txt', 'r') as f: 40 | lines = f.readlines() 41 | for i, l in enumerate(lines): 42 | word[i] = l.strip().upper() 43 | 44 | return files, word 45 | 46 | def __len__(self): 47 | return len(self.file_paths) 48 | 49 | def __getitem__(self, idx): 50 | file_path = self.file_paths[idx] 51 | content = os.path.split(file_path)[-1].split('_')[0].upper() 52 | target = self.word2int[content] 53 | 54 | if self.wav: 55 | aud, sr = torchaudio.load(file_path) 56 | if round(sr) != 16000: 57 | aud = torch.tensor(librosa.resample(aud.squeeze(0).numpy(), sr, 16000)).unsqueeze(0) 58 | 59 | aud = aud / torch.abs(aud).max() * 0.9 60 | aud = torch.FloatTensor(self.preemphasize(aud.squeeze(0))).unsqueeze(0) 61 | aud = torch.clamp(aud, min=-1, max=1) 62 | 63 | spec = self.stft.mel_spectrogram(aud) 64 | 65 | else: 66 | data = np.load(file_path) 67 | spec = data['mel'] 68 | data.close() 69 | 70 | spec = torch.FloatTensor(self.denormalize(spec)) 71 | 72 | spec = spec[:, :, :self.max_v_timesteps * 4] 73 | num_a_frames = spec.size(2) 74 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(spec) 75 | 76 | assert spec.size(2) == 116 77 | return spec, target 78 | 79 | def preemphasize(self, aud): 80 | aud = signal.lfilter([1, -0.97], [1], aud) 81 | return aud 82 | 83 | def denormalize(self, melspec): 84 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5 85 | return melspec 86 | 87 | class TacotronSTFT(torch.nn.Module): 88 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 89 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 90 | mel_fmax=8000.0): 91 | super(TacotronSTFT, self).__init__() 92 | self.n_mel_channels = n_mel_channels 93 | self.sampling_rate = sampling_rate 94 | self.stft_fn = STFT(filter_length, hop_length, win_length) 95 | mel_basis = librosa_mel_fn( 96 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 97 | mel_basis = torch.from_numpy(mel_basis).float() 98 | self.register_buffer('mel_basis', mel_basis) 99 | 100 | def spectral_normalize(self, magnitudes): 101 | output = dynamic_range_compression(magnitudes) 102 | return output 103 | 104 | def spectral_de_normalize(self, magnitudes): 105 | output = dynamic_range_decompression(magnitudes) 106 | return output 107 | 108 | def mel_spectrogram(self, y): 109 | """Computes mel-spectrograms from a batch of waves 110 | PARAMS 111 | ------ 112 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 113 | RETURNS 114 | ------- 115 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 116 | """ 117 | assert(torch.min(y.data) >= -1) 118 | assert(torch.max(y.data) <= 1) 119 | 120 | magnitudes, phases = self.stft_fn.transform(y) 121 | magnitudes = magnitudes.data 122 | mel_output = torch.matmul(self.mel_basis, magnitudes) 123 | mel_output = self.spectral_normalize(mel_output) 124 | return mel_output -------------------------------------------------------------------------------- /README_LRS.md: -------------------------------------------------------------------------------- 1 | We provide the training code and trained VCA-GAN model on LRS2 and LRS3.
2 | The performances are reported in our ICASSP23 paper '[Lip-to-Speech Synthesis in the Wild with Multi-task Learning](https://arxiv.org/abs/2302.08841)'. 3 | ### Datasets 4 | #### Download 5 | LRS2/LRS3 dataset can be downloaded from the below link. 6 | - https://www.robots.ox.ac.uk/~vgg/data/lip_reading/ 7 | 8 | For data preprocessing, download the lip coordinate of LRS2 and LRS3 from the below links. 9 | - [LRS2](https://drive.google.com/file/d/10cnzNRRr-LQbS5kjc393FLvmNxPJ_u1N/view?usp=drive_link) 10 | - [LRS3](https://drive.google.com/file/d/10eAVKBuy7TyslcPdv4xmf5dxSYx4NMrS/view?usp=drive_link) 11 | 12 | Unzip and put the files to 13 | ``` 14 | ./data/LRS2/LRS2_crop/*.txt 15 | ./data/LRS3/LRS3_crop/*.txt 16 | ``` 17 | 18 | #### Preprocessing 19 | After download the dataset, extract audio file (.wav) from the video using `./preprocess/Extract_audio_LRS.py`. 20 | 21 | ```shell 22 | python ./preprocess/Extract_audio_LRS.py \ 23 | --data_path 'original_video_path/LRS2-BBC' \ 24 | --save_path 'path_to_save/LRS2-BBC_audio' \ 25 | --data_type 'LRS2 or LRS3' 26 | ``` 27 | 28 | We suppose the data directory is constructed as 29 | ``` 30 | LRS2-BBC 31 | ├── main 32 | | ├── * 33 | | | └── *.mp4 34 | | | └── *.txt 35 | 36 | LRS2-BBC_audio 37 | ├── main 38 | | ├── * 39 | | | └── *.wav 40 | ``` 41 | 42 | ``` 43 | LRS3-TED 44 | ├── trainval 45 | | ├── * 46 | | | └── *.mp4 47 | | | └── *.txt 48 | 49 | LRS2-TED_audio 50 | ├── trainval 51 | | ├── * 52 | | | └── *.wav 53 | ``` 54 | 55 | Moreover, put the train/val/test splits to
56 | ``` 57 | ./data/LRS2/*.txt 58 | ./data/LRS3/*.txt 59 | ``` 60 | 61 | For the LRS2, we use the original splits of the dataset provided.
62 | For the LRS3, we use the unseen splits setting of [SVTS](https://arxiv.org/abs/2205.02058), where they are placed in the directory already. 63 | 64 | ## Training the Model 65 | `data_name` argument is used to choose which dataset will be used. (LRS2 or LRS3)
66 | To train the model, run following command: 67 | 68 | ```shell 69 | # Data Parallel training example using 4 GPUs on LRS2 70 | python train_LRS.py \ 71 | --data '/data_dir_as_like/LRS2-BBC' \ 72 | --data_name 'LRS2' 73 | --checkpoint_dir 'enter_the_path_to_save' \ 74 | --batch_size 80 \ 75 | --epochs 200 \ 76 | --dataparallel \ 77 | --gpu 0,1,2,3 78 | ``` 79 | 80 | ```shell 81 | # 1 GPU training example on LRS3 82 | python train_LRS.py \ 83 | --data '/data_dir_as_like/LRS3-TED' \ 84 | --data_name 'LRS3' 85 | --checkpoint_dir 'enter_the_path_to_save' \ 86 | --batch_size 80 \ 87 | --epochs 200 \ 88 | --gpu 0 89 | ``` 90 | 91 | Descriptions of training parameters are as follows: 92 | - `--data`: Dataset location (LRS2 or LRS3) 93 | - `--data_name`: Choose to train on LRS2 or LRS3 94 | - `--checkpoint_dir`: directory for saving checkpoints 95 | - `--checkpoint` : saved checkpoint where the training is resumed from 96 | - `--batch_size`: batch size 97 | - `--epochs`: number of epochs 98 | - `--augmentations`: whether performing augmentation 99 | - `--dataparallel`: Use DataParallel 100 | - `--gpu`: gpu number for training 101 | - `--lr`: learning rate 102 | - `--window_size`: number of frames to be used for training 103 | - Refer to `train_LRS3.py` for the other training parameters 104 | 105 | The evaluation during training is performed for a subset of the validation dataset due to the heavy time costs of waveform conversion (griffin-lim).
106 | In order to evaluate the entire performance of the trained model run the test code (refer to "Testing the Model" section). 107 | 108 | ### check the training logs 109 | ```shell 110 | tensorboard --logdir='./runs/logs to watch' --host='ip address of the server' 111 | ``` 112 | The tensorboard shows the training and validation loss, evaluation metrics, generated mel-spectrogram, and audio 113 | 114 | 115 | ## Testing the Model 116 | To test the model, run following command: 117 | ```shell 118 | # test example on LRS2 119 | python test_LRS.py \ 120 | --data 'data_directory_path' \ 121 | --data_name 'LRS2' 122 | --checkpoint 'enter_the_checkpoint_path' \ 123 | --batch_size 20 \ 124 | --save_mel \ 125 | --save_wav \ 126 | --gpu 0 127 | ``` 128 | 129 | Descriptions of training parameters are as follows: 130 | - `--data`: Dataset location (LRS2 or LRS3) 131 | - `--data_name`: Choose to train on LRS2 or LRS3 132 | - `--checkpoint` : saved checkpoint where the training is resumed from 133 | - `--batch_size`: batch size 134 | - `--dataparallel`: Use DataParallel 135 | - `--gpu`: gpu number for training 136 | - `--save_mel`: whether to save the 'mel_spectrogram' and 'spectrogram' in `.npz` format 137 | - `--save_wav`: whether to save the 'waveform' in `.wav` format 138 | - Refer to `test.py` for the other parameters 139 | 140 | 141 | ## Pre-trained model checkpoints 142 | We provide pre-trained VCA-GAN models trained on LRS2 and LRS3.
143 | The performances are reported in our ICASSP23 [paper](https://arxiv.org/abs/2302.08841). 144 | 145 | | Dataset | STOI | 146 | |:-------------------:|:--------:| 147 | |LRS2 | [0.407](https://drive.google.com/file/d/11ixych8CrmrrUWsHC35LG0nipFEOZw-W/view?usp=sharing) | 148 | |LRS3 | [0.474](https://drive.google.com/file/d/11jUyjr_hnsaeOzxUDcVywhUTl3lemXT5/view?usp=sharing) | 149 | -------------------------------------------------------------------------------- /preprocess/Preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import glob 4 | import cv2 5 | import torchvision 6 | import torch 7 | from skimage import transform 8 | import numpy as np 9 | from scipy import signal 10 | from torch.utils.data import DataLoader, Dataset 11 | import torchvision.transforms.functional as F 12 | from torchvision import transforms 13 | import librosa 14 | import soundfile as sf 15 | import argparse 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--Data_dir', type=str, default="Data dir of images and audio of GRID") 20 | parser.add_argument('--Landmark', type=str, default="Data dir of GRID Landmark") 21 | parser.add_argument('--FPS', type=int, default=25, help="25 for GRID") 22 | parser.add_argument('--reference', type=str, default='./Ref_face.txt') 23 | parser.add_argument("--Output_dir", type=str, default='Output dir Ex) ./GRID_processed') 24 | args = parser.parse_args() 25 | return args 26 | 27 | args = parse_args() 28 | eps = 1e-8 29 | 30 | class Crop(object): 31 | def __init__(self, crop): 32 | self.crop = crop 33 | 34 | def __call__(self, img): 35 | return img.crop(self.crop) 36 | 37 | ###################################################### 38 | f = open(args.reference, 'r') 39 | lm = f.readlines()[0] 40 | f.close() 41 | lm = lm.split(':')[-1].split('|')[6] 42 | lms = lm.split(',') 43 | temp_lm = [] 44 | 45 | for lm in lms: 46 | x, y = lm.split() 47 | temp_lm.append([x, y]) 48 | temp_lm = np.array(temp_lm, dtype=float) 49 | refer_lm = temp_lm 50 | ###################################################### 51 | 52 | 53 | class Preprocessing(Dataset): 54 | def __init__(self,): 55 | self.file_paths = self.build_file_list() 56 | 57 | def build_file_list(self): 58 | file_list = [] 59 | landmarks = sorted(glob.glob(os.path.join(args.Landmark, '*', '*', '*.txt'))) 60 | for lm in landmarks: 61 | if not os.path.exists(lm.replace(args.Landmark, args.Output_dir)[:-4] + '.mp4'): 62 | file_list.append(lm) 63 | return file_list 64 | 65 | def __len__(self): 66 | return len(self.file_paths) 67 | 68 | def __getitem__(self, idx): 69 | file_path = self.file_paths[idx] 70 | ims = sorted(glob.glob(os.path.join(file_path.replace(args.Landmark, args.Data_dir)[:-4], '*.png'))) 71 | frames = [] 72 | 73 | for im in ims: 74 | frames += [cv2.cvtColor(cv2.imread(im), cv2.COLOR_BGR2RGB)] 75 | v = np.stack(frames, 0) 76 | 77 | t, f_name = os.path.split(file_path) 78 | t, m_name = os.path.split(t) 79 | _, s_name = os.path.split(t) 80 | save_path = os.path.join(args.Output_dir, s_name, m_name) 81 | try: 82 | with open(file_path, 'r', encoding='utf-8') as lf: 83 | lms = lf.readlines()[0] 84 | except: 85 | with open(file_path, 'r', encoding='cp949') as lf: 86 | lms = lf.readlines()[0] 87 | lms = lms.split(':')[-1].split('|') 88 | assert v.shape[0] == len(lms), 'the video frame length differs to the landmark frames' 89 | 90 | aligned_video = [] 91 | for i, frame in enumerate(v): 92 | lm = lms[i].split(',') 93 | temp_lm = [] 94 | for l in lm: 95 | x, y = l.split() 96 | temp_lm.append([x, y]) 97 | temp_lm = np.array(temp_lm, dtype=float) # 98,2 98 | 99 | source_lm = temp_lm 100 | 101 | tform.estimate(source_lm, refer_lm) 102 | mat = tform.params[0:2, :] 103 | aligned_im = cv2.warpAffine(frame, mat, (np.shape(frame)[0], np.shape(frame)[1])) 104 | aligned_video += [aligned_im[:256, :256, :]] 105 | 106 | aligned_video = np.array(aligned_video) 107 | 108 | #### audio preprocessing #### 109 | aud, _ = librosa.load(os.path.join(file_path.replace(args.Landmark, args.Data_dir).replace('video', 'audio')[:-4] + '.wav'), 16000) 110 | fc = 55. # Cut-off frequency of the filter 111 | w = fc / (16000 / 2) # Normalize the frequency 112 | b, a = signal.butter(7, w, 'high') 113 | aud = signal.filtfilt(b, a, aud) 114 | 115 | return torch.tensor(aligned_video), save_path, f_name, torch.tensor(aud.copy()) 116 | 117 | Data = Preprocessing() 118 | Data_loader = DataLoader(Data, shuffle=False, batch_size=1, num_workers=3) 119 | tform = transform.SimilarityTransform() 120 | 121 | for kk, data in enumerate(Data_loader): 122 | cropped_video, save_path, f_name, aud = data 123 | aud = aud[0] 124 | cropped_video = cropped_video[0] 125 | save_path = save_path[0] 126 | f_name = f_name[0] 127 | if not os.path.exists(save_path): 128 | os.makedirs(save_path) 129 | if not os.path.exists(save_path.replace('video', 'audio')): 130 | os.makedirs(save_path.replace('video', 'audio')) 131 | torchvision.io.write_video(os.path.join(save_path, f_name[:-4] + '.mp4'), video_array=cropped_video, fps=args.FPS) 132 | sf.write(os.path.join(save_path.replace('video', 'audio'), f_name[:-4] + ".flac"), aud.numpy(), samplerate=16000) 133 | print('##########', kk + 1, ' / ', len(Data_loader), '##########') 134 | -------------------------------------------------------------------------------- /ASR_model/LRW/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | from torch import nn, optim 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | from src.models.audio_front import Audio_front 8 | from src.models.classifier import Backend 9 | import os 10 | from torch.utils.data import DataLoader 11 | from torch.nn import functional as F 12 | from src.data.vid_aud_lrw_test import MultiDataset 13 | from torch.nn import DataParallel as DP 14 | import torch.nn.parallel 15 | import math 16 | from matplotlib import pyplot as plt 17 | import time 18 | import glob 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--data', default="TEST_DIR", help='./../../test/spec_mel') 23 | parser.add_argument('--wav', default=False, action='store_true') 24 | parser.add_argument("--checkpoint_dir", type=str, default='./data') 25 | parser.add_argument("--checkpoint", type=str, default='./data/LRW_acc_0.98464.ckpt') 26 | parser.add_argument("--batch_size", type=int, default=320) 27 | parser.add_argument("--epochs", type=int, default=100) 28 | parser.add_argument("--lr", type=float, default=0.01) 29 | parser.add_argument("--weight_decay", type=float, default=0.00001) 30 | parser.add_argument("--workers", type=int, default=5) 31 | parser.add_argument("--resnet", type=int, default=18) 32 | parser.add_argument("--seed", type=int, default=1) 33 | 34 | parser.add_argument("--max_timesteps", type=int, default=29) 35 | 36 | parser.add_argument("--dataparallel", default=False, action='store_true') 37 | parser.add_argument("--gpu", type=str, default='0') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def train_net(args): 43 | torch.backends.cudnn.deterministic = False 44 | torch.backends.cudnn.benchmark = True 45 | torch.manual_seed(args.seed) 46 | torch.cuda.manual_seed_all(args.seed) 47 | random.seed(args.seed) 48 | os.environ['OMP_NUM_THREADS'] = '2' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 50 | 51 | a_front = Audio_front(in_channels=1) 52 | a_back = Backend() 53 | 54 | if args.checkpoint is not None: 55 | print(f"Loading checkpoint: {args.checkpoint}") 56 | checkpoint = torch.load(args.checkpoint) 57 | a_front.load_state_dict(checkpoint['a_front_state_dict']) 58 | a_back.load_state_dict(checkpoint['a_back_state_dict']) 59 | del checkpoint 60 | 61 | a_front.cuda() 62 | a_back.cuda() 63 | 64 | if args.dataparallel: 65 | a_front = DP(a_front) 66 | a_back = DP(a_back) 67 | 68 | _ = validate(a_front, a_back, epoch=0, writer=None) 69 | 70 | def validate(a_front, a_back, fast_validate=False, epoch=0, writer=None): 71 | with torch.no_grad(): 72 | a_front.eval() 73 | a_back.eval() 74 | 75 | val_data = MultiDataset( 76 | lrw=args.data, 77 | mode='test', 78 | max_v_timesteps=args.max_timesteps, 79 | augmentations=False, 80 | wav=args.wav 81 | ) 82 | 83 | dataloader = DataLoader( 84 | val_data, 85 | shuffle=False, 86 | batch_size=args.batch_size * 2, 87 | num_workers=args.workers, 88 | drop_last=False 89 | ) 90 | 91 | criterion = nn.CrossEntropyLoss().cuda() 92 | batch_size = dataloader.batch_size 93 | if fast_validate: 94 | samples = min(2 * batch_size, int(len(dataloader.dataset))) 95 | max_batches = 2 96 | else: 97 | samples = int(len(dataloader.dataset)) 98 | max_batches = int(len(dataloader)) 99 | 100 | val_loss = [] 101 | tot_cor, tot_v_cor, tot_a_cor, tot_num = 0, 0, 0, 0 102 | 103 | description = 'Check validation step' if fast_validate else 'Validation' 104 | print(description) 105 | for i, batch in enumerate(dataloader): 106 | if i % 10 == 0: 107 | if not fast_validate: 108 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples)) 109 | a_in, target = batch 110 | 111 | a_feat = a_front(a_in.cuda()) # S,B,51 112 | a_pred = a_back(a_feat) 113 | 114 | loss = criterion(a_pred, target.long().cuda()).cpu().item() 115 | prediction = torch.argmax(a_pred.cpu(), dim=1).numpy() 116 | tot_cor += np.sum(prediction == target.long().numpy()) 117 | tot_num += len(prediction) 118 | 119 | batch_size = a_pred.size(0) 120 | val_loss.append(loss) 121 | 122 | if i >= max_batches: 123 | break 124 | 125 | if writer is not None: 126 | writer.add_scalar('Val/loss', np.mean(np.array(val_loss)), epoch) 127 | writer.add_scalar('Val/acc', tot_cor / tot_num, epoch) 128 | 129 | a_front.train() 130 | a_back.train() 131 | print('test_ACC:', tot_cor / tot_num, 'WER:', 1. - tot_cor / tot_num) 132 | if fast_validate: 133 | return {} 134 | else: 135 | return np.mean(np.array(val_loss)), tot_cor / tot_num 136 | 137 | if __name__ == "__main__": 138 | args = parse_args() 139 | train_net(args) 140 | 141 | -------------------------------------------------------------------------------- /src/data/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | Copyright (c) 2017, Prem Seetharaman 4 | All rights reserved. 5 | * Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this 10 | list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | * Neither the name of the copyright holder nor the names of its 13 | contributors may be used to endorse or promote products derived from this 14 | software without specific prior written permission. 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | import torch 28 | import numpy as np 29 | import torch.nn.functional as F 30 | from torch.autograd import Variable 31 | from scipy.signal import get_window 32 | from librosa.util import pad_center, tiny 33 | from src.data.audio_processing import window_sumsquare 34 | 35 | class STFT(torch.nn.Module): 36 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 37 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 38 | window='hann'): 39 | super(STFT, self).__init__() 40 | self.filter_length = filter_length 41 | self.hop_length = hop_length 42 | self.win_length = win_length 43 | self.window = window 44 | self.forward_transform = None 45 | scale = self.filter_length / self.hop_length 46 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 47 | 48 | cutoff = int((self.filter_length / 2 + 1)) 49 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 50 | np.imag(fourier_basis[:cutoff, :])]) 51 | 52 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 53 | inverse_basis = torch.FloatTensor( 54 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 55 | 56 | if window is not None: 57 | assert(filter_length >= win_length) 58 | # get window and zero center pad it to filter_length 59 | fft_window = get_window(window, win_length, fftbins=True) 60 | fft_window = pad_center(fft_window, filter_length) 61 | fft_window = torch.from_numpy(fft_window).float() 62 | 63 | # window the bases 64 | forward_basis *= fft_window 65 | inverse_basis *= fft_window 66 | 67 | self.register_buffer('forward_basis', forward_basis.float()) 68 | self.register_buffer('inverse_basis', inverse_basis.float()) 69 | 70 | def transform(self, input_data): 71 | num_batches = input_data.size(0) 72 | num_samples = input_data.size(1) 73 | 74 | self.num_samples = num_samples 75 | 76 | # similar to librosa, reflect-pad the input 77 | input_data = input_data.view(num_batches, 1, num_samples) 78 | input_data = F.pad( 79 | input_data.unsqueeze(1), 80 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 81 | mode='reflect') 82 | input_data = input_data.squeeze(1) 83 | 84 | forward_transform = F.conv1d( 85 | input_data, 86 | Variable(self.forward_basis, requires_grad=False), 87 | stride=self.hop_length, 88 | padding=0) 89 | 90 | cutoff = int((self.filter_length / 2) + 1) 91 | real_part = forward_transform[:, :cutoff, :] 92 | imag_part = forward_transform[:, cutoff:, :] 93 | 94 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 95 | phase = torch.autograd.Variable( 96 | torch.atan2(imag_part.data, real_part.data)) 97 | 98 | return magnitude, phase 99 | 100 | def inverse(self, magnitude, phase): 101 | recombine_magnitude_phase = torch.cat( 102 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 103 | 104 | inverse_transform = F.conv_transpose1d( 105 | recombine_magnitude_phase, 106 | Variable(self.inverse_basis, requires_grad=False), 107 | stride=self.hop_length, 108 | padding=0) 109 | 110 | if self.window is not None: 111 | window_sum = window_sumsquare( 112 | self.window, magnitude.size(-1), hop_length=self.hop_length, 113 | win_length=self.win_length, n_fft=self.filter_length, 114 | dtype=np.float32) 115 | # remove modulation effects 116 | approx_nonzero_indices = torch.from_numpy( 117 | np.where(window_sum > tiny(window_sum))[0]) 118 | window_sum = torch.autograd.Variable( 119 | torch.from_numpy(window_sum), requires_grad=False) 120 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 121 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 122 | 123 | # scale by hop ratio 124 | inverse_transform *= float(self.filter_length) / self.hop_length 125 | 126 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 127 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 128 | 129 | return inverse_transform 130 | 131 | def forward(self, input_data): 132 | self.magnitude, self.phase = self.transform(input_data) 133 | reconstruction = self.inverse(self.magnitude, self.phase) 134 | return reconstruction -------------------------------------------------------------------------------- /ASR_model/GRID/src/data/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | Copyright (c) 2017, Prem Seetharaman 4 | All rights reserved. 5 | * Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this 10 | list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | * Neither the name of the copyright holder nor the names of its 13 | contributors may be used to endorse or promote products derived from this 14 | software without specific prior written permission. 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | import torch 28 | import numpy as np 29 | import torch.nn.functional as F 30 | from torch.autograd import Variable 31 | from scipy.signal import get_window 32 | from librosa.util import pad_center, tiny 33 | from src.data.audio_processing import window_sumsquare 34 | 35 | class STFT(torch.nn.Module): 36 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 37 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 38 | window='hann'): 39 | super(STFT, self).__init__() 40 | self.filter_length = filter_length 41 | self.hop_length = hop_length 42 | self.win_length = win_length 43 | self.window = window 44 | self.forward_transform = None 45 | scale = self.filter_length / self.hop_length 46 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 47 | 48 | cutoff = int((self.filter_length / 2 + 1)) 49 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 50 | np.imag(fourier_basis[:cutoff, :])]) 51 | 52 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 53 | inverse_basis = torch.FloatTensor( 54 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 55 | 56 | if window is not None: 57 | assert(filter_length >= win_length) 58 | # get window and zero center pad it to filter_length 59 | fft_window = get_window(window, win_length, fftbins=True) 60 | fft_window = pad_center(fft_window, filter_length) 61 | fft_window = torch.from_numpy(fft_window).float() 62 | 63 | # window the bases 64 | forward_basis *= fft_window 65 | inverse_basis *= fft_window 66 | 67 | self.register_buffer('forward_basis', forward_basis.float()) 68 | self.register_buffer('inverse_basis', inverse_basis.float()) 69 | 70 | def transform(self, input_data): 71 | num_batches = input_data.size(0) 72 | num_samples = input_data.size(1) 73 | 74 | self.num_samples = num_samples 75 | 76 | # similar to librosa, reflect-pad the input 77 | input_data = input_data.view(num_batches, 1, num_samples) 78 | input_data = F.pad( 79 | input_data.unsqueeze(1), 80 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 81 | mode='reflect') 82 | input_data = input_data.squeeze(1) 83 | 84 | forward_transform = F.conv1d( 85 | input_data, 86 | Variable(self.forward_basis, requires_grad=False), 87 | stride=self.hop_length, 88 | padding=0) 89 | 90 | cutoff = int((self.filter_length / 2) + 1) 91 | real_part = forward_transform[:, :cutoff, :] 92 | imag_part = forward_transform[:, cutoff:, :] 93 | 94 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 95 | phase = torch.autograd.Variable( 96 | torch.atan2(imag_part.data, real_part.data)) 97 | 98 | return magnitude, phase 99 | 100 | def inverse(self, magnitude, phase): 101 | recombine_magnitude_phase = torch.cat( 102 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 103 | 104 | inverse_transform = F.conv_transpose1d( 105 | recombine_magnitude_phase, 106 | Variable(self.inverse_basis, requires_grad=False), 107 | stride=self.hop_length, 108 | padding=0) 109 | 110 | if self.window is not None: 111 | window_sum = window_sumsquare( 112 | self.window, magnitude.size(-1), hop_length=self.hop_length, 113 | win_length=self.win_length, n_fft=self.filter_length, 114 | dtype=np.float32) 115 | # remove modulation effects 116 | approx_nonzero_indices = torch.from_numpy( 117 | np.where(window_sum > tiny(window_sum))[0]) 118 | window_sum = torch.autograd.Variable( 119 | torch.from_numpy(window_sum), requires_grad=False) 120 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 121 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 122 | 123 | # scale by hop ratio 124 | inverse_transform *= float(self.filter_length) / self.hop_length 125 | 126 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 127 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 128 | 129 | return inverse_transform 130 | 131 | def forward(self, input_data): 132 | self.magnitude, self.phase = self.transform(input_data) 133 | reconstruction = self.inverse(self.magnitude, self.phase) 134 | return reconstruction -------------------------------------------------------------------------------- /ASR_model/LRW/src/data/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | Copyright (c) 2017, Prem Seetharaman 4 | All rights reserved. 5 | * Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright notice, this 10 | list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | * Neither the name of the copyright holder nor the names of its 13 | contributors may be used to endorse or promote products derived from this 14 | software without specific prior written permission. 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | import torch 28 | import numpy as np 29 | import torch.nn.functional as F 30 | from torch.autograd import Variable 31 | from scipy.signal import get_window 32 | from librosa.util import pad_center, tiny 33 | from src.data.audio_processing import window_sumsquare 34 | 35 | class STFT(torch.nn.Module): 36 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 37 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 38 | window='hann'): 39 | super(STFT, self).__init__() 40 | self.filter_length = filter_length 41 | self.hop_length = hop_length 42 | self.win_length = win_length 43 | self.window = window 44 | self.forward_transform = None 45 | scale = self.filter_length / self.hop_length 46 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 47 | 48 | cutoff = int((self.filter_length / 2 + 1)) 49 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 50 | np.imag(fourier_basis[:cutoff, :])]) 51 | 52 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 53 | inverse_basis = torch.FloatTensor( 54 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 55 | 56 | if window is not None: 57 | assert(filter_length >= win_length) 58 | # get window and zero center pad it to filter_length 59 | fft_window = get_window(window, win_length, fftbins=True) 60 | fft_window = pad_center(fft_window, filter_length) 61 | fft_window = torch.from_numpy(fft_window).float() 62 | 63 | # window the bases 64 | forward_basis *= fft_window 65 | inverse_basis *= fft_window 66 | 67 | self.register_buffer('forward_basis', forward_basis.float()) 68 | self.register_buffer('inverse_basis', inverse_basis.float()) 69 | 70 | def transform(self, input_data): 71 | num_batches = input_data.size(0) 72 | num_samples = input_data.size(1) 73 | 74 | self.num_samples = num_samples 75 | 76 | # similar to librosa, reflect-pad the input 77 | input_data = input_data.view(num_batches, 1, num_samples) 78 | input_data = F.pad( 79 | input_data.unsqueeze(1), 80 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 81 | mode='reflect') 82 | input_data = input_data.squeeze(1) 83 | 84 | forward_transform = F.conv1d( 85 | input_data, 86 | Variable(self.forward_basis, requires_grad=False), 87 | stride=self.hop_length, 88 | padding=0) 89 | 90 | cutoff = int((self.filter_length / 2) + 1) 91 | real_part = forward_transform[:, :cutoff, :] 92 | imag_part = forward_transform[:, cutoff:, :] 93 | 94 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 95 | phase = torch.autograd.Variable( 96 | torch.atan2(imag_part.data, real_part.data)) 97 | 98 | return magnitude, phase 99 | 100 | def inverse(self, magnitude, phase): 101 | recombine_magnitude_phase = torch.cat( 102 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 103 | 104 | inverse_transform = F.conv_transpose1d( 105 | recombine_magnitude_phase, 106 | Variable(self.inverse_basis, requires_grad=False), 107 | stride=self.hop_length, 108 | padding=0) 109 | 110 | if self.window is not None: 111 | window_sum = window_sumsquare( 112 | self.window, magnitude.size(-1), hop_length=self.hop_length, 113 | win_length=self.win_length, n_fft=self.filter_length, 114 | dtype=np.float32) 115 | # remove modulation effects 116 | approx_nonzero_indices = torch.from_numpy( 117 | np.where(window_sum > tiny(window_sum))[0]) 118 | window_sum = torch.autograd.Variable( 119 | torch.from_numpy(window_sum), requires_grad=False) 120 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 121 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 122 | 123 | # scale by hop ratio 124 | inverse_transform *= float(self.filter_length) / self.hop_length 125 | 126 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 127 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 128 | 129 | return inverse_transform 130 | 131 | def forward(self, input_data): 132 | self.magnitude, self.phase = self.transform(input_data) 133 | reconstruction = self.inverse(self.magnitude, self.phase) 134 | return reconstruction -------------------------------------------------------------------------------- /README_GRID.md: -------------------------------------------------------------------------------- 1 | ### Datasets 2 | #### Download 3 | GRID dataset (video normal) can be downloaded from the below link. 4 | - http://spandh.dcs.shef.ac.uk/gridcorpus/ 5 | 6 | For data preprocessing, download the face landmark of GRID from the below link. 7 | - https://drive.google.com/file/d/10upLpydfbqCJ7t64h210Xx-35B8qCcZr/view?usp=sharing 8 | 9 | #### Preprocessing 10 | After download the dataset, preprocess the dataset with the following scripts in `./preprocess`.
11 | It supposes the data directory is constructed as 12 | ``` 13 | Data_dir 14 | ├── subject 15 | | ├── video 16 | | | └── xxx.mpg 17 | ``` 18 | 19 | 1. Extract frames
20 | `Extract_frames.py` extract images and audio from the video.
21 | ```shell 22 | python Extract_frames.py --Grid_dir "Data dir of GRID_corpus" --Out_dir "Output dir of images and audio of GRID_corpus" 23 | ``` 24 | 25 | 2. Align faces and audio processing
26 | `Preprocess.py` aligns faces and generates videos, which enables cropping the video lip-centered during training.
27 | ```shell 28 | python Preprocess.py \ 29 | --Data_dir "Data dir of extracted images and audio of GRID_corpus" \ 30 | --Landmark "Downloaded landmark dir of GRID" \ 31 | --Output_dir "Output dir of processed data" 32 | ``` 33 | 34 | ## Training the Model 35 | The speaker setting (different subject) can be selected by `subject` argument. Please refer to below examples.
36 | To train the model, run following command: 37 | 38 | ```shell 39 | # Data Parallel training example using 4 GPUs for multi-speaker setting in GRID 40 | python train.py \ 41 | --grid 'enter_the_processed_data_path' \ 42 | --checkpoint_dir 'enter_the_path_to_save' \ 43 | --batch_size 88 \ 44 | --epochs 500 \ 45 | --subject 'overlap' \ 46 | --eval_step 720 \ 47 | --dataparallel \ 48 | --gpu 0,1,2,3 49 | ``` 50 | 51 | ```shell 52 | # 1 GPU training example for GRID for unseen-speaker setting in GRID 53 | python train.py \ 54 | --grid 'enter_the_processed_data_path' \ 55 | --checkpoint_dir 'enter_the_path_to_save' \ 56 | --batch_size 22 \ 57 | --epochs 500 \ 58 | --subject 'unseen' \ 59 | --eval_step 1000 \ 60 | --gpu 0 61 | ``` 62 | 63 | Descriptions of training parameters are as follows: 64 | - `--grid`: Dataset location (grid) 65 | - `--checkpoint_dir`: directory for saving checkpoints 66 | - `--checkpoint` : saved checkpoint where the training is resumed from 67 | - `--batch_size`: batch size 68 | - `--epochs`: number of epochs 69 | - `--augmentations`: whether performing augmentation 70 | - `--dataparallel`: Use DataParallel 71 | - `--subject`: different speaker settings, `s#` is speaker specific training, `overlap` for multi-speaker setting, `unseen` for unseen-speaker setting, `four` for four speaker training 72 | - `--gpu`: gpu number for training 73 | - `--lr`: learning rate 74 | - `--eval_step`: steps for performing evaluation 75 | - `--window_size`: number of frames to be used for training 76 | - Refer to `train.py` for the other training parameters 77 | 78 | The evaluation during training is performed for a subset of the validation dataset due to the heavy time costs of waveform conversion (griffin-lim).
79 | In order to evaluate the entire performance of the trained model run the test code (refer to "Testing the Model" section). 80 | 81 | ### check the training logs 82 | ```shell 83 | tensorboard --logdir='./runs/logs to watch' --host='ip address of the server' 84 | ``` 85 | The tensorboard shows the training and validation loss, evaluation metrics, generated mel-spectrogram, and audio 86 | 87 | 88 | ## Testing the Model 89 | To test the model, run following command: 90 | ```shell 91 | # Dataparallel test example for multi-speaker setting in GRID 92 | python test.py \ 93 | --grid 'enter_the_processed_data_path' \ 94 | --checkpoint 'enter_the_checkpoint_path' \ 95 | --batch_size 100 \ 96 | --subject 'overlap' \ 97 | --save_mel \ 98 | --save_wav \ 99 | --dataparallel \ 100 | --gpu 0,1 101 | ``` 102 | 103 | Descriptions of training parameters are as follows: 104 | - `--grid`: Dataset location (grid) 105 | - `--checkpoint` : saved checkpoint where the training is resumed from 106 | - `--batch_size`: batch size 107 | - `--dataparallel`: Use DataParallel 108 | - `--subject`: different speaker settings, `s#` is speaker specific training, `overlap` for multi-speaker setting, `unseen` for unseen-speaker setting, `four` for four speaker training 109 | - `--save_mel`: whether to save the 'mel_spectrogram' and 'spectrogram' in `.npz` format 110 | - `--save_wav`: whether to save the 'waveform' in `.wav` format 111 | - `--gpu`: gpu number for training 112 | - Refer to `test.py` for the other parameters 113 | 114 | ## Test Automatic Speech Recognition (ASR) results of generated results: WER 115 | Transcription (Ground-truth) of GRID dataset can be downloaded from the below link. 116 | - https://drive.google.com/file/d/112UES-N0OKjj0xV0hCDO9ZaOyPsvnJzR/view?usp=sharing 117 | 118 | move to the ASR_model directory 119 | ```shell 120 | cd ASR_model/GRID 121 | ``` 122 | 123 | To evaluate the WER, run following command: 124 | ```shell 125 | # test example for multi-speaker setting in GRID 126 | python test.py \ 127 | --data 'enter_the_generated_data_dir (mel or wav) (ex. ./../../test/spec_mel)' \ 128 | --gtpath 'enter_the_downloaded_transcription_path' \ 129 | --subject 'overlap' \ 130 | --gpu 0 131 | ``` 132 | 133 | Descriptions of training parameters are as follows: 134 | - `--data`: Data for evaluation (wav or mel(.npz)) 135 | - `--wav` : whether the data is waveform or not 136 | - `--batch_size`: batch size 137 | - `--subject`: different speaker settings, `s#` is speaker specific training, `overlap` for multi-speaker setting, `unseen` for unseen-speaker setting, `four` for four speaker training 138 | - `--gpu`: gpu number for training 139 | - Refer to `./ASR_model/GRID/test.py` for the other parameters 140 | 141 | 142 | ### Pre-trained ASR model checkpoint 143 | Below lists are the pre-trained ASR model to evaluate the generated speech.
144 | WER shows the original performances of the model on ground-truth audio. 145 | 146 | | Setting | WER | 147 | |:-------------------:|:--------:| 148 | |GRID (constrained-speaker) | [0.83 %](https://drive.google.com/file/d/11OyjBnfLU7M3qt98udIkUiGZ6frHLwP1/view?usp=sharing) | 149 | |GRID (multi-speaker) | [0.37 %](https://drive.google.com/file/d/113XmPNjlY7c1fSjUgf5IzJi1us7qVeX1/view?usp=sharing) | 150 | |GRID (unseen-speaker) | [1.67 %](https://drive.google.com/file/d/11CAO_z_eUWcdGWPuVcCYy1mpEG2YghAT/view?usp=sharing) | 151 | |LRW | [1.57 %](https://drive.google.com/file/d/11erwY_Tf69OBfBSdZFwikmBXLI_G2YKb/view?usp=sharing) | 152 | 153 | Put the checkpoints in `./ASR_model/GRID/data` for GRID, and in `./ASR_model/LRW/data` for LRW.
154 | The LRW checkpoint is modified from originally trained using torchvision to that using librosa library. 155 | -------------------------------------------------------------------------------- /ASR_model/LRW/data/class.txt: -------------------------------------------------------------------------------- 1 | ABOUT 2 | ABSOLUTELY 3 | ABUSE 4 | ACCESS 5 | ACCORDING 6 | ACCUSED 7 | ACROSS 8 | ACTION 9 | ACTUALLY 10 | AFFAIRS 11 | AFFECTED 12 | AFRICA 13 | AFTER 14 | AFTERNOON 15 | AGAIN 16 | AGAINST 17 | AGREE 18 | AGREEMENT 19 | AHEAD 20 | ALLEGATIONS 21 | ALLOW 22 | ALLOWED 23 | ALMOST 24 | ALREADY 25 | ALWAYS 26 | AMERICA 27 | AMERICAN 28 | AMONG 29 | AMOUNT 30 | ANNOUNCED 31 | ANOTHER 32 | ANSWER 33 | ANYTHING 34 | AREAS 35 | AROUND 36 | ARRESTED 37 | ASKED 38 | ASKING 39 | ATTACK 40 | ATTACKS 41 | AUTHORITIES 42 | BANKS 43 | BECAUSE 44 | BECOME 45 | BEFORE 46 | BEHIND 47 | BEING 48 | BELIEVE 49 | BENEFIT 50 | BENEFITS 51 | BETTER 52 | BETWEEN 53 | BIGGEST 54 | BILLION 55 | BLACK 56 | BORDER 57 | BRING 58 | BRITAIN 59 | BRITISH 60 | BROUGHT 61 | BUDGET 62 | BUILD 63 | BUILDING 64 | BUSINESS 65 | BUSINESSES 66 | CALLED 67 | CAMERON 68 | CAMPAIGN 69 | CANCER 70 | CANNOT 71 | CAPITAL 72 | CASES 73 | CENTRAL 74 | CERTAINLY 75 | CHALLENGE 76 | CHANCE 77 | CHANGE 78 | CHANGES 79 | CHARGE 80 | CHARGES 81 | CHIEF 82 | CHILD 83 | CHILDREN 84 | CHINA 85 | CLAIMS 86 | CLEAR 87 | CLOSE 88 | CLOUD 89 | COMES 90 | COMING 91 | COMMUNITY 92 | COMPANIES 93 | COMPANY 94 | CONCERNS 95 | CONFERENCE 96 | CONFLICT 97 | CONSERVATIVE 98 | CONTINUE 99 | CONTROL 100 | COULD 101 | COUNCIL 102 | COUNTRIES 103 | COUNTRY 104 | COUPLE 105 | COURSE 106 | COURT 107 | CRIME 108 | CRISIS 109 | CURRENT 110 | CUSTOMERS 111 | DAVID 112 | DEATH 113 | DEBATE 114 | DECIDED 115 | DECISION 116 | DEFICIT 117 | DEGREES 118 | DESCRIBED 119 | DESPITE 120 | DETAILS 121 | DIFFERENCE 122 | DIFFERENT 123 | DIFFICULT 124 | DOING 125 | DURING 126 | EARLY 127 | EASTERN 128 | ECONOMIC 129 | ECONOMY 130 | EDITOR 131 | EDUCATION 132 | ELECTION 133 | EMERGENCY 134 | ENERGY 135 | ENGLAND 136 | ENOUGH 137 | EUROPE 138 | EUROPEAN 139 | EVENING 140 | EVENTS 141 | EVERY 142 | EVERYBODY 143 | EVERYONE 144 | EVERYTHING 145 | EVIDENCE 146 | EXACTLY 147 | EXAMPLE 148 | EXPECT 149 | EXPECTED 150 | EXTRA 151 | FACING 152 | FAMILIES 153 | FAMILY 154 | FIGHT 155 | FIGHTING 156 | FIGURES 157 | FINAL 158 | FINANCIAL 159 | FIRST 160 | FOCUS 161 | FOLLOWING 162 | FOOTBALL 163 | FORCE 164 | FORCES 165 | FOREIGN 166 | FORMER 167 | FORWARD 168 | FOUND 169 | FRANCE 170 | FRENCH 171 | FRIDAY 172 | FRONT 173 | FURTHER 174 | FUTURE 175 | GAMES 176 | GENERAL 177 | GEORGE 178 | GERMANY 179 | GETTING 180 | GIVEN 181 | GIVING 182 | GLOBAL 183 | GOING 184 | GOVERNMENT 185 | GREAT 186 | GREECE 187 | GROUND 188 | GROUP 189 | GROWING 190 | GROWTH 191 | GUILTY 192 | HAPPEN 193 | HAPPENED 194 | HAPPENING 195 | HAVING 196 | HEALTH 197 | HEARD 198 | HEART 199 | HEAVY 200 | HIGHER 201 | HISTORY 202 | HOMES 203 | HOSPITAL 204 | HOURS 205 | HOUSE 206 | HOUSING 207 | HUMAN 208 | HUNDREDS 209 | IMMIGRATION 210 | IMPACT 211 | IMPORTANT 212 | INCREASE 213 | INDEPENDENT 214 | INDUSTRY 215 | INFLATION 216 | INFORMATION 217 | INQUIRY 218 | INSIDE 219 | INTEREST 220 | INVESTMENT 221 | INVOLVED 222 | IRELAND 223 | ISLAMIC 224 | ISSUE 225 | ISSUES 226 | ITSELF 227 | JAMES 228 | JUDGE 229 | JUSTICE 230 | KILLED 231 | KNOWN 232 | LABOUR 233 | LARGE 234 | LATER 235 | LATEST 236 | LEADER 237 | LEADERS 238 | LEADERSHIP 239 | LEAST 240 | LEAVE 241 | LEGAL 242 | LEVEL 243 | LEVELS 244 | LIKELY 245 | LITTLE 246 | LIVES 247 | LIVING 248 | LOCAL 249 | LONDON 250 | LONGER 251 | LOOKING 252 | MAJOR 253 | MAJORITY 254 | MAKES 255 | MAKING 256 | MANCHESTER 257 | MARKET 258 | MASSIVE 259 | MATTER 260 | MAYBE 261 | MEANS 262 | MEASURES 263 | MEDIA 264 | MEDICAL 265 | MEETING 266 | MEMBER 267 | MEMBERS 268 | MESSAGE 269 | MIDDLE 270 | MIGHT 271 | MIGRANTS 272 | MILITARY 273 | MILLION 274 | MILLIONS 275 | MINISTER 276 | MINISTERS 277 | MINUTES 278 | MISSING 279 | MOMENT 280 | MONEY 281 | MONTH 282 | MONTHS 283 | MORNING 284 | MOVING 285 | MURDER 286 | NATIONAL 287 | NEEDS 288 | NEVER 289 | NIGHT 290 | NORTH 291 | NORTHERN 292 | NOTHING 293 | NUMBER 294 | NUMBERS 295 | OBAMA 296 | OFFICE 297 | OFFICERS 298 | OFFICIALS 299 | OFTEN 300 | OPERATION 301 | OPPOSITION 302 | ORDER 303 | OTHER 304 | OTHERS 305 | OUTSIDE 306 | PARENTS 307 | PARLIAMENT 308 | PARTIES 309 | PARTS 310 | PARTY 311 | PATIENTS 312 | PAYING 313 | PEOPLE 314 | PERHAPS 315 | PERIOD 316 | PERSON 317 | PERSONAL 318 | PHONE 319 | PLACE 320 | PLACES 321 | PLANS 322 | POINT 323 | POLICE 324 | POLICY 325 | POLITICAL 326 | POLITICIANS 327 | POLITICS 328 | POSITION 329 | POSSIBLE 330 | POTENTIAL 331 | POWER 332 | POWERS 333 | PRESIDENT 334 | PRESS 335 | PRESSURE 336 | PRETTY 337 | PRICE 338 | PRICES 339 | PRIME 340 | PRISON 341 | PRIVATE 342 | PROBABLY 343 | PROBLEM 344 | PROBLEMS 345 | PROCESS 346 | PROTECT 347 | PROVIDE 348 | PUBLIC 349 | QUESTION 350 | QUESTIONS 351 | QUITE 352 | RATES 353 | RATHER 354 | REALLY 355 | REASON 356 | RECENT 357 | RECORD 358 | REFERENDUM 359 | REMEMBER 360 | REPORT 361 | REPORTS 362 | RESPONSE 363 | RESULT 364 | RETURN 365 | RIGHT 366 | RIGHTS 367 | RULES 368 | RUNNING 369 | RUSSIA 370 | RUSSIAN 371 | SAYING 372 | SCHOOL 373 | SCHOOLS 374 | SCOTLAND 375 | SCOTTISH 376 | SECOND 377 | SECRETARY 378 | SECTOR 379 | SECURITY 380 | SEEMS 381 | SENIOR 382 | SENSE 383 | SERIES 384 | SERIOUS 385 | SERVICE 386 | SERVICES 387 | SEVEN 388 | SEVERAL 389 | SHORT 390 | SHOULD 391 | SIDES 392 | SIGNIFICANT 393 | SIMPLY 394 | SINCE 395 | SINGLE 396 | SITUATION 397 | SMALL 398 | SOCIAL 399 | SOCIETY 400 | SOMEONE 401 | SOMETHING 402 | SOUTH 403 | SOUTHERN 404 | SPEAKING 405 | SPECIAL 406 | SPEECH 407 | SPEND 408 | SPENDING 409 | SPENT 410 | STAFF 411 | STAGE 412 | STAND 413 | START 414 | STARTED 415 | STATE 416 | STATEMENT 417 | STATES 418 | STILL 419 | STORY 420 | STREET 421 | STRONG 422 | SUNDAY 423 | SUNSHINE 424 | SUPPORT 425 | SYRIA 426 | SYRIAN 427 | SYSTEM 428 | TAKEN 429 | TAKING 430 | TALKING 431 | TALKS 432 | TEMPERATURES 433 | TERMS 434 | THEIR 435 | THEMSELVES 436 | THERE 437 | THESE 438 | THING 439 | THINGS 440 | THINK 441 | THIRD 442 | THOSE 443 | THOUGHT 444 | THOUSANDS 445 | THREAT 446 | THREE 447 | THROUGH 448 | TIMES 449 | TODAY 450 | TOGETHER 451 | TOMORROW 452 | TONIGHT 453 | TOWARDS 454 | TRADE 455 | TRIAL 456 | TRUST 457 | TRYING 458 | UNDER 459 | UNDERSTAND 460 | UNION 461 | UNITED 462 | UNTIL 463 | USING 464 | VICTIMS 465 | VIOLENCE 466 | VOTERS 467 | WAITING 468 | WALES 469 | WANTED 470 | WANTS 471 | WARNING 472 | WATCHING 473 | WATER 474 | WEAPONS 475 | WEATHER 476 | WEEKEND 477 | WEEKS 478 | WELCOME 479 | WELFARE 480 | WESTERN 481 | WESTMINSTER 482 | WHERE 483 | WHETHER 484 | WHICH 485 | WHILE 486 | WHOLE 487 | WINDS 488 | WITHIN 489 | WITHOUT 490 | WOMEN 491 | WORDS 492 | WORKERS 493 | WORKING 494 | WORLD 495 | WORST 496 | WOULD 497 | WRONG 498 | YEARS 499 | YESTERDAY 500 | YOUNG 501 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | from torch import nn, optim 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | from src.models.visual_front import Visual_front 8 | from src.models.generator import Decoder, Discriminator, gan_loss, sync_Discriminator, Postnet 9 | import os 10 | from torch.utils.data import DataLoader 11 | from torch.nn import functional as F 12 | from src.data.vid_aud_grid import MultiDataset 13 | from torch.nn import DataParallel as DP 14 | import torch.nn.parallel 15 | import time 16 | import glob 17 | from torch.autograd import grad 18 | import soundfile as sf 19 | from pesq import pesq 20 | from pystoi import stoi 21 | from matplotlib import pyplot as plt 22 | import copy, librosa 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--grid', default="Data_dir") 28 | parser.add_argument("--checkpoint_dir", type=str, default='./data/checkpoints/GRID') 29 | parser.add_argument("--checkpoint", type=str, default=None) 30 | parser.add_argument("--batch_size", type=int, default=100) 31 | parser.add_argument("--epochs", type=int, default=1000) 32 | parser.add_argument("--lr", type=float, default=0.0001) 33 | parser.add_argument("--weight_decay", type=float, default=0.00001) 34 | parser.add_argument("--workers", type=int, default=10) 35 | parser.add_argument("--seed", type=int, default=1) 36 | 37 | parser.add_argument("--subject", type=str, default='overlap') 38 | 39 | parser.add_argument("--start_epoch", type=int, default=0) 40 | parser.add_argument("--augmentations", default=True) 41 | 42 | parser.add_argument("--window_size", type=int, default=40) 43 | parser.add_argument("--max_timesteps", type=int, default=75) 44 | parser.add_argument("--temp", type=float, default=1.0) 45 | 46 | parser.add_argument("--dataparallel", default=False, action='store_true') 47 | parser.add_argument("--gpu", type=str, default='0,1') 48 | 49 | parser.add_argument("--save_mel", default=False, action='store_true') 50 | parser.add_argument("--save_wav", default=False, action='store_true') 51 | 52 | args = parser.parse_args() 53 | return args 54 | 55 | 56 | def train_net(args): 57 | torch.backends.cudnn.deterministic = False 58 | torch.backends.cudnn.benchmark = True 59 | torch.manual_seed(args.seed) 60 | torch.cuda.manual_seed_all(args.seed) 61 | random.seed(args.seed) 62 | os.environ['OMP_NUM_THREADS'] = '2' 63 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 64 | 65 | v_front = Visual_front(in_channels=1) 66 | gen = Decoder() 67 | post = Postnet() 68 | 69 | print(f"Loading checkpoint: {args.checkpoint}") 70 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda()) 71 | 72 | v_front.load_state_dict(checkpoint['v_front_state_dict']) 73 | gen.load_state_dict(checkpoint['gen_state_dict']) 74 | post.load_state_dict(checkpoint['post_state_dict']) 75 | del checkpoint 76 | 77 | v_front.cuda() 78 | gen.cuda() 79 | post.cuda() 80 | 81 | if args.dataparallel: 82 | v_front = DP(v_front) 83 | gen = DP(gen) 84 | post = DP(post) 85 | 86 | _ = test(v_front, gen, post) 87 | 88 | def test(v_front, gen, post, fast_validate=False): 89 | with torch.no_grad(): 90 | v_front.eval() 91 | gen.eval() 92 | post.eval() 93 | 94 | val_data = MultiDataset( 95 | grid=args.grid, 96 | mode='test', 97 | subject=args.subject, 98 | window_size=args.window_size, 99 | max_v_timesteps=args.max_timesteps, 100 | augmentations=False, 101 | fast_validate=fast_validate 102 | ) 103 | 104 | dataloader = DataLoader( 105 | val_data, 106 | shuffle=False, 107 | batch_size=args.batch_size * 2, 108 | num_workers=args.workers, 109 | drop_last=False 110 | ) 111 | 112 | stft = copy.deepcopy(val_data.stft).cuda() 113 | stoi_spec_list = [] 114 | estoi_spec_list = [] 115 | pesq_spec_list = [] 116 | batch_size = dataloader.batch_size 117 | if fast_validate: 118 | samples = min(2 * batch_size, int(len(dataloader.dataset))) 119 | max_batches = 2 120 | else: 121 | samples = int(len(dataloader.dataset)) 122 | max_batches = int(len(dataloader)) 123 | 124 | description = 'Check validation step' if fast_validate else 'Validation' 125 | print(description) 126 | for i, batch in enumerate(dataloader): 127 | if i % 10 == 0: 128 | if not fast_validate: 129 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples)) 130 | mel, spec, vid, vid_len, wav_tr, mel_len, f_name = batch 131 | vid = vid.cuda() 132 | phon, sent = v_front(vid) # S,B,512 133 | g1, g2, g3 = gen(sent, phon, vid_len) 134 | g3_temp = g3.clone() 135 | 136 | vid = vid.flip(4) 137 | phon, sent = v_front(vid) # S,B,512 138 | g1, g2, g3 = gen(sent, phon, vid_len) 139 | 140 | g3 = (g3_temp + g3) / 2. 141 | gs = post(g3) 142 | 143 | wav_spec = val_data.inverse_spec(gs[:, :, :, :mel_len[0]].detach(), stft) 144 | 145 | for b in range(g3.size(0)): 146 | stoi_spec_list.append(stoi(wav_tr[b][:len(wav_spec[b])], wav_spec[b], 16000, extended=False)) 147 | estoi_spec_list.append(stoi(wav_tr[b][:len(wav_spec[b])], wav_spec[b], 16000, extended=True)) 148 | pesq_spec_list.append(pesq(8000, librosa.resample(wav_tr[b][:len(wav_spec[b])].numpy(), 16000, 8000), librosa.resample(wav_spec[b], 16000, 8000), 'nb')) 149 | 150 | sub_name, _, file_name = f_name[b].split('/') 151 | if not os.path.exists(f'./test/spec_mel/{sub_name}'): 152 | os.makedirs(f'./test/spec_mel/{sub_name}') 153 | np.savez(f'./test/spec_mel/{sub_name}/{file_name}.npz', 154 | mel=g3[b, :, :, :mel_len[b]].detach().cpu().numpy(), 155 | spec=gs[b, :, :, :mel_len[b]].detach().cpu().numpy()) 156 | 157 | if not os.path.exists(f'./test/wav/{sub_name}'): 158 | os.makedirs(f'./test/wav/{sub_name}') 159 | sf.write(f'./test/wav/{sub_name}/{file_name}.wav', wav_spec[b], 16000, subtype='PCM_16') 160 | 161 | if i >= max_batches: 162 | break 163 | 164 | print('STOI: ', np.mean(stoi_spec_list)) 165 | print('ESTOI: ', np.mean(estoi_spec_list)) 166 | print('PESQ: ', np.mean(pesq_spec_list)) 167 | with open(f'./test/metric.txt', 'w') as f: 168 | f.write(f'STOI : {np.mean(stoi_spec_list)}') 169 | f.write(f'ESTOI : {np.mean(estoi_spec_list)}') 170 | f.write(f'PESQ : {np.mean(pesq_spec_list)}') 171 | 172 | if __name__ == "__main__": 173 | args = parse_args() 174 | train_net(args) 175 | 176 | -------------------------------------------------------------------------------- /ASR_model/GRID/src/data/vid_aud_GRID_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchaudio 8 | import torchvision 9 | from torchvision import transforms 10 | from torch.utils.data import DataLoader, Dataset 11 | import random 12 | from librosa.filters import mel as librosa_mel_fn 13 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression 14 | from src.data.stft import STFT 15 | import math, glob 16 | from scipy import signal 17 | 18 | log1e5 = math.log(1e-5) 19 | 20 | letters = ['_', ' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 21 | 'U', 'V', 'W', 'X', 'Y', 'Z'] 22 | 23 | 24 | class MultiDataset(Dataset): 25 | def __init__(self, grid, mode, gtpath, subject, max_v_timesteps=155, max_text_len=150, augmentations=False, num_mel_bins=80, wav=False): 26 | self.wav = wav 27 | self.gtpath = gtpath 28 | self.max_v_timesteps = max_v_timesteps 29 | self.max_text_len = max_text_len 30 | self.augmentations = augmentations if mode == 'train' else False 31 | self.file_paths = self.build_file_list(grid, subject) 32 | self.int2char = dict(enumerate(letters)) 33 | self.char2int = {char: index for index, char in self.int2char.items()} 34 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=num_mel_bins, sampling_rate=16000, mel_fmin=55., mel_fmax=7500.) 35 | 36 | def build_file_list(self, grid, subject): 37 | check_list = [] 38 | if subject == 'overlap': 39 | with open('./../../data/overlap_val.txt', 'r') as f: 40 | lines = f.readlines() 41 | for l in lines: 42 | file = l.strip().replace('mpg_6000/', '') + '.mp4' 43 | check_list.append(os.path.join(grid, file)) 44 | elif subject == 'unseen': 45 | with open('./../../data/unseen_splits.txt', 'r') as f: 46 | lines = f.readlines() 47 | for l in lines: 48 | if 'test' in l.strip(): 49 | _, sub, fname = l.strip().split('/') 50 | file = f'{sub}/video/{fname}.mp4' 51 | if os.path.exists(os.path.join(grid, file)): 52 | check_list.append(os.path.join(grid, file)) 53 | else: 54 | with open('./../../data/test_4.txt', 'r') as f: 55 | lines = f.readlines() 56 | for l in lines: 57 | file = l.strip() 58 | if subject == 'four': 59 | check_list.append(os.path.join(grid, file)) 60 | elif file.split('/')[0] == subject: 61 | check_list.append(os.path.join(grid, file)) 62 | 63 | if self.wav: 64 | file_list = sorted(glob.glob(os.path.join(grid, '*', '*.wav'))) 65 | else: 66 | file_list = sorted(glob.glob(os.path.join(grid, '*', '*.npz'))) 67 | 68 | assert len(check_list) == len(file_list), 'The data for testing is not full' 69 | return file_list 70 | 71 | def __len__(self): 72 | return len(self.file_paths) 73 | 74 | def build_content(self, content): 75 | words = [] 76 | with open(content, 'r') as f: 77 | lines = f.readlines() 78 | for l in lines: 79 | word = l.strip().split()[2] 80 | if not word in ['SIL', 'SP', 'sil', 'sp']: 81 | words.append(word) 82 | return words 83 | 84 | def __getitem__(self, idx): 85 | file_path = self.file_paths[idx] 86 | t, f_name = os.path.split(file_path) 87 | _, sub = os.path.split(t) 88 | 89 | words = self.build_content(os.path.join(self.gtpath, sub.split('_')[0], 'align', f_name.split('.')[0] + '.align')) 90 | content = ' '.join(words).upper() 91 | 92 | if self.wav: 93 | aud, sr = torchaudio.load(file_path) 94 | 95 | if round(sr) != 16000: 96 | aud = torch.tensor(librosa.resample(aud.squeeze(0).numpy(), sr, 16000)).unsqueeze(0) 97 | 98 | aud = aud / torch.abs(aud).max() * 0.9 99 | aud = torch.FloatTensor(self.preemphasize(aud.squeeze(0))).unsqueeze(0) 100 | aud = torch.clamp(aud, min=-1, max=1) 101 | 102 | spec = self.stft.mel_spectrogram(aud) 103 | num_a_frames = spec.size(2) 104 | 105 | else: 106 | data = np.load(file_path) 107 | mel = data['mel'] 108 | data.close() 109 | 110 | spec = torch.FloatTensor(self.denormalize(mel)) 111 | 112 | num_a_frames = spec.size(2) 113 | 114 | target, txt_len = self.encode(content) 115 | 116 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(spec) 117 | 118 | return spec, target, num_a_frames, txt_len 119 | 120 | def encode(self, content): 121 | encoded = [self.char2int[c] for c in content] 122 | if len(encoded) > self.max_text_len: 123 | print(f"Max output length too short. Required {len(encoded)}") 124 | encoded = encoded[:self.max_text_len] 125 | num_txt = len(encoded) 126 | encoded += [self.char2int['_'] for _ in range(self.max_text_len - len(encoded))] 127 | return torch.Tensor(encoded), num_txt 128 | 129 | def preemphasize(self, aud): 130 | aud = signal.lfilter([1, -0.97], [1], aud) 131 | return aud 132 | 133 | def denormalize(self, melspec): 134 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5 135 | return melspec 136 | 137 | class TacotronSTFT(torch.nn.Module): 138 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 139 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 140 | mel_fmax=8000.0): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 147 | mel_basis = torch.from_numpy(mel_basis).float() 148 | self.register_buffer('mel_basis', mel_basis) 149 | 150 | def spectral_normalize(self, magnitudes): 151 | output = dynamic_range_compression(magnitudes) 152 | return output 153 | 154 | def spectral_de_normalize(self, magnitudes): 155 | output = dynamic_range_decompression(magnitudes) 156 | return output 157 | 158 | def mel_spectrogram(self, y): 159 | """Computes mel-spectrograms from a batch of waves 160 | PARAMS 161 | ------ 162 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 163 | RETURNS 164 | ------- 165 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 166 | """ 167 | assert(torch.min(y.data) >= -1) 168 | assert(torch.max(y.data) <= 1) 169 | 170 | magnitudes, phases = self.stft_fn.transform(y) 171 | magnitudes = magnitudes.data 172 | mel_output = torch.matmul(self.mel_basis, magnitudes) 173 | mel_output = self.spectral_normalize(mel_output) 174 | return mel_output 175 | -------------------------------------------------------------------------------- /ASR_model/GRID/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | from torch import nn, optim 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | from src.models.audio_front import Audio_front 8 | from src.models.classifier import Backend 9 | import os 10 | from torch.utils.data import DataLoader 11 | from torch.nn import functional as F 12 | from src.data.vid_aud_GRID_test import MultiDataset 13 | from torch.nn import DataParallel as DP 14 | import torch.nn.parallel 15 | import math 16 | import editdistance 17 | import re 18 | from matplotlib import pyplot as plt 19 | import time 20 | import glob 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--data', default="TEST_DIR", help='./../../test/spec_mel') 25 | parser.add_argument('--wav', default=False, action='store_true', help='Is waveform or Mel(.npz) form') 26 | parser.add_argument('--gtpath', default="GT_path", help='GT transcription path') 27 | parser.add_argument('--model', default="GRID_CTC") 28 | parser.add_argument("--checkpoint_dir", type=str, default='./data') 29 | parser.add_argument("--checkpoint", type=str, default='./data/GRID_4_wer_0.00833_cer_0.00252.ckpt') 30 | parser.add_argument("--batch_size", type=int, default=160) 31 | parser.add_argument("--epochs", type=int, default=150) 32 | parser.add_argument("--lr", type=float, default=0.001) 33 | parser.add_argument("--weight_decay", type=float, default=0.00001) 34 | parser.add_argument("--workers", type=int, default=4) 35 | parser.add_argument("--resnet", type=int, default=18) 36 | parser.add_argument("--seed", type=int, default=1) 37 | 38 | parser.add_argument("--subject", default='overlap', help=['overlap', 'unseen', 's1', 's2', 's4', 's29', 'four']) 39 | 40 | parser.add_argument("--max_timesteps", type=int, default=75) 41 | parser.add_argument("--max_text_len", type=int, default=75) 42 | 43 | parser.add_argument("--dataparallel", default=False, action='store_true') 44 | parser.add_argument("--gpu", type=str, default='0') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def train_net(args): 50 | torch.backends.cudnn.deterministic = True 51 | torch.backends.cudnn.benchmark = False 52 | torch.manual_seed(args.seed) 53 | torch.cuda.manual_seed_all(args.seed) 54 | random.seed(args.seed) 55 | os.environ['OMP_NUM_THREADS'] = '2' 56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 57 | 58 | a_front = Audio_front() 59 | a_back = Backend() 60 | 61 | if args.subject == 'unseen': 62 | args.checkpoint = './data/GRID_unseen_wer_0.01676_cer_0.00896.ckpt' 63 | elif args.subject == 'overlap': 64 | args.checkpoint = './data/GRID_33_wer_0.00368_cer_0.00120.ckpt' 65 | else: 66 | args.checkpoint = './data/GRID_4_wer_0.00833_cer_0.00252.ckpt' 67 | 68 | if args.checkpoint is not None: 69 | print(f"Loading checkpoint: {args.checkpoint}") 70 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda()) 71 | a_front.load_state_dict(checkpoint['a_front_state_dict']) 72 | a_back.load_state_dict(checkpoint['a_back_state_dict']) 73 | del checkpoint 74 | 75 | a_front.cuda() 76 | a_back.cuda() 77 | 78 | if args.dataparallel: 79 | a_front = DP(a_front) 80 | a_back = DP(a_back) 81 | 82 | wer, cer = validate(a_front, a_back) 83 | 84 | def validate(a_front, a_back, fast_validate=False): 85 | with torch.no_grad(): 86 | a_front.eval() 87 | a_back.eval() 88 | 89 | val_data = MultiDataset( 90 | grid=args.data, 91 | mode='test', 92 | gtpath=args.gtpath, 93 | subject=args.subject, 94 | max_v_timesteps=args.max_timesteps, 95 | max_text_len=args.max_text_len, 96 | wav=args.wav, 97 | augmentations=False 98 | ) 99 | 100 | dataloader = DataLoader( 101 | val_data, 102 | shuffle=False, 103 | batch_size=args.batch_size * 2, 104 | num_workers=args.workers, 105 | drop_last=False 106 | ) 107 | 108 | batch_size = dataloader.batch_size 109 | if fast_validate: 110 | samples = min(2 * batch_size, int(len(dataloader.dataset))) 111 | max_batches = 2 112 | else: 113 | samples = int(len(dataloader.dataset)) 114 | max_batches = int(len(dataloader)) 115 | 116 | wer_sum, cer_sum, tot_num = 0, 0, 0 117 | 118 | description = 'Check validation step' if fast_validate else 'Validation' 119 | print(description) 120 | for i, batch in enumerate(dataloader): 121 | if i % 50 == 0: 122 | if not fast_validate: 123 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples)) 124 | a_in, target, aud_len, txt_len = batch 125 | 126 | a_feat = a_front(a_in.cuda()) # B,S,512 127 | pred = a_back(a_feat) 128 | 129 | cer, wer, sentences = greedy_decode(val_data, F.softmax(pred, dim=2).cpu(), target) 130 | 131 | B, S, _ = a_feat.size() 132 | 133 | tot_num += B 134 | wer_sum += B * wer 135 | cer_sum += B * cer 136 | batch_size = B 137 | 138 | if i % 50 == 0: 139 | for j in range(2): 140 | print('label: ', sentences[j][0]) 141 | print('prediction: ', sentences[j][1]) 142 | 143 | if i >= max_batches: 144 | break 145 | 146 | a_front.train() 147 | a_back.train() 148 | 149 | wer = wer_sum / tot_num 150 | cer = cer_sum / tot_num 151 | 152 | print('test_cer:', cer) 153 | print('test_wer:', wer) 154 | 155 | if fast_validate: 156 | return {} 157 | else: 158 | return wer, cer 159 | 160 | def decode(dataset, label_tokens, pred_tokens): 161 | label, output = '', '' 162 | for index in range(len(label_tokens)): 163 | label += dataset.int2char[int(label_tokens[index])] 164 | for index in range(len(pred_tokens)): 165 | output += dataset.int2char[int(pred_tokens[index])] 166 | 167 | output = re.sub(' +', ' ', output) 168 | pattern = re.compile(r"(.)\1{1,}", re.DOTALL) # remove characters that are repeated more than 2 times 169 | output = pattern.sub(r"\1", output) 170 | 171 | label = label.replace('_', '') 172 | output = output.replace('_', '') 173 | 174 | output_words, label_words = output.split(" "), label.split(" ") 175 | 176 | cer = editdistance.eval(output, label) / len(label) 177 | wer = editdistance.eval(output_words, label_words) / len(label_words) 178 | 179 | return label, output, cer, wer 180 | 181 | def greedy_decode(dataset, results, target): 182 | _, results = results.topk(1, dim=2) 183 | results = results.squeeze(dim=2) 184 | cer_sum, wer_sum = 0, 0 185 | batch_size = results.size(0) 186 | sentences = [] 187 | for batch in range(batch_size): 188 | label, output, cer, wer = decode(dataset, target[batch], results[batch]) 189 | sentences.append([label, output]) 190 | cer_sum += cer 191 | wer_sum += wer 192 | 193 | return cer_sum / batch_size, wer_sum / batch_size, sentences 194 | 195 | if __name__ == "__main__": 196 | args = parse_args() 197 | train_net(args) 198 | 199 | -------------------------------------------------------------------------------- /test_LRS.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | from torch import nn, optim 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | from src.models.visual_front import Visual_front 8 | from src.models.generator import Decoder, Discriminator, gan_loss, sync_Discriminator, Postnet 9 | import os 10 | from torch.utils.data import DataLoader 11 | from torch.nn import functional as F 12 | from src.data.vid_aud_lrs2 import MultiDataset as LRS2_dataset 13 | from src.data.vid_aud_lrs3 import MultiDataset as LRS3_dataset 14 | from torch.nn import DataParallel as DP 15 | import torch.nn.parallel 16 | import time 17 | import glob 18 | from torch.autograd import grad 19 | import soundfile as sf 20 | from pesq import pesq 21 | from pystoi import stoi 22 | from matplotlib import pyplot as plt 23 | import copy, librosa 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--data', default="Data_dir") 29 | parser.add_argument('--data_name', type=str, default="LRS2") 30 | parser.add_argument("--checkpoint_dir", type=str, default='./data/checkpoints/') 31 | parser.add_argument("--checkpoint", type=str, default='checkpoint') 32 | parser.add_argument("--batch_size", type=int, default=16) 33 | parser.add_argument("--epochs", type=int, default=1000) 34 | parser.add_argument("--lr", type=float, default=0.0001) 35 | parser.add_argument("--weight_decay", type=float, default=0.00001) 36 | parser.add_argument("--workers", type=int, default=10) 37 | parser.add_argument("--seed", type=int, default=1) 38 | 39 | parser.add_argument("--subject", type=str, default='overlap') 40 | 41 | parser.add_argument("--f_min", type=int, default=55.) 42 | parser.add_argument("--f_max", type=int, default=7600.) 43 | 44 | parser.add_argument("--start_epoch", type=int, default=0) 45 | parser.add_argument("--augmentations", default=True) 46 | 47 | parser.add_argument("--window_size", type=int, default=50) 48 | parser.add_argument("--max_timesteps", type=int, default=160) 49 | 50 | parser.add_argument("--dataparallel", default=False, action='store_true') 51 | parser.add_argument("--gpu", type=str, default='0') 52 | 53 | parser.add_argument("--save_mel", default=True, action='store_true') 54 | parser.add_argument("--save_wav", default=True, action='store_true') 55 | 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def train_net(args): 61 | torch.backends.cudnn.deterministic = False 62 | torch.backends.cudnn.benchmark = True 63 | torch.manual_seed(args.seed) 64 | torch.cuda.manual_seed_all(args.seed) 65 | random.seed(args.seed) 66 | os.environ['OMP_NUM_THREADS'] = '2' 67 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 68 | 69 | v_front = Visual_front(in_channels=1) 70 | gen = Decoder() 71 | post = Postnet() 72 | 73 | print(f"Loading checkpoint: {args.checkpoint}") 74 | checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda()) 75 | 76 | v_front.load_state_dict(checkpoint['v_front_state_dict']) 77 | gen.load_state_dict(checkpoint['gen_state_dict']) 78 | post.load_state_dict(checkpoint['post_state_dict']) 79 | del checkpoint 80 | 81 | v_front.cuda() 82 | gen.cuda() 83 | post.cuda() 84 | 85 | if args.dataparallel: 86 | v_front = DP(v_front) 87 | gen = DP(gen) 88 | post = DP(post) 89 | 90 | _ = test(v_front, gen, post) 91 | 92 | def test(v_front, gen, post, fast_validate=False): 93 | with torch.no_grad(): 94 | v_front.eval() 95 | gen.eval() 96 | post.eval() 97 | 98 | if args.data_name == 'LRS2': 99 | val_data = LRS2_dataset( 100 | data=args.data, 101 | mode='test', 102 | max_v_timesteps=args.max_timesteps, 103 | window_size=args.window_size, 104 | augmentations=False, 105 | f_min=args.f_min, 106 | f_max=args.f_max 107 | ) 108 | elif args.data_name == 'LRS3': 109 | val_data = LRS3_dataset( 110 | data=args.data, 111 | mode='test', 112 | max_v_timesteps=args.max_timesteps, 113 | window_size=args.window_size, 114 | augmentations=False, 115 | f_min=args.f_min, 116 | f_max=args.f_max 117 | ) 118 | 119 | dataloader = DataLoader( 120 | val_data, 121 | shuffle=True if fast_validate else False, 122 | batch_size=args.batch_size, 123 | num_workers=args.workers, 124 | drop_last=False, 125 | collate_fn=lambda x: val_data.collate_fn(x) 126 | ) 127 | 128 | stft = copy.deepcopy(val_data.stft).cuda() 129 | stoi_spec_list = [] 130 | estoi_spec_list = [] 131 | pesq_spec_list = [] 132 | batch_size = dataloader.batch_size 133 | if fast_validate: 134 | samples = min(2 * batch_size, int(len(dataloader.dataset))) 135 | max_batches = 2 136 | else: 137 | samples = int(len(dataloader.dataset)) 138 | max_batches = int(len(dataloader)) 139 | 140 | description = 'Check validation step' if fast_validate else 'Validation' 141 | print(description) 142 | for i, batch in enumerate(dataloader): 143 | if i % 10 == 0: 144 | if not fast_validate: 145 | print("******** Validation : %d / %d ********" % ((i + 1) * batch_size, samples)) 146 | mel, spec, vid, vid_len, wav_tr, mel_len, f_name = batch 147 | 148 | vid = vid.cuda() 149 | phon, sent = v_front(vid) # S,B,512 150 | g1, g2, g3 = gen(sent, phon, vid_len) 151 | g3_temp = g3.clone() 152 | 153 | vid = vid.flip(4) 154 | phon, sent = v_front(vid) # S,B,512 155 | g1, g2, g3 = gen(sent, phon, vid_len) 156 | 157 | g3 = (g3_temp + g3) / 2. 158 | gs = post(g3) 159 | 160 | for b in range(g3.size(0)): 161 | wav_spec = val_data.inverse_spec(gs[b, :, :, :mel_len[b]].detach(), stft)[0] 162 | min_length = min(len(wav_spec), len(wav_tr[b])) 163 | stoi_spec_list.append(stoi(wav_tr[b][:min_length], wav_spec[:min_length], 16000, extended=False)) 164 | estoi_spec_list.append(stoi(wav_tr[b][:min_length], wav_spec[:min_length], 16000, extended=True)) 165 | pesq_spec_list.append(pesq(8000, librosa.resample(wav_tr[b][:min_length].numpy(), 16000, 8000), librosa.resample(wav_spec, 16000, 8000), 'nb')) 166 | 167 | m_name, v_name, file_name = f_name[b].split('/') 168 | if not os.path.exists(f'./test/{args.data_name}/mel/{m_name}/{v_name}'): 169 | os.makedirs(f'./test/{args.data_name}/mel/{m_name}/{v_name}') 170 | np.savez(f'./test/{args.data_name}/mel/{m_name}/{v_name}/{file_name}.npz', 171 | mel=g3[b, :, :, :mel_len[b]].detach().cpu().numpy(), 172 | spec=gs[b, :, :, :mel_len[b]].detach().cpu().numpy()) 173 | 174 | if not os.path.exists(f'./test/{args.data_name}/wav/{m_name}/{v_name}'): 175 | os.makedirs(f'./test/{args.data_name}/wav/{m_name}/{v_name}') 176 | sf.write(f'./test/{args.data_name}/wav/{m_name}/{v_name}/{file_name}.wav', wav_spec, 16000, subtype='PCM_16') 177 | 178 | 179 | if i >= max_batches: 180 | break 181 | 182 | print('STOI: ', np.mean(stoi_spec_list)) 183 | print('ESTOI: ', np.mean(estoi_spec_list)) 184 | print('PESQ: ', np.mean(pesq_spec_list)) 185 | with open(f'./test/{args.data_name}/metric.txt', 'w') as f: 186 | f.write(f'STOI : {np.mean(stoi_spec_list)}') 187 | f.write(f'ESTOI : {np.mean(estoi_spec_list)}') 188 | f.write(f'PESQ : {np.mean(pesq_spec_list)}') 189 | 190 | if __name__ == "__main__": 191 | args = parse_args() 192 | train_net(args) 193 | 194 | -------------------------------------------------------------------------------- /src/data/vid_aud_grid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchaudio 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader, Dataset 10 | from src.data.transforms import Crop, StatefulRandomHorizontalFlip 11 | from PIL import Image 12 | import librosa 13 | from matplotlib import pyplot as plt 14 | import glob 15 | from scipy import signal 16 | import torchvision 17 | from torch.autograd import Variable 18 | from librosa.filters import mel as librosa_mel_fn 19 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression, griffin_lim 20 | from src.data.stft import STFT 21 | import math 22 | log1e5 = math.log(1e-5) 23 | 24 | class MultiDataset(Dataset): 25 | def __init__(self, grid, mode, max_v_timesteps=155, window_size=40, subject=None, augmentations=False, num_mel_bins=80, fast_validate=False): 26 | assert mode in ['train', 'test', 'val'] 27 | self.grid = grid 28 | self.mode = mode 29 | self.sample_window = True if mode == 'train' else False 30 | self.fast_validate = fast_validate 31 | self.max_v_timesteps = window_size if self.sample_window else max_v_timesteps 32 | self.window_size = window_size 33 | self.augmentations = augmentations if mode == 'train' else False 34 | self.num_mel_bins = num_mel_bins 35 | self.file_paths = self.build_file_list(grid, mode, subject) 36 | self.f_min = 55. 37 | self.f_max = 7500. 38 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=80, sampling_rate=16000, mel_fmin=self.f_min, mel_fmax=self.f_max) 39 | 40 | def build_file_list(self, grid, mode, subject): 41 | file_list = [] 42 | if subject == 'overlap': 43 | if mode == 'train': 44 | with open('./data/overlap_train.txt', 'r') as f: 45 | lines = f.readlines() 46 | for l in lines: 47 | file = l.strip().replace('mpg_6000/', '') + '.mp4' 48 | file_list.append(os.path.join(grid, file)) 49 | else: 50 | with open('./data/overlap_val.txt', 'r') as f: 51 | lines = f.readlines() 52 | for l in lines: 53 | file = l.strip().replace('mpg_6000/', '') + '.mp4' 54 | file_list.append(os.path.join(grid, file)) 55 | elif subject == 'unseen': 56 | with open('./data/unseen_splits.txt', 'r') as f: 57 | lines = f.readlines() 58 | for l in lines: 59 | if mode in l.strip(): 60 | _, sub, fname = l.strip().split('/') 61 | file = f'{sub}/video/{fname}.mp4' 62 | if os.path.exists(os.path.join(grid, file)): 63 | file_list.append(os.path.join(grid, file)) 64 | else: 65 | if mode == 'train': 66 | with open('./data/train_4.txt', 'r') as f: 67 | lines = f.readlines() 68 | for l in lines: 69 | file = l.strip() 70 | if subject == 'four': 71 | file_list.append(os.path.join(grid, file)) 72 | elif file.split('/')[0] == subject: 73 | file_list.append(os.path.join(grid, file)) 74 | elif mode == 'val': 75 | with open('./data/val_4.txt', 'r') as f: 76 | lines = f.readlines() 77 | for l in lines: 78 | file = l.strip() 79 | if subject == 'four': 80 | file_list.append(os.path.join(grid, file)) 81 | elif file.split('/')[0] == subject: 82 | file_list.append(os.path.join(grid, file)) 83 | else: 84 | with open('./data/test_4.txt', 'r') as f: 85 | lines = f.readlines() 86 | for l in lines: 87 | file = l.strip() 88 | if subject == 'four': 89 | file_list.append(os.path.join(grid, file)) 90 | elif file.split('/')[0] == subject: 91 | file_list.append(os.path.join(grid, file)) 92 | return file_list 93 | 94 | def build_tensor(self, frames): 95 | if self.augmentations: 96 | augmentations1 = transforms.Compose([StatefulRandomHorizontalFlip(0.5)]) 97 | else: 98 | augmentations1 = transforms.Compose([]) 99 | crop = [59, 95, 195, 231] 100 | 101 | transform = transforms.Compose([ 102 | transforms.ToPILImage(), 103 | Crop(crop), 104 | transforms.Resize([112, 112]), 105 | augmentations1, 106 | transforms.Grayscale(num_output_channels=1), 107 | transforms.ToTensor(), 108 | transforms.Normalize(0.4136, 0.1700) 109 | ]) 110 | 111 | temporalVolume = torch.zeros(self.max_v_timesteps, 1, 112, 112) 112 | for i, frame in enumerate(frames): 113 | temporalVolume[i] = transform(frame) 114 | 115 | ### Random Erasing ### 116 | if self.augmentations: 117 | x_s, y_s = [random.randint(-10, 66) for _ in range(2)] # starting point 118 | temporalVolume[:, :, np.maximum(0, y_s):np.minimum(112, y_s + 56), np.maximum(0, x_s):np.minimum(112, x_s + 56)] = 0. 119 | 120 | temporalVolume = temporalVolume.transpose(1, 0) # (C, T, H, W) 121 | return temporalVolume 122 | 123 | def __len__(self): 124 | return len(self.file_paths) 125 | 126 | def __getitem__(self, idx): 127 | file_path = self.file_paths[idx] 128 | 129 | vid, _, info = torchvision.io.read_video(file_path, pts_unit='sec') 130 | audio, info['audio_fps'] = librosa.load(file_path.replace('video', 'audio')[:-4] + '.flac', sr=16000) 131 | audio = torch.FloatTensor(audio).unsqueeze(0) 132 | 133 | if not 'video_fps' in info: 134 | info['video_fps'] = 25 135 | info['audio_fps'] = 16000 136 | 137 | if vid.size(0) < 5 or audio.size(1) < 5: 138 | vid = torch.zeros([1, 112, 112, 3]) 139 | audio = torch.zeros([1, 16000//25]) 140 | 141 | ## Audio ## 142 | aud = audio / torch.abs(audio).max() * 0.9 143 | aud = torch.FloatTensor(self.preemphasize(aud.squeeze(0))).unsqueeze(0) 144 | aud = torch.clamp(aud, min=-1, max=1) 145 | 146 | melspec, spec = self.stft.mel_spectrogram(aud) 147 | 148 | ## Video ## 149 | vid = vid.permute(0, 3, 1, 2) # T C H W 150 | 151 | if self.sample_window: 152 | vid, melspec, spec, audio = self.extract_window(vid, melspec, spec, audio, info) 153 | 154 | num_v_frames = vid.size(0) 155 | vid = self.build_tensor(vid) 156 | 157 | melspec = self.normalize(melspec) 158 | 159 | num_a_frames = melspec.size(2) 160 | melspec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(melspec) 161 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), 0.0)(spec) 162 | 163 | if not self.sample_window: 164 | audio = audio[:, :self.max_v_timesteps * 4 * 160] 165 | audio = torch.cat([audio, torch.zeros([1, int(self.max_v_timesteps / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1) 166 | 167 | if self.mode == 'test': 168 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames, file_path.replace(self.grid, '')[1:-4] 169 | else: 170 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames 171 | 172 | def extract_window(self, vid, mel, spec, aud, info): 173 | # vid : T,C,H,W 174 | vid_2_aud = info['audio_fps'] / info['video_fps'] / 160 175 | 176 | st_fr = random.randint(0, vid.size(0) - self.window_size) 177 | vid = vid[st_fr:st_fr + self.window_size] 178 | 179 | st_mel_fr = int(st_fr * vid_2_aud) 180 | mel_window_size = int(self.window_size * vid_2_aud) 181 | 182 | mel = mel[:, :, st_mel_fr:st_mel_fr + mel_window_size] 183 | spec = spec[:, :, st_mel_fr:st_mel_fr + mel_window_size] 184 | 185 | aud = aud[:, st_mel_fr*160:st_mel_fr*160 + mel_window_size*160] 186 | aud = torch.cat([aud, torch.zeros([1, int(self.window_size / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1) 187 | 188 | return vid, mel, spec, aud 189 | 190 | def inverse_mel(self, mel, stft): 191 | if len(mel.size()) < 4: 192 | mel = mel.unsqueeze(0) # B,1,80,T 193 | 194 | mel = self.denormalize(mel) 195 | mel = stft.spectral_de_normalize(mel) 196 | mel = mel.transpose(2, 3).contiguous() # B,80,T --> B,T,80 197 | spec_from_mel_scaling = 1000 198 | spec_from_mel = torch.matmul(mel, stft.mel_basis) 199 | spec_from_mel = spec_from_mel.transpose(2, 3).squeeze(1) # B,1,F,T 200 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 201 | 202 | wav = griffin_lim(spec_from_mel, stft.stft_fn, 60).squeeze(1) # B,L 203 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy() 204 | wavs = [] 205 | for w in wav: 206 | w = self.deemphasize(w) 207 | wavs += [w] 208 | wavs = np.array(wavs) 209 | wavs = np.clip(wavs, -1, 1) 210 | return wavs 211 | 212 | def inverse_spec(self, spec, stft): 213 | if len(spec.size()) < 4: 214 | spec = spec.unsqueeze(0) # B,1,321,T 215 | 216 | wav = griffin_lim(spec.squeeze(1), stft.stft_fn, 60).squeeze(1) # B,L 217 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy() 218 | wavs = [] 219 | for w in wav: 220 | w = self.deemphasize(w) 221 | wavs += [w] 222 | wavs = np.array(wavs) 223 | wavs = np.clip(wavs, -1, 1) 224 | return wavs 225 | 226 | def preemphasize(self, aud): 227 | aud = signal.lfilter([1, -0.97], [1], aud) 228 | return aud 229 | 230 | def deemphasize(self, aud): 231 | aud = signal.lfilter([1], [1, -0.97], aud) 232 | return aud 233 | 234 | def normalize(self, melspec): 235 | melspec = ((melspec - log1e5) / (-log1e5 / 2)) - 1 #0~2 --> -1~1 236 | return melspec 237 | 238 | def denormalize(self, melspec): 239 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5 240 | return melspec 241 | 242 | def audio_preprocessing(self, aud): 243 | fc = self.f_min 244 | w = fc / (16000 / 2) 245 | b, a = signal.butter(7, w, 'high') 246 | aud = aud.squeeze(0).numpy() 247 | aud = signal.filtfilt(b, a, aud) 248 | return torch.tensor(aud.copy()).unsqueeze(0) 249 | 250 | def plot_spectrogram_to_numpy(self, mels): 251 | fig, ax = plt.subplots(figsize=(15, 4)) 252 | im = ax.imshow(np.squeeze(mels, 0), aspect="auto", origin="lower", 253 | interpolation='none') 254 | plt.colorbar(im, ax=ax) 255 | plt.xlabel("Frames") 256 | plt.ylabel("Channels") 257 | plt.tight_layout() 258 | 259 | fig.canvas.draw() 260 | data = self.save_figure_to_numpy(fig) 261 | plt.close() 262 | return data 263 | 264 | def save_figure_to_numpy(self, fig): 265 | # save it to a numpy array. 266 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 267 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 268 | return data.transpose(2, 0, 1) 269 | 270 | class TacotronSTFT(torch.nn.Module): 271 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 272 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 273 | mel_fmax=8000.0): 274 | super(TacotronSTFT, self).__init__() 275 | self.n_mel_channels = n_mel_channels 276 | self.sampling_rate = sampling_rate 277 | self.stft_fn = STFT(filter_length, hop_length, win_length) 278 | mel_basis = librosa_mel_fn( 279 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 280 | mel_basis = torch.from_numpy(mel_basis).float() 281 | self.register_buffer('mel_basis', mel_basis) 282 | 283 | def spectral_normalize(self, magnitudes): 284 | output = dynamic_range_compression(magnitudes) 285 | return output 286 | 287 | def spectral_de_normalize(self, magnitudes): 288 | output = dynamic_range_decompression(magnitudes) 289 | return output 290 | 291 | def mel_spectrogram(self, y): 292 | """Computes mel-spectrograms from a batch of waves 293 | PARAMS 294 | ------ 295 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 296 | RETURNS 297 | ------- 298 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 299 | """ 300 | assert(torch.min(y.data) >= -1) 301 | assert(torch.max(y.data) <= 1) 302 | 303 | magnitudes, phases = self.stft_fn.transform(y) 304 | magnitudes = magnitudes.data 305 | mel_output = torch.matmul(self.mel_basis, magnitudes) 306 | mel_output = self.spectral_normalize(mel_output) 307 | return mel_output, magnitudes 308 | -------------------------------------------------------------------------------- /src/models/generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from src.models.resnet import BasicBlock 7 | 8 | class ResBlk1D(nn.Module): 9 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), 10 | normalize=False, downsample=False): 11 | super().__init__() 12 | self.actv = actv 13 | self.normalize = normalize 14 | self.downsample = downsample 15 | self.learned_sc = dim_in != dim_out 16 | self._build_weights(dim_in, dim_out) 17 | 18 | def _build_weights(self, dim_in, dim_out): 19 | self.conv1 = nn.Conv1d(dim_in, dim_in, 5, 1, 2) 20 | self.conv2 = nn.Conv1d(dim_in, dim_out, 5, 1, 2) 21 | if self.normalize: 22 | self.norm1 = nn.BatchNorm1d(dim_in) 23 | self.norm2 = nn.BatchNorm1d(dim_in) 24 | if self.learned_sc: 25 | self.conv1x1 = nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False) 26 | 27 | def _shortcut(self, x): 28 | if self.learned_sc: 29 | x = self.conv1x1(x) 30 | if self.downsample: 31 | x = F.avg_pool1d(x, 2) 32 | return x 33 | 34 | def _residual(self, x): 35 | if self.normalize: 36 | x = self.norm1(x) 37 | x = self.actv(x) 38 | x = self.conv1(x) 39 | if self.downsample: 40 | x = F.avg_pool1d(x, 2) 41 | if self.normalize: 42 | x = self.norm2(x) 43 | x = self.actv(x) 44 | x = self.conv2(x) 45 | return x 46 | 47 | def forward(self, x): 48 | x = self._shortcut(x) + self._residual(x) 49 | return x / math.sqrt(2) # unit variance 50 | 51 | class ResBlk(nn.Module): 52 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), 53 | normalize=False, downsample=False): 54 | super().__init__() 55 | self.actv = actv 56 | self.normalize = normalize 57 | self.downsample = downsample 58 | self.learned_sc = dim_in != dim_out 59 | self._build_weights(dim_in, dim_out) 60 | 61 | def _build_weights(self, dim_in, dim_out): 62 | self.conv1 = nn.Conv2d(dim_in, dim_in, 5, 1, 2) 63 | self.conv2 = nn.Conv2d(dim_in, dim_out, 5, 1, 2) 64 | if self.normalize: 65 | self.norm1 = nn.BatchNorm2d(dim_in) 66 | self.norm2 = nn.BatchNorm2d(dim_in) 67 | if self.learned_sc: 68 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 69 | 70 | def _shortcut(self, x): 71 | if self.learned_sc: 72 | x = self.conv1x1(x) 73 | if self.downsample: 74 | x = F.avg_pool2d(x, 2) 75 | return x 76 | 77 | def _residual(self, x): 78 | if self.normalize: 79 | x = self.norm1(x) 80 | x = self.actv(x) 81 | x = self.conv1(x) 82 | if self.downsample: 83 | x = F.avg_pool2d(x, 2) 84 | if self.normalize: 85 | x = self.norm2(x) 86 | x = self.actv(x) 87 | x = self.conv2(x) 88 | return x 89 | 90 | def forward(self, x): 91 | x = self._shortcut(x) + self._residual(x) 92 | return x / math.sqrt(2) # unit variance 93 | 94 | class GenResBlk(nn.Module): 95 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), upsample=False): 96 | super().__init__() 97 | self.actv = actv 98 | self.upsample = upsample 99 | self.learned_sc = dim_in != dim_out 100 | self._build_weights(dim_in, dim_out) 101 | 102 | def _build_weights(self, dim_in, dim_out): 103 | self.conv1 = nn.Conv2d(dim_in, dim_out, 5, 1, 2) 104 | self.conv2 = nn.Conv2d(dim_out, dim_out, 5, 1, 2) 105 | self.norm1 = nn.BatchNorm2d(dim_in) 106 | self.norm2 = nn.BatchNorm2d(dim_out) 107 | if self.learned_sc: 108 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 109 | 110 | def _shortcut(self, x): 111 | if self.upsample: 112 | x = F.interpolate(x, scale_factor=2, mode='nearest') 113 | if self.learned_sc: 114 | x = self.conv1x1(x) 115 | return x 116 | 117 | def _residual(self, x): 118 | x = self.norm1(x) 119 | x = self.actv(x) 120 | if self.upsample: 121 | x = F.interpolate(x, scale_factor=2, mode='nearest') 122 | x = self.conv1(x) 123 | x = self.norm2(x) 124 | x = self.actv(x) 125 | x = self.conv2(x) 126 | return x 127 | 128 | def forward(self, x): 129 | out = self._residual(x) 130 | out = (out + self._shortcut(x)) / math.sqrt(2) 131 | return out 132 | 133 | class Flatten(nn.Module): 134 | def forward(self, input): 135 | return input.view(input.size(0), -1) 136 | 137 | class Avgpool(nn.Module): 138 | def forward(self, input): 139 | #input:B,C,H,W 140 | return input.mean([2, 3]) 141 | 142 | class AVAttention(nn.Module): 143 | def __init__(self, out_dim): 144 | super().__init__() 145 | 146 | self.softmax = nn.Softmax(2) 147 | self.k = nn.Linear(512, out_dim) 148 | self.v = nn.Linear(512, out_dim) 149 | self.q = nn.Linear(2560, out_dim) 150 | self.out_dim = out_dim 151 | dim = 20 * 64 152 | self.mel = nn.Linear(out_dim, dim) 153 | 154 | def forward(self, ph, g, len): 155 | #ph: B,S,512 156 | #g: B,C,F,T 157 | B, C, F, T = g.size() 158 | k = self.k(ph).transpose(1, 2).contiguous() # B,256,S 159 | q = self.q(g.view(B, C * F, T).transpose(1, 2).contiguous()) # B,T,256 160 | 161 | att = torch.bmm(q, k) / math.sqrt(self.out_dim) # B,T,S 162 | for i in range(att.size(0)): 163 | att[i, :, len[i]:] = float('-inf') 164 | att = self.softmax(att) # B,T,S 165 | 166 | v = self.v(ph) # B,S,256 167 | value = torch.bmm(att, v) # B,T,256 168 | out = self.mel(value) # B, T, 20*64 169 | out = out.view(B, T, F, -1).permute(0, 3, 2, 1) 170 | 171 | return out #B,C,F,T 172 | 173 | class Postnet(nn.Module): 174 | def __init__(self): 175 | super().__init__() 176 | 177 | self.postnet = nn.Sequential( 178 | nn.Conv1d(80, 128, 7, 1, 3), 179 | nn.BatchNorm1d(128), 180 | nn.LeakyReLU(0.2), 181 | ResBlk1D(128, 256), 182 | ResBlk1D(256, 256), 183 | ResBlk1D(256, 256), 184 | nn.Conv1d(256, 321, 1, 1, 0, bias=False) 185 | ) 186 | 187 | def forward(self, x): 188 | # x: B,1,80,T 189 | x = x.squeeze(1) # B, 80, t 190 | x = self.postnet(x) # B, 321, T 191 | x = x.unsqueeze(1) # B, 1, 321, T 192 | return x 193 | 194 | class Decoder(nn.Module): 195 | def __init__(self): 196 | super().__init__() 197 | 198 | self.decode = nn.ModuleList() 199 | self.g1 = nn.ModuleList() 200 | self.g2 = nn.ModuleList() 201 | self.g3 = nn.ModuleList() 202 | 203 | self.att1 = AVAttention(256) 204 | self.attconv1 = nn.Conv2d(128 + 64, 128, 5, 1, 2) 205 | self.att2 = AVAttention(256) 206 | self.attconv2 = nn.Conv2d(64 + 32, 64, 5, 1, 2) 207 | 208 | self.to_mel1 = nn.Sequential( 209 | nn.BatchNorm2d(128), 210 | nn.LeakyReLU(0.2), 211 | nn.Conv2d(128, 1, 1, 1, 0), 212 | nn.Tanh() 213 | ) 214 | self.to_mel2 = nn.Sequential( 215 | nn.BatchNorm2d(64), 216 | nn.LeakyReLU(0.2), 217 | nn.Conv2d(64, 1, 1, 1, 0), 218 | nn.Tanh() 219 | ) 220 | self.to_mel3 = nn.Sequential( 221 | nn.BatchNorm2d(32), 222 | nn.LeakyReLU(0.2), 223 | nn.Conv2d(32, 1, 1, 1, 0), 224 | nn.Tanh() 225 | ) 226 | 227 | # bottleneck blocks 228 | self.decode.append(GenResBlk(512 + 128, 512)) # 20,T 229 | self.decode.append(GenResBlk(512, 256)) 230 | self.decode.append(GenResBlk(256, 256)) 231 | 232 | # up-sampling blocks 233 | self.g1.append(GenResBlk(256, 128)) # 20,T 234 | self.g1.append(GenResBlk(128, 128)) 235 | self.g1.append(GenResBlk(128, 128)) 236 | 237 | self.g2.append(GenResBlk(128, 64, upsample=True)) # 40,2T 238 | self.g2.append(GenResBlk(64, 64)) 239 | self.g2.append(GenResBlk(64, 64)) 240 | 241 | self.g3.append(GenResBlk(64, 32, upsample=True)) # 80,4T 242 | self.g3.append(GenResBlk(32, 32)) 243 | self.g3.append(GenResBlk(32, 32)) 244 | 245 | def forward(self, s, x, len): 246 | # s: B,512,T x: B,T,512 247 | s = s.transpose(1, 2).contiguous() 248 | n = torch.randn([x.size(0), 128, 20, x.size(1)]).cuda() # B,128,20,T 249 | x = x.transpose(1, 2).contiguous().unsqueeze(2).repeat(1, 1, 20, 1) # B, 512, 20, T 250 | x = torch.cat([x, n], 1) 251 | for block in self.decode: 252 | x = block(x) 253 | for block in self.g1: 254 | x = block(x) 255 | g1 = x.clone() 256 | c1 = self.att1(s, g1, len) 257 | x = self.attconv1(torch.cat([x, c1], 1)) 258 | for block in self.g2: 259 | x = block(x) 260 | g2 = x.clone() 261 | c2 = self.att2(s, g2, len) 262 | x = self.attconv2(torch.cat([x, c2], 1)) 263 | for block in self.g3: 264 | x = block(x) 265 | return self.to_mel1(g1), self.to_mel2(g2), self.to_mel3(x) 266 | 267 | class Discriminator(nn.Module): 268 | def __init__(self, num_class=1, max_conv_dim=512, phase='1'): 269 | super().__init__() 270 | dim_in = 32 271 | blocks = [] 272 | blocks += [nn.Conv2d(1, dim_in, 5, 1, 2)] 273 | 274 | if phase == '1': 275 | repeat_num = 2 276 | elif phase == '2': 277 | repeat_num = 3 278 | else: 279 | repeat_num = 4 280 | 281 | for _ in range(repeat_num): # 80,4T --> 40,2T --> 20,T --> 10,T/2 --> 5,T/4 282 | dim_out = min(dim_in*2, max_conv_dim) 283 | blocks += [ResBlk(dim_in, dim_out, downsample=True)] 284 | dim_in = dim_out 285 | 286 | self.main = nn.Sequential(*blocks) 287 | 288 | uncond = [] 289 | uncond += [nn.LeakyReLU(0.2)] 290 | uncond += [nn.Conv2d(dim_out, dim_out, 5, 1, 0)] 291 | uncond += [nn.LeakyReLU(0.2)] 292 | uncond += [Avgpool()] 293 | uncond += [nn.Linear(dim_out, num_class)] 294 | self.uncond = nn.Sequential(*uncond) 295 | 296 | cond = [] 297 | cond += [nn.LeakyReLU(0.2)] 298 | cond += [nn.Conv2d(dim_out + 512, dim_out, 5, 1, 2)] 299 | cond += [nn.LeakyReLU(0.2)] 300 | cond += [nn.Conv2d(dim_out, dim_out, 5, 1, 0)] 301 | cond += [nn.LeakyReLU(0.2)] 302 | cond += [Avgpool()] 303 | cond += [nn.Linear(dim_out, num_class)] 304 | self.cond = nn.Sequential(*cond) 305 | 306 | def forward(self, x, c, vid_max_length): 307 | # c: B,C,T 308 | f_len = final_length(vid_max_length) 309 | c = c.mean(2) #B,C 310 | c = c.unsqueeze(2).unsqueeze(2).repeat(1, 1, 5, f_len) 311 | out = self.main(x).clone() 312 | uout = self.uncond(out) 313 | out = torch.cat([out, c], dim=1) 314 | cout = self.cond(out) 315 | uout = uout.view(uout.size(0), -1) # (batch, num_domains) 316 | cout = cout.view(cout.size(0), -1) # (batch, num_domains) 317 | return uout, cout 318 | 319 | class sync_Discriminator(nn.Module): 320 | def __init__(self, temp=1.0): 321 | super().__init__() 322 | 323 | self.frontend = nn.Sequential( 324 | nn.Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 325 | nn.BatchNorm2d(128), 326 | nn.PReLU(128), 327 | nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 328 | nn.BatchNorm2d(256), 329 | nn.PReLU(256) 330 | ) 331 | 332 | self.Res_block = nn.Sequential( 333 | BasicBlock(256, 256) 334 | ) 335 | 336 | self.Linear = nn.Linear(256 * 20, 512) 337 | self.temp = temp 338 | 339 | def forward(self, v_feat, aud, gen=False): 340 | # v_feat: B, S, 512 341 | a_feat = self.frontend(aud) 342 | a_feat = self.Res_block(a_feat) 343 | b, c, f, t = a_feat.size() 344 | a_feat = a_feat.view(b, c * f, t).transpose(1, 2).contiguous() # B, T/4, 256 * F/4 345 | a_feat = self.Linear(a_feat) # B, S, 512 346 | 347 | if gen: 348 | sim = torch.abs(F.cosine_similarity(v_feat, a_feat, 2)).mean(1) #B, S 349 | loss = 5 * torch.ones_like(sim) - sim 350 | else: 351 | v_feat_norm = F.normalize(v_feat, dim=2) #B,S,512 352 | a_feat_norm = F.normalize(a_feat, dim=2) #B,S,512 353 | 354 | sim = torch.bmm(v_feat_norm, a_feat_norm.transpose(1, 2)) / self.temp #B,v_S,a_S 355 | 356 | nce_va = torch.mean(torch.diagonal(F.log_softmax(sim, dim=2), dim1=-2, dim2=-1), dim=1) 357 | nce_av = torch.mean(torch.diagonal(F.log_softmax(sim, dim=1), dim1=-2, dim2=-1), dim=1) 358 | 359 | loss = -1/2 * (nce_va + nce_av) 360 | 361 | return loss 362 | 363 | def gan_loss(inputs, label=None): 364 | # non-saturating loss with R1 regularization 365 | l = -1 if label else 1 366 | return F.softplus(l*inputs).mean() 367 | 368 | def final_length(vid_length): 369 | half = (vid_length // 2) 370 | quad = (half // 2) 371 | return quad -------------------------------------------------------------------------------- /src/data/vid_aud_lrs3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchaudio 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader, Dataset 10 | from src.data.transforms import Crop, StatefulRandomHorizontalFlip 11 | from PIL import Image 12 | import librosa 13 | import cv2 14 | from matplotlib import pyplot as plt 15 | import glob 16 | from scipy import signal 17 | import torchvision 18 | from torch.autograd import Variable 19 | from librosa.filters import mel as librosa_mel_fn 20 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression, griffin_lim 21 | from src.data.stft import STFT 22 | import math 23 | log1e5 = math.log(1e-5) 24 | 25 | class MultiDataset(Dataset): 26 | def __init__(self, data, mode, max_v_timesteps=155, window_size=40, augmentations=False, num_mel_bins=80, fast_validate=False, f_min=55., f_max=7600.): 27 | assert mode in ['pretrain', 'train', 'test', 'val'] 28 | self.data = data 29 | self.sample_window = True if (mode == 'pretrain' or mode == 'train') else False 30 | self.fast_validate = fast_validate 31 | self.max_v_timesteps = window_size if self.sample_window else max_v_timesteps 32 | self.window_size = window_size 33 | self.augmentations = augmentations if mode == 'train' else False 34 | self.num_mel_bins = num_mel_bins 35 | self.file_paths, self.file_names, self.crops = self.build_file_list(data, mode) 36 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=num_mel_bins, sampling_rate=16000, mel_fmin=f_min, mel_fmax=f_max) 37 | 38 | def build_file_list(self, lrs3, mode): 39 | file_list, paths = [], [] 40 | crops = {} 41 | 42 | ## LRS3 crop (lip centered axis) load ## 43 | file = open(f"./data/LRS3/LRS3_crop/preprocess_pretrain.txt", "r") 44 | content = file.read() 45 | file.close() 46 | for i, line in enumerate(content.splitlines()): 47 | split = line.split(".") 48 | file = split[0] 49 | crop_str = split[1][4:] 50 | crops['pretrain/' + file] = crop_str 51 | file = open(f"./data/LRS3/LRS3_crop/preprocess_test.txt", "r") 52 | content = file.read() 53 | file.close() 54 | for i, line in enumerate(content.splitlines()): 55 | split = line.split(".") 56 | file = split[0] 57 | crop_str = split[1][4:] 58 | crops['test/' + file] = crop_str 59 | file = open(f"./data/LRS3/LRS3_crop/preprocess_trainval.txt", "r") 60 | content = file.read() 61 | file.close() 62 | for i, line in enumerate(content.splitlines()): 63 | split = line.split(".") 64 | file = split[0] 65 | crop_str = split[1][4:] 66 | crops['trainval/' + file] = crop_str 67 | 68 | ## LRS3 file lists## 69 | file = open(f"./data/LRS3/lrs3_unseen_{mode}.txt", "r") 70 | content = file.read() 71 | file.close() 72 | for file in content.splitlines(): 73 | if file in crops: 74 | file_list.append(file) 75 | paths.append(f"{lrs3}/{file}") 76 | 77 | print(f'Mode: {mode}, File Num: {len(file_list)}') 78 | return paths, file_list, crops 79 | 80 | def build_tensor(self, frames, crops): 81 | if self.augmentations: 82 | s = random.randint(-5, 5) 83 | else: 84 | s = 0 85 | crop = [] 86 | for i in range(0, len(crops), 2): 87 | left = int(crops[i]) - 40 + s 88 | upper = int(crops[i + 1]) - 40 + s 89 | right = int(crops[i]) + 40 + s 90 | bottom = int(crops[i + 1]) + 40 + s 91 | crop.append([left, upper, right, bottom]) 92 | crops = crop 93 | 94 | if self.augmentations: 95 | augmentations1 = transforms.Compose([StatefulRandomHorizontalFlip(0.5)]) 96 | else: 97 | augmentations1 = transforms.Compose([]) 98 | 99 | temporalVolume = torch.zeros(self.max_v_timesteps, 1, 112, 112) 100 | for i, frame in enumerate(frames): 101 | transform = transforms.Compose([ 102 | transforms.ToPILImage(), 103 | Crop(crops[i]), 104 | transforms.Resize([112, 112]), 105 | augmentations1, 106 | transforms.Grayscale(num_output_channels=1), 107 | transforms.ToTensor(), 108 | transforms.Normalize(0.4136, 0.1700), 109 | ]) 110 | temporalVolume[i] = transform(frame) 111 | 112 | temporalVolume = temporalVolume.transpose(1, 0) # (C, T, H, W) 113 | return temporalVolume 114 | 115 | def __len__(self): 116 | return len(self.file_paths) 117 | 118 | def __getitem__(self, idx): 119 | file = self.file_names[idx] 120 | file_path = self.file_paths[idx] 121 | crops = self.crops[file].split("/") 122 | 123 | info = {} 124 | info['video_fps'] = 25 125 | cap = cv2.VideoCapture(file_path + '.mp4') 126 | frames = [] 127 | while (cap.isOpened()): 128 | ret, frame = cap.read() 129 | if ret: 130 | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 131 | else: 132 | break 133 | cap.release() 134 | audio, info['audio_fps'] = librosa.load(file_path.replace('LRS3-TED', 'LRS3-TED_audio') + '.wav', sr=16000) 135 | vid = torch.tensor(np.stack(frames, 0)) 136 | audio = torch.tensor(audio).unsqueeze(0) 137 | 138 | if not 'video_fps' in info: 139 | info['video_fps'] = 25 140 | info['audio_fps'] = 16000 141 | 142 | assert vid.size(0) > 5 or audio.size(1) > 5 143 | 144 | ## Audio ## 145 | audio = audio / torch.abs(audio).max() * 0.9 146 | aud = torch.FloatTensor(self.preemphasize(audio.squeeze(0))).unsqueeze(0) 147 | aud = torch.clamp(aud, min=-1, max=1) 148 | 149 | melspec, spec = self.stft.mel_spectrogram(aud) 150 | 151 | ## Video ## 152 | vid = vid.permute(0, 3, 1, 2) # T C H W 153 | 154 | if self.sample_window: 155 | vid, melspec, spec, audio, crops = self.extract_window(vid, melspec, spec, audio, info, crops) 156 | elif vid.size(0) > self.max_v_timesteps: 157 | print('Sample is longer than Max video frames! Trimming to the length of ', self.max_v_timesteps) 158 | vid = vid[:self.max_v_timesteps] 159 | melspec = melspec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)] 160 | spec = spec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)] 161 | audio = audio[:, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'])] 162 | crops = crops[:self.max_v_timesteps * 2] 163 | 164 | num_v_frames = vid.size(0) 165 | vid = self.build_tensor(vid, crops) 166 | 167 | melspec = self.normalize(melspec) #0~2 --> -1~1 168 | 169 | spec = self.normalize_spec(spec) # 0 ~ 1 170 | spec = self.stft.spectral_normalize(spec) # log(1e-5) ~ 0 # in log scale 171 | spec = self.normalize(spec) # -1 ~ 1 172 | 173 | num_a_frames = melspec.size(2) 174 | melspec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(melspec) 175 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(spec) 176 | 177 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames, file_path.replace(self.data, '')[1:] 178 | 179 | def extract_window(self, vid, mel, spec, aud, info, crops): 180 | # vid : T,C,H,W 181 | st_fr = random.randint(0, max(0, vid.size(0) - self.window_size)) 182 | vid = vid[st_fr:st_fr + self.window_size] 183 | crops = crops[st_fr * 2: st_fr * 2 + self.window_size * 2] 184 | 185 | assert vid.size(0) * 2 == len(crops), f'vid length: {vid.size(0)}, crop length: {len(crops)}' 186 | 187 | st_mel_fr = int(st_fr * info['audio_fps'] / info['video_fps'] / 160) 188 | mel_window_size = int(self.window_size * info['audio_fps'] / info['video_fps'] / 160) 189 | mel = mel[:, :, st_mel_fr:st_mel_fr + mel_window_size] 190 | spec = spec[:, :, st_mel_fr:st_mel_fr + mel_window_size] 191 | aud = aud[:, st_mel_fr*160:st_mel_fr*160 + mel_window_size*160] 192 | aud = torch.cat([aud, torch.zeros([1, int(self.window_size / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1) 193 | 194 | return vid, mel, spec, aud, crops 195 | 196 | def collate_fn(self, batch): 197 | vid_lengths, spec_lengths, padded_spec_lengths, aud_lengths = [], [], [], [] 198 | for data in batch: 199 | vid_lengths.append(data[3]) 200 | spec_lengths.append(data[5]) 201 | padded_spec_lengths.append(data[0].size(2)) 202 | aud_lengths.append(data[4].size(0)) 203 | 204 | max_aud_length = max(aud_lengths) 205 | max_spec_length = max(padded_spec_lengths) 206 | padded_vid = [] 207 | padded_melspec = [] 208 | padded_spec = [] 209 | padded_audio = [] 210 | f_names = [] 211 | 212 | for i, (melspec, spec, vid, num_v_frames, audio, spec_len, f_name) in enumerate(batch): 213 | padded_vid.append(vid) # B, C, T, H, W 214 | padded_melspec.append(nn.ConstantPad2d((0, max_spec_length - melspec.size(2), 0, 0), -1.0)(melspec)) 215 | padded_spec.append(nn.ConstantPad2d((0, max_spec_length - spec.size(2), 0, 0), -1.0)(spec)) 216 | padded_audio.append(torch.cat([audio, torch.zeros([max_aud_length - audio.size(0)])], 0)) 217 | f_names.append(f_name) 218 | 219 | vid = torch.stack(padded_vid, 0).float() 220 | vid_length = torch.IntTensor(vid_lengths) 221 | melspec = torch.stack(padded_melspec, 0).float() 222 | spec = torch.stack(padded_spec, 0).float() 223 | spec_length = torch.IntTensor(spec_lengths) 224 | audio = torch.stack(padded_audio, 0).float() 225 | 226 | return melspec, spec, vid, vid_length, audio, spec_length, f_names 227 | 228 | def inverse_mel(self, mel, stft): 229 | if len(mel.size()) < 4: 230 | mel = mel.unsqueeze(0) #B,1,80,T 231 | 232 | mel = self.denormalize(mel) 233 | mel = stft.spectral_de_normalize(mel) 234 | mel = mel.transpose(2, 3).contiguous() #B,80,T --> B,T,80 235 | spec_from_mel_scaling = 1000 236 | spec_from_mel = torch.matmul(mel, stft.mel_basis) 237 | spec_from_mel = spec_from_mel.transpose(2, 3).squeeze(1) # B,1,F,T 238 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 239 | 240 | wav = griffin_lim(spec_from_mel, stft.stft_fn, 60).squeeze(1) #B,L 241 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy() 242 | wavs = [] 243 | for w in wav: 244 | w = self.deemphasize(w) 245 | wavs += [w] 246 | wavs = np.array(wavs) 247 | wavs = np.clip(wavs, -1, 1) 248 | return wavs 249 | 250 | def inverse_spec(self, spec, stft): 251 | if len(spec.size()) < 4: 252 | spec = spec.unsqueeze(0) #B,1,321,T 253 | 254 | spec = self.denormalize(spec) # log1e5 ~ 0 255 | spec = stft.spectral_de_normalize(spec) # 0 ~ 1 256 | spec = self.denormalize_spec(spec) # 0 ~ 14 257 | wav = griffin_lim(spec.squeeze(1), stft.stft_fn, 60).squeeze(1) #B,L 258 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy() 259 | wavs = [] 260 | for w in wav: 261 | w = self.deemphasize(w) 262 | wavs += [w] 263 | wavs = np.array(wavs) 264 | wavs = np.clip(wavs, -1, 1) 265 | return wavs 266 | 267 | def preemphasize(self, aud): 268 | aud = signal.lfilter([1, -0.97], [1], aud) 269 | return aud 270 | 271 | def deemphasize(self, aud): 272 | aud = signal.lfilter([1], [1, -0.97], aud) 273 | return aud 274 | 275 | def normalize(self, melspec): 276 | melspec = ((melspec - log1e5) / (-log1e5 / 2)) - 1 #0~2 --> -1~1 277 | return melspec 278 | 279 | def denormalize(self, melspec): 280 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5 281 | return melspec 282 | 283 | def normalize_spec(self, spec): 284 | spec = (spec - spec.min()) / (spec.max() - spec.min()) # 0 ~ 1 285 | return spec 286 | 287 | def denormalize_spec(self, spec): 288 | spec = spec * 14. # 0 ~ 14 289 | return spec 290 | 291 | def plot_spectrogram_to_numpy(self, mels): 292 | fig, ax = plt.subplots(figsize=(15, 4)) 293 | im = ax.imshow(np.squeeze(mels, 0), aspect="auto", origin="lower", 294 | interpolation='none') 295 | plt.colorbar(im, ax=ax) 296 | plt.xlabel("Frames") 297 | plt.ylabel("Channels") 298 | plt.tight_layout() 299 | 300 | fig.canvas.draw() 301 | data = self.save_figure_to_numpy(fig) 302 | plt.close() 303 | return data 304 | 305 | def save_figure_to_numpy(self, fig): 306 | # save it to a numpy array. 307 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 308 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 309 | return data.transpose(2, 0, 1) 310 | 311 | class TacotronSTFT(torch.nn.Module): 312 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 313 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 314 | mel_fmax=8000.0): 315 | super(TacotronSTFT, self).__init__() 316 | self.n_mel_channels = n_mel_channels 317 | self.sampling_rate = sampling_rate 318 | self.stft_fn = STFT(filter_length, hop_length, win_length) 319 | mel_basis = librosa_mel_fn( 320 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 321 | mel_basis = torch.from_numpy(mel_basis).float() 322 | self.register_buffer('mel_basis', mel_basis) 323 | 324 | def spectral_normalize(self, magnitudes): 325 | output = dynamic_range_compression(magnitudes) 326 | return output 327 | 328 | def spectral_de_normalize(self, magnitudes): 329 | output = dynamic_range_decompression(magnitudes) 330 | return output 331 | 332 | def mel_spectrogram(self, y): 333 | """Computes mel-spectrograms from a batch of waves 334 | PARAMS 335 | ------ 336 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 337 | RETURNS 338 | ------- 339 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 340 | """ 341 | assert(torch.min(y.data) >= -1) 342 | assert(torch.max(y.data) <= 1) 343 | 344 | magnitudes, phases = self.stft_fn.transform(y) 345 | magnitudes = magnitudes.data 346 | mel_output = torch.matmul(self.mel_basis, magnitudes) 347 | mel_output = self.spectral_normalize(mel_output) 348 | return mel_output, magnitudes 349 | -------------------------------------------------------------------------------- /src/data/vid_aud_lrs2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchaudio 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader, Dataset 10 | from src.data.transforms import Crop, StatefulRandomHorizontalFlip 11 | from PIL import Image 12 | import librosa 13 | import cv2 14 | import matplotlib 15 | matplotlib.use('Agg') 16 | from matplotlib import pyplot as plt 17 | import glob 18 | from scipy import signal 19 | import torchvision 20 | from torch.autograd import Variable 21 | from librosa.filters import mel as librosa_mel_fn 22 | from src.data.audio_processing import dynamic_range_compression, dynamic_range_decompression, griffin_lim 23 | from src.data.stft import STFT 24 | import math 25 | log1e5 = math.log(1e-5) 26 | 27 | class MultiDataset(Dataset): 28 | def __init__(self, data, mode, max_v_timesteps=155, window_size=40, augmentations=False, num_mel_bins=80, fast_validate=False, f_min=55., f_max=7600.): 29 | assert mode in ['train', 'test', 'val'] 30 | self.data = data 31 | self.sample_window = True if (mode == 'train') else False 32 | self.fast_validate = fast_validate 33 | self.max_v_timesteps = window_size if self.sample_window else max_v_timesteps 34 | self.window_size = window_size 35 | self.augmentations = augmentations if mode == 'train' else False 36 | self.num_mel_bins = num_mel_bins 37 | self.file_paths, self.file_names, self.crops = self.build_file_list(data, mode) 38 | self.stft = TacotronSTFT(filter_length=640, hop_length=160, win_length=640, n_mel_channels=num_mel_bins, sampling_rate=16000, mel_fmin=f_min, mel_fmax=f_max) 39 | 40 | def build_file_list(self, lrs2, mode): 41 | file_list, paths = [], [] 42 | crops = {} 43 | lrs2_mode = 'main' 44 | 45 | ## LRS2 crop (lip centered axis) load ## 46 | file = open(f"./data/LRS2/LRS2_crop/preprocess_{lrs2_mode}.txt", "r") 47 | content = file.read() 48 | file.close() 49 | for i, line in enumerate(content.splitlines()): 50 | split = line.split(".") 51 | file = split[0] 52 | crop_str = split[1][4:] 53 | crops[f'{lrs2_mode}/{file}'] = crop_str 54 | ## LRS2 file in crop ## 55 | file = open(f"./data/LRS2/{mode}.txt", "r") 56 | content = file.readlines() 57 | file.close() 58 | for file in content: 59 | file = file.strip().split()[0] 60 | file = f'{lrs2_mode}/{file}' 61 | if file in crops: 62 | file_list.append(file) 63 | paths.append(f"{lrs2}/{file}") 64 | if mode == 'train': 65 | file = open(f"./data/LRS2/LRS2_crop/preprocess_pretrain.txt", "r") 66 | content = file.read() 67 | file.close() 68 | for i, line in enumerate(content.splitlines()): 69 | split = line.split(".") 70 | file = split[0] 71 | crop_str = split[1][4:] 72 | crops[f'pretrain/{file}'] = crop_str 73 | ## LRS2 file in crop ## 74 | file = open(f"./data/LRS2/pretrain.txt", "r") 75 | content = file.readlines() 76 | file.close() 77 | for file in content: 78 | file = file.strip().split()[0] 79 | file = f'pretrain/{file}' 80 | if file in crops: 81 | file_list.append(file) 82 | paths.append(f"{lrs2}/{file}") 83 | 84 | print(f'Mode: {mode}, File Num: {len(file_list)}') 85 | return paths, file_list, crops 86 | 87 | def build_tensor(self, frames, crops): 88 | if self.augmentations: 89 | s = random.randint(-5, 5) 90 | else: 91 | s = 0 92 | crop = [] 93 | for i in range(0, len(crops), 2): 94 | left = int(crops[i]) - 40 + s 95 | upper = int(crops[i + 1]) - 40 + s 96 | right = int(crops[i]) + 40 + s 97 | bottom = int(crops[i + 1]) + 40 + s 98 | crop.append([left, upper, right, bottom]) 99 | crops = crop 100 | 101 | if self.augmentations: 102 | augmentations1 = transforms.Compose([StatefulRandomHorizontalFlip(0.5)]) 103 | else: 104 | augmentations1 = transforms.Compose([]) 105 | 106 | temporalVolume = torch.zeros(self.max_v_timesteps, 1, 112, 112) 107 | for i, frame in enumerate(frames): 108 | transform = transforms.Compose([ 109 | transforms.ToPILImage(), 110 | Crop(crops[i]), 111 | transforms.Resize([112, 112]), 112 | augmentations1, 113 | transforms.Grayscale(num_output_channels=1), 114 | transforms.ToTensor(), 115 | transforms.Normalize(0.4136, 0.1700), 116 | ]) 117 | temporalVolume[i] = transform(frame) 118 | 119 | temporalVolume = temporalVolume.transpose(1, 0) # (C, T, H, W) 120 | return temporalVolume 121 | 122 | def __len__(self): 123 | return len(self.file_paths) 124 | 125 | def __getitem__(self, idx): 126 | file = self.file_names[idx] 127 | file_path = self.file_paths[idx] 128 | crops = self.crops[file].split("/") 129 | 130 | info = {} 131 | info['video_fps'] = 25 132 | cap = cv2.VideoCapture(file_path + '.mp4') 133 | frames = [] 134 | while (cap.isOpened()): 135 | ret, frame = cap.read() 136 | if ret: 137 | frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 138 | else: 139 | break 140 | cap.release() 141 | audio, info['audio_fps'] = librosa.load(file_path.replace('LRS2-BBC', 'LRS2-BBC_audio') + '.wav', sr=16000) 142 | vid = torch.tensor(np.stack(frames, 0)) 143 | audio = torch.tensor(audio).unsqueeze(0) 144 | 145 | if not 'video_fps' in info: 146 | info['video_fps'] = 25 147 | info['audio_fps'] = 16000 148 | 149 | assert vid.size(0) > 5 or audio.size(1) > 5 150 | 151 | ## Audio ## 152 | audio = audio / torch.abs(audio).max() * 0.9 153 | aud = torch.FloatTensor(self.preemphasize(audio.squeeze(0))).unsqueeze(0) 154 | aud = torch.clamp(aud, min=-1, max=1) 155 | 156 | melspec, spec = self.stft.mel_spectrogram(aud) 157 | 158 | ## Video ## 159 | vid = vid.permute(0, 3, 1, 2) # T C H W 160 | 161 | if self.sample_window: 162 | vid, melspec, spec, audio, crops = self.extract_window(vid, melspec, spec, audio, info, crops) 163 | elif vid.size(0) > self.max_v_timesteps: 164 | print('Sample is longer than Max video frames! Trimming to the length of ', self.max_v_timesteps) 165 | vid = vid[:self.max_v_timesteps] 166 | melspec = melspec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)] 167 | spec = spec[:, :, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'] / 160)] 168 | audio = audio[:, :int(self.max_v_timesteps * info['audio_fps'] / info['video_fps'])] 169 | crops = crops[:self.max_v_timesteps * 2] 170 | 171 | num_v_frames = vid.size(0) 172 | vid = self.build_tensor(vid, crops) 173 | 174 | melspec = self.normalize(melspec) #0~2 --> -1~1 175 | 176 | spec = self.normalize_spec(spec) # 0 ~ 1 177 | spec = self.stft.spectral_normalize(spec) # log(1e-5) ~ 0 # in log scale 178 | spec = self.normalize(spec) # -1 ~ 1 179 | 180 | num_a_frames = melspec.size(2) 181 | melspec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(melspec) 182 | spec = nn.ConstantPad2d((0, self.max_v_timesteps * 4 - num_a_frames, 0, 0), -1.0)(spec) 183 | 184 | return melspec, spec, vid, num_v_frames, audio.squeeze(0), num_a_frames, file_path.replace(self.data, '')[1:] 185 | 186 | def extract_window(self, vid, mel, spec, aud, info, crops): 187 | # vid : T,C,H,W 188 | st_fr = random.randint(0, max(0, vid.size(0) - self.window_size)) 189 | vid = vid[st_fr:st_fr + self.window_size] 190 | crops = crops[st_fr * 2: st_fr * 2 + self.window_size * 2] 191 | 192 | assert vid.size(0) * 2 == len(crops), f'vid length: {vid.size(0)}, crop length: {len(crops)}' 193 | 194 | st_mel_fr = int(st_fr * info['audio_fps'] / info['video_fps'] / 160) 195 | mel_window_size = int(self.window_size * info['audio_fps'] / info['video_fps'] / 160) 196 | mel = mel[:, :, st_mel_fr:st_mel_fr + mel_window_size] 197 | spec = spec[:, :, st_mel_fr:st_mel_fr + mel_window_size] 198 | aud = aud[:, st_mel_fr*160:st_mel_fr*160 + mel_window_size*160] 199 | aud = torch.cat([aud, torch.zeros([1, int(self.window_size / info['video_fps'] * info['audio_fps'] - aud.size(1))])], 1) 200 | 201 | return vid, mel, spec, aud, crops 202 | 203 | def collate_fn(self, batch): 204 | vid_lengths, spec_lengths, padded_spec_lengths, aud_lengths = [], [], [], [] 205 | for data in batch: 206 | vid_lengths.append(data[3]) 207 | spec_lengths.append(data[5]) 208 | padded_spec_lengths.append(data[0].size(2)) 209 | aud_lengths.append(data[4].size(0)) 210 | 211 | max_aud_length = max(aud_lengths) 212 | max_spec_length = max(padded_spec_lengths) 213 | padded_vid = [] 214 | padded_melspec = [] 215 | padded_spec = [] 216 | padded_audio = [] 217 | f_names = [] 218 | 219 | for i, (melspec, spec, vid, num_v_frames, audio, spec_len, f_name) in enumerate(batch): 220 | padded_vid.append(vid) # B, C, T, H, W 221 | padded_melspec.append(nn.ConstantPad2d((0, max_spec_length - melspec.size(2), 0, 0), -1.0)(melspec)) 222 | padded_spec.append(nn.ConstantPad2d((0, max_spec_length - spec.size(2), 0, 0), -1.0)(spec)) 223 | padded_audio.append(torch.cat([audio, torch.zeros([max_aud_length - audio.size(0)])], 0)) 224 | f_names.append(f_name) 225 | 226 | vid = torch.stack(padded_vid, 0).float() 227 | vid_length = torch.IntTensor(vid_lengths) 228 | melspec = torch.stack(padded_melspec, 0).float() 229 | spec = torch.stack(padded_spec, 0).float() 230 | spec_length = torch.IntTensor(spec_lengths) 231 | audio = torch.stack(padded_audio, 0).float() 232 | 233 | return melspec, spec, vid, vid_length, audio, spec_length, f_names 234 | 235 | def inverse_mel(self, mel, stft): 236 | if len(mel.size()) < 4: 237 | mel = mel.unsqueeze(0) #B,1,80,T 238 | 239 | mel = self.denormalize(mel) 240 | mel = stft.spectral_de_normalize(mel) 241 | mel = mel.transpose(2, 3).contiguous() #B,80,T --> B,T,80 242 | spec_from_mel_scaling = 1000 243 | spec_from_mel = torch.matmul(mel, stft.mel_basis) 244 | spec_from_mel = spec_from_mel.transpose(2, 3).squeeze(1) # B,1,F,T 245 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 246 | 247 | wav = griffin_lim(spec_from_mel, stft.stft_fn, 60).squeeze(1) #B,L 248 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy() 249 | wavs = [] 250 | for w in wav: 251 | w = self.deemphasize(w) 252 | wavs += [w] 253 | wavs = np.array(wavs) 254 | wavs = np.clip(wavs, -1, 1) 255 | return wavs 256 | 257 | def inverse_spec(self, spec, stft): 258 | if len(spec.size()) < 4: 259 | spec = spec.unsqueeze(0) #B,1,321,T 260 | 261 | spec = self.denormalize(spec) # log1e5 ~ 0 262 | spec = stft.spectral_de_normalize(spec) # 0 ~ 1 263 | spec = self.denormalize_spec(spec) # 0 ~ 14 264 | wav = griffin_lim(spec.squeeze(1), stft.stft_fn, 60).squeeze(1) #B,L 265 | wav = wav.cpu().numpy() if wav.is_cuda else wav.numpy() 266 | wavs = [] 267 | for w in wav: 268 | w = self.deemphasize(w) 269 | wavs += [w] 270 | wavs = np.array(wavs) 271 | wavs = np.clip(wavs, -1, 1) 272 | return wavs 273 | 274 | def preemphasize(self, aud): 275 | aud = signal.lfilter([1, -0.97], [1], aud) 276 | return aud 277 | 278 | def deemphasize(self, aud): 279 | aud = signal.lfilter([1], [1, -0.97], aud) 280 | return aud 281 | 282 | def normalize(self, melspec): 283 | melspec = ((melspec - log1e5) / (-log1e5 / 2)) - 1 #0~2 --> -1~1 284 | return melspec 285 | 286 | def denormalize(self, melspec): 287 | melspec = ((melspec + 1) * (-log1e5 / 2)) + log1e5 288 | return melspec 289 | 290 | def normalize_spec(self, spec): 291 | spec = (spec - spec.min()) / (spec.max() - spec.min()) # 0 ~ 1 292 | return spec 293 | 294 | def denormalize_spec(self, spec): 295 | spec = spec * 14. # 0 ~ 14 296 | return spec 297 | 298 | def plot_spectrogram_to_numpy(self, mels): 299 | fig, ax = plt.subplots(figsize=(15, 4)) 300 | im = ax.imshow(np.squeeze(mels, 0), aspect="auto", origin="lower", 301 | interpolation='none') 302 | plt.colorbar(im, ax=ax) 303 | plt.xlabel("Frames") 304 | plt.ylabel("Channels") 305 | plt.tight_layout() 306 | 307 | fig.canvas.draw() 308 | data = self.save_figure_to_numpy(fig) 309 | plt.close() 310 | return data 311 | 312 | def save_figure_to_numpy(self, fig): 313 | # save it to a numpy array. 314 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 315 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 316 | return data.transpose(2, 0, 1) 317 | 318 | class TacotronSTFT(torch.nn.Module): 319 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 320 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 321 | mel_fmax=8000.0): 322 | super(TacotronSTFT, self).__init__() 323 | self.n_mel_channels = n_mel_channels 324 | self.sampling_rate = sampling_rate 325 | self.stft_fn = STFT(filter_length, hop_length, win_length) 326 | mel_basis = librosa_mel_fn( 327 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 328 | mel_basis = torch.from_numpy(mel_basis).float() 329 | self.register_buffer('mel_basis', mel_basis) 330 | 331 | def spectral_normalize(self, magnitudes): 332 | output = dynamic_range_compression(magnitudes) 333 | return output 334 | 335 | def spectral_de_normalize(self, magnitudes): 336 | output = dynamic_range_decompression(magnitudes) 337 | return output 338 | 339 | def mel_spectrogram(self, y): 340 | """Computes mel-spectrograms from a batch of waves 341 | PARAMS 342 | ------ 343 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 344 | RETURNS 345 | ------- 346 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 347 | """ 348 | assert(torch.min(y.data) >= -1) 349 | assert(torch.max(y.data) <= 1) 350 | 351 | magnitudes, phases = self.stft_fn.transform(y) 352 | magnitudes = magnitudes.data 353 | mel_output = torch.matmul(self.mel_basis, magnitudes) 354 | mel_output = self.spectral_normalize(mel_output) 355 | return mel_output, magnitudes 356 | -------------------------------------------------------------------------------- /preprocess/Ref_face.txt: -------------------------------------------------------------------------------- 1 | ABOUT/test/ABOUT_00001.mp4:71 98,71 106,71 115,72 124,72 132,73 141,74 149,76 158,79 166,82 174,87 181,92 188,97 194,103 199,111 202,118 205,126 207,134 206,142 204,150 201,157 196,163 190,169 184,174 177,178 169,181 161,184 153,185 144,187 135,188 126,189 117,190 108,190 99,81 82,89 74,97 73,106 73,113 76,113 81,105 80,97 80,89 80,144 75,153 71,161 71,171 72,178 81,170 78,162 77,153 79,144 80,129 94,129 106,128 116,129 127,114 133,121 135,129 136,135 135,143 134,89 95,94 93,99 92,106 93,111 96,105 97,100 98,94 97,146 96,152 93,159 91,164 93,169 95,164 98,158 98,152 97,103 155,113 151,124 149,129 150,133 151,145 153,154 157,147 165,137 169,127 170,118 169,110 163,105 156,117 154,129 155,141 154,152 157,141 162,128 164,116 161,100 95,158 95|71 96,71 104,71 114,71 122,72 131,73 140,74 148,75 157,78 166,82 173,86 181,91 188,96 193,103 199,111 202,118 205,126 207,134 206,142 204,149 200,157 195,163 189,168 183,174 176,178 168,181 160,183 152,185 143,186 134,188 125,188 116,189 107,189 98,81 82,89 74,97 72,105 73,113 75,113 81,105 80,96 80,89 80,144 75,153 73,161 73,170 73,177 82,169 79,161 78,152 80,144 81,128 94,129 106,128 116,129 127,114 133,121 135,128 136,135 135,143 134,89 95,94 93,99 91,106 92,111 96,106 97,100 98,94 97,145 96,151 93,158 91,164 93,168 96,163 98,158 99,152 98,102 155,112 151,123 149,129 151,133 151,144 152,154 157,146 164,137 168,127 169,117 168,109 162,105 155,117 154,128 155,140 154,152 157,140 161,127 163,116 160,100 95,157 95|70 96,70 104,70 114,70 122,71 131,72 140,73 148,75 157,78 166,81 174,86 181,91 188,96 194,103 199,110 203,117 205,126 207,134 206,142 204,149 200,156 195,163 190,168 184,173 177,177 169,180 160,183 152,184 144,186 135,188 125,188 117,189 108,189 99,80 79,88 71,96 70,105 71,113 74,113 80,104 78,96 77,88 77,144 74,153 71,161 71,170 72,177 81,169 78,161 76,153 78,144 80,128 94,128 106,128 117,128 128,113 134,121 136,128 137,135 136,142 135,89 94,94 92,99 91,106 92,111 96,105 97,99 97,94 97,145 96,151 94,158 92,164 94,168 96,163 98,158 99,152 98,102 155,112 152,123 149,128 151,133 151,144 153,154 157,146 165,137 169,127 171,117 169,109 163,104 156,116 154,128 156,140 155,151 157,140 162,127 164,115 161,100 95,157 95|69 96,69 104,69 114,70 122,70 131,71 140,72 148,74 157,77 166,80 174,85 181,90 188,96 194,102 199,110 203,117 205,126 207,133 207,141 204,149 200,156 195,162 190,167 183,173 177,177 168,180 160,183 152,184 143,186 134,187 125,188 116,189 107,189 98,79 81,87 72,96 71,105 72,113 75,112 81,104 79,96 78,88 79,144 75,153 72,161 72,170 73,178 82,170 79,161 78,152 79,143 81,128 95,128 106,128 118,128 129,113 134,120 136,127 137,134 136,142 135,89 95,93 93,99 91,106 92,111 96,105 98,99 98,93 97,144 97,151 94,157 92,163 94,168 96,163 99,157 99,151 98,102 155,112 152,123 150,128 151,133 151,143 153,153 157,146 164,137 169,126 170,117 168,109 163,104 156,116 155,128 156,139 155,151 157,140 162,127 164,115 161,100 95,157 95|71 97,71 105,71 115,71 123,72 132,73 140,74 148,76 157,79 166,82 173,87 181,92 187,97 193,104 199,111 202,118 204,127 206,135 206,142 203,149 200,157 195,163 189,168 183,174 177,178 169,181 161,184 152,185 144,187 135,188 126,189 118,190 108,190 99,81 82,89 73,97 72,106 73,114 75,114 81,105 80,97 79,89 80,145 75,154 72,162 72,171 73,179 82,171 79,162 78,154 79,145 81,129 95,129 107,129 118,129 129,114 134,121 136,128 137,136 136,143 135,90 96,95 93,100 92,107 93,112 97,106 98,100 99,95 98,146 97,152 94,159 92,165 94,169 97,164 99,159 100,152 98,102 155,113 152,124 150,129 151,134 151,145 153,155 157,147 164,138 168,127 169,118 168,110 162,105 155,117 154,129 156,141 155,153 156,141 162,128 163,116 160,101 95,158 96|70 96,70 104,70 114,71 122,72 131,73 139,73 147,75 156,78 165,81 172,86 180,91 186,96 192,103 197,111 201,118 203,126 205,134 205,142 202,150 199,157 194,163 188,169 182,174 176,178 167,181 159,184 151,185 143,187 134,188 125,188 116,190 107,190 98,80 84,88 75,97 73,106 74,114 76,114 82,105 81,96 81,88 81,145 76,154 73,163 73,172 74,179 84,171 80,163 78,154 80,144 82,129 96,129 108,129 119,129 130,114 135,121 137,128 138,136 137,143 136,90 97,95 94,100 93,107 94,113 98,106 99,101 100,95 99,145 98,151 94,158 93,164 94,169 97,164 100,158 101,152 99,103 155,113 153,124 151,129 152,134 153,144 154,154 157,146 164,137 167,127 169,118 167,110 162,105 156,117 156,129 157,141 156,152 157,140 161,128 162,116 160,101 96,158 96|70 96,70 104,70 113,71 122,71 130,72 139,73 147,74 156,77 165,81 172,85 179,91 186,96 192,103 197,111 200,118 203,127 203,135 204,143 201,150 198,157 193,163 188,169 182,174 176,178 168,181 160,184 152,185 143,187 134,188 125,189 117,190 108,190 99,79 81,87 73,96 72,105 73,113 76,113 82,104 80,96 79,87 79,145 76,154 73,162 73,172 74,179 84,171 80,162 79,154 80,144 82,129 97,129 109,129 120,129 131,114 137,121 139,129 140,136 139,143 138,88 97,93 94,99 93,106 94,112 99,106 100,99 101,93 100,145 100,151 96,159 94,165 96,169 99,164 102,159 103,152 101,103 157,113 154,124 152,129 154,134 154,144 155,154 158,146 164,137 166,128 168,119 166,110 162,105 157,117 157,129 159,140 158,152 158,140 161,128 161,117 159,100 97,158 98|69 100,69 108,69 117,70 125,70 134,71 142,72 150,73 159,76 167,80 174,85 181,90 187,96 193,102 198,110 200,118 203,126 204,134 204,142 201,149 198,156 194,162 189,168 183,173 177,177 169,180 161,183 154,184 146,186 137,187 128,188 120,189 111,189 103,79 83,87 75,95 74,104 76,112 79,112 84,104 82,95 81,87 81,144 79,153 76,161 75,170 76,178 85,170 82,161 81,152 83,143 85,128 98,127 110,127 122,128 133,113 138,120 140,127 141,134 140,142 139,87 98,92 95,98 94,105 95,110 100,104 101,98 102,92 100,145 101,151 97,158 95,164 97,168 100,163 103,158 103,151 102,102 157,112 155,123 153,128 154,133 154,143 155,153 159,145 164,136 167,126 168,117 167,109 163,104 158,116 158,127 159,139 158,151 159,139 161,127 162,115 160,99 98,157 99|69 100,69 108,69 117,70 125,70 134,71 142,72 150,73 159,76 167,80 174,85 181,90 187,96 193,103 197,110 200,118 203,126 203,134 204,142 201,148 198,156 194,162 189,168 183,173 177,177 169,180 162,183 154,184 146,186 137,187 128,188 121,188 111,189 103,79 82,87 74,95 73,104 75,112 78,112 84,104 82,95 81,87 81,144 79,153 75,161 75,170 75,177 85,169 81,161 80,152 82,143 84,128 98,128 110,128 121,128 132,113 137,121 139,128 140,135 139,142 138,88 97,93 95,98 94,106 95,110 99,104 100,98 101,93 100,145 100,150 97,157 95,163 97,168 99,163 102,157 103,151 101,102 156,113 153,123 151,128 153,133 153,143 154,152 158,145 164,136 166,127 168,118 166,110 162,105 156,117 156,128 158,139 157,150 158,139 161,127 161,116 159,99 97,157 98|70 100,69 108,70 117,70 125,71 133,71 142,72 149,74 159,77 167,80 174,85 181,90 187,96 192,103 197,110 200,118 202,126 203,134 204,142 201,148 198,156 194,162 189,168 183,173 177,177 169,181 161,183 154,184 146,186 137,188 128,188 120,189 111,189 102,79 82,87 74,95 73,104 75,112 78,112 83,104 81,95 80,87 80,144 79,153 75,161 75,170 75,177 84,169 81,161 80,152 82,143 84,128 97,128 109,128 120,129 131,114 136,121 138,128 139,134 138,142 137,88 97,93 95,99 94,106 95,111 99,105 100,99 101,93 99,144 99,150 96,157 95,163 96,167 99,162 101,157 102,151 101,103 156,113 153,123 151,128 152,133 152,143 154,152 158,144 164,136 167,127 168,118 167,110 163,105 156,117 156,128 157,139 156,149 158,139 161,127 162,116 160,99 97,156 98|69 100,69 108,70 117,71 125,71 133,72 142,73 149,74 159,77 167,80 174,85 181,91 187,96 192,103 197,111 200,118 202,126 203,134 203,142 201,149 198,156 193,163 188,168 182,173 176,178 168,181 160,183 153,185 145,186 136,188 127,188 119,189 110,189 102,79 82,87 74,96 73,104 75,112 77,112 83,104 81,95 80,87 80,144 78,153 75,161 74,170 75,177 84,169 80,161 80,152 82,143 84,128 96,128 108,128 119,128 130,114 135,121 137,128 138,134 137,141 136,89 96,93 93,99 92,106 94,111 98,105 99,99 100,93 98,144 98,150 95,157 93,163 95,167 98,162 100,157 101,151 100,102 156,113 153,123 151,128 152,133 152,143 153,152 157,144 164,136 167,127 168,118 167,110 163,105 156,116 155,128 157,139 156,150 157,139 161,127 162,116 160,100 96,156 97|69 97,69 105,70 114,70 123,71 132,72 140,73 148,74 157,77 166,80 173,85 180,91 187,96 193,103 198,111 201,118 203,127 205,135 205,143 202,149 199,157 193,163 188,168 182,174 176,178 167,181 159,183 152,185 143,186 134,188 126,188 117,189 108,189 99,79 83,87 75,96 74,104 75,112 78,112 84,104 82,96 80,87 80,144 79,152 75,161 75,169 75,177 84,169 81,161 80,152 82,143 84,128 96,128 107,128 118,129 129,114 135,121 137,128 138,135 137,142 136,88 96,93 93,99 92,106 93,111 97,105 99,99 99,93 98,145 98,151 94,158 93,163 94,168 97,163 100,158 101,151 99,103 156,113 153,123 151,128 152,133 152,143 154,152 158,145 164,136 167,127 168,118 167,110 163,105 157,117 156,128 157,139 156,150 158,139 161,127 162,116 160,100 96,157 96|70 98,69 106,70 116,71 124,71 133,72 141,73 149,74 159,77 167,81 174,85 181,91 188,96 193,102 199,110 202,118 204,126 205,134 205,142 202,149 199,157 194,163 189,168 183,173 177,177 169,181 160,183 153,184 145,186 136,188 127,188 119,189 110,189 100,80 83,88 76,96 76,104 77,112 81,113 86,104 83,96 82,88 82,144 81,152 78,160 77,169 77,176 86,168 83,160 82,152 84,143 87,128 96,128 108,128 119,128 130,114 136,121 138,127 139,134 138,141 137,89 96,94 93,100 92,106 94,111 98,105 99,100 100,94 98,144 98,150 95,157 93,163 95,167 98,162 100,157 101,151 100,102 157,112 155,123 153,128 154,133 154,142 156,152 158,144 163,135 165,127 166,118 165,110 162,105 157,116 158,128 159,139 159,150 159,139 160,127 160,115 159,100 96,157 96|70 98,70 107,70 116,71 124,71 133,72 141,73 149,75 159,78 167,81 174,86 182,91 188,96 194,103 199,111 202,118 205,127 206,134 206,143 203,149 200,157 195,163 190,168 184,173 177,177 169,181 161,183 153,184 145,186 136,188 128,188 119,189 110,189 100,80 85,88 78,97 78,105 79,113 82,113 88,105 85,97 84,89 84,143 83,152 79,159 79,168 79,176 87,168 85,160 84,151 86,143 88,128 97,128 109,128 120,128 131,114 136,121 138,128 139,134 138,141 137,89 97,94 94,100 93,106 94,111 99,106 100,100 100,94 99,145 99,151 96,157 94,163 96,167 99,162 101,157 102,151 101,102 157,113 156,123 154,128 155,133 155,143 157,152 159,144 164,135 165,127 166,118 165,110 162,105 158,117 159,128 160,139 160,150 160,139 160,127 159,116 158,100 96,157 97|72 99,71 107,72 116,72 125,73 133,74 141,75 149,77 158,80 166,83 174,87 181,92 188,98 194,104 200,111 203,118 205,127 207,135 207,143 204,150 201,158 196,164 191,169 185,174 178,178 171,181 163,184 155,185 146,187 138,189 130,189 120,190 112,190 102,82 84,90 76,98 76,107 78,114 81,114 86,106 84,98 83,90 82,144 82,153 78,161 78,169 78,177 86,169 84,161 83,153 85,144 87,129 96,129 108,129 118,129 129,115 135,122 136,129 137,135 137,142 136,90 95,96 92,101 92,108 93,113 97,107 98,101 99,95 98,146 98,152 94,159 93,165 95,169 98,164 100,159 101,153 100,104 156,114 154,124 152,129 153,134 153,144 155,153 159,145 165,136 167,128 168,119 166,111 162,106 157,118 156,129 158,140 158,151 159,140 161,128 162,117 160,102 95,159 96|71 99,70 107,71 117,72 125,72 134,73 142,75 150,77 160,80 168,83 175,88 183,94 189,99 195,105 201,113 205,121 206,129 208,137 207,145 204,151 200,159 195,165 189,169 183,174 176,178 168,181 160,184 152,185 144,186 135,188 127,189 118,189 109,189 100,82 83,90 75,98 75,107 77,115 80,115 86,106 83,98 82,90 82,144 80,152 77,160 76,168 76,176 85,168 82,160 82,152 84,143 86,129 96,129 107,129 118,130 128,115 134,122 135,129 137,136 135,143 135,90 96,95 94,101 93,107 94,112 98,107 98,101 99,95 98,146 97,152 95,158 93,164 95,168 98,163 99,159 100,152 99,103 156,114 152,125 150,130 151,134 151,145 153,154 157,147 165,138 169,129 170,119 169,110 164,106 156,118 154,130 156,141 155,152 157,141 163,129 164,116 162,101 96,158 96|71 94,70 103,70 113,71 122,72 131,73 140,74 148,76 158,79 166,82 175,87 182,92 189,98 196,105 201,113 206,120 208,129 209,137 208,146 205,152 202,160 196,165 190,171 182,176 176,180 167,183 159,185 151,187 142,188 133,190 124,191 115,190 106,191 97,80 81,88 74,97 74,106 77,115 80,115 86,106 83,97 81,88 80,144 81,153 77,161 75,170 76,179 83,170 81,161 81,153 84,143 86,130 94,130 106,130 117,130 127,116 132,122 135,129 136,136 134,143 133,90 94,96 93,101 93,107 93,112 95,106 96,101 96,95 95,147 95,154 95,160 94,165 95,170 96,165 97,160 97,154 96,103 156,114 151,125 149,130 151,135 150,145 153,154 158,147 165,139 170,129 172,118 170,110 165,105 156,118 154,130 155,141 155,152 158,141 164,129 164,116 163,102 95,160 95|69 94,68 103,69 112,70 121,70 130,71 139,73 147,75 158,77 166,81 175,85 182,91 190,96 196,103 202,111 206,119 208,127 209,136 209,144 206,150 202,158 196,163 190,168 183,173 176,178 168,181 160,184 152,185 143,187 134,189 125,189 116,189 107,190 98,78 79,87 72,96 73,105 76,114 79,114 85,105 82,96 80,87 78,141 81,151 77,159 75,168 76,177 84,168 81,159 81,150 84,141 86,128 94,128 105,128 116,129 126,114 131,121 134,128 135,135 134,142 133,89 93,94 92,100 92,106 93,111 94,105 95,99 95,94 94,146 96,152 95,159 95,164 96,168 97,163 97,159 98,152 97,101 155,112 151,123 149,128 150,133 150,144 152,153 158,146 165,137 171,127 173,117 170,108 164,104 155,116 153,128 154,140 154,151 158,140 164,127 165,115 163,100 94,159 96|69 95,69 103,69 113,70 122,70 131,71 140,73 148,75 158,77 166,81 175,86 182,91 190,96 196,103 201,111 205,118 208,127 209,135 208,143 206,150 202,157 196,163 190,168 183,173 176,178 168,181 160,184 152,185 144,187 134,188 126,189 117,189 108,190 98,79 80,88 73,97 73,106 76,114 79,114 84,105 82,97 80,88 79,142 80,151 76,160 75,169 76,177 84,168 81,160 81,151 83,142 85,128 93,128 105,128 116,129 126,114 131,121 134,128 135,135 134,142 133,89 93,94 92,100 92,106 92,111 94,105 94,100 95,94 94,146 95,153 94,159 94,164 95,168 96,163 97,158 97,153 96,101 155,112 150,123 148,128 150,133 150,144 152,153 157,146 165,137 171,127 173,116 170,108 164,104 156,116 153,128 154,140 154,151 157,140 164,127 165,114 163,100 93,159 96|71 96,70 104,71 114,71 123,72 132,73 140,74 148,76 158,79 167,82 175,87 182,92 190,98 196,104 201,112 205,120 207,128 208,137 208,145 206,151 202,159 196,164 190,170 183,175 177,179 168,182 161,185 152,186 144,188 135,190 126,190 118,191 108,191 99,80 80,89 73,98 72,107 75,115 78,115 84,106 81,97 80,89 79,144 79,153 76,161 75,171 75,179 83,170 81,161 80,153 83,143 85,129 93,130 105,129 116,130 127,115 131,122 134,129 135,136 134,144 133,90 92,95 91,100 90,106 91,112 94,106 94,100 94,95 93,147 95,154 94,160 93,166 94,170 96,165 97,160 97,154 96,102 155,113 150,124 148,130 150,134 150,146 152,155 157,147 165,139 170,128 172,118 170,109 164,105 156,117 153,129 154,141 154,152 157,141 164,129 165,116 163,101 93,160 95|71 99,70 107,71 117,72 125,72 134,73 143,74 150,76 160,79 169,82 176,87 183,92 190,98 196,104 201,112 205,120 207,128 208,136 207,144 205,151 201,158 196,164 190,169 184,174 177,179 169,182 161,184 154,186 145,187 136,189 127,190 119,190 110,190 101,81 82,89 74,97 73,106 76,114 78,114 84,105 82,97 80,89 80,145 80,154 76,162 75,170 76,178 84,170 82,162 81,153 83,145 85,129 95,129 106,129 117,130 127,115 133,122 135,129 136,136 135,143 134,90 94,95 93,100 93,106 93,111 96,105 96,100 97,94 96,147 96,153 95,159 94,165 95,169 97,164 98,159 98,153 97,102 156,113 152,124 149,129 151,134 151,145 153,154 157,147 165,138 170,128 172,118 170,110 164,105 156,117 154,129 155,141 155,152 157,141 164,128 165,116 163,100 95,159 96|70 101,70 109,71 119,71 127,71 136,72 144,74 152,75 161,78 170,82 177,87 184,92 191,98 196,104 201,113 204,120 206,128 207,136 206,144 203,151 200,158 195,164 189,169 183,174 177,178 169,181 161,184 153,185 145,186 136,188 127,188 119,188 110,189 102,81 83,89 76,97 75,106 77,114 80,114 86,105 83,97 82,89 82,144 80,153 77,160 76,169 77,176 85,168 82,160 82,152 84,144 86,128 96,129 107,128 118,129 127,115 134,121 135,128 136,135 135,142 135,90 97,94 95,100 94,106 95,111 98,105 99,100 99,94 98,146 97,152 95,158 94,163 96,167 98,162 99,158 100,152 99,103 157,113 152,124 150,129 151,133 151,144 153,153 158,146 164,137 168,128 170,119 168,110 164,106 157,117 154,129 155,140 155,151 158,140 162,128 163,116 162,100 97,157 97|72 102,71 110,72 119,72 128,72 136,73 145,74 152,76 162,79 170,82 177,87 184,92 190,98 196,104 201,112 203,120 205,128 206,136 206,144 203,151 200,158 196,164 191,169 185,175 179,179 171,182 163,184 155,185 147,187 138,188 129,189 121,189 112,189 104,82 86,90 78,99 77,107 79,115 82,115 88,107 86,98 84,90 84,144 83,153 79,161 78,169 79,177 87,169 85,160 84,152 86,144 88,129 98,129 109,129 120,129 130,115 136,122 138,129 139,136 138,142 137,91 98,96 96,101 95,107 96,112 99,106 100,101 101,95 100,146 99,152 97,158 95,164 97,168 99,163 101,158 101,152 100,103 157,114 154,124 152,129 153,134 152,144 154,153 159,146 164,137 166,128 167,119 166,111 163,106 158,118 156,129 157,140 157,152 159,140 161,128 161,117 160,101 98,158 98|71 105,70 112,71 122,71 129,71 138,72 146,73 154,75 163,78 171,81 178,86 185,92 190,97 196,104 200,112 202,119 204,127 205,135 205,143 203,150 199,157 195,163 191,168 185,174 179,178 171,181 164,184 156,185 148,186 140,188 131,188 123,188 115,188 106,82 86,90 78,98 78,107 80,114 84,114 89,106 86,98 85,90 84,144 84,152 80,160 79,168 80,175 88,168 85,160 85,152 87,144 89,129 100,129 111,129 121,129 132,115 137,122 139,128 140,135 139,142 138,90 99,94 97,100 96,106 97,111 101,105 102,100 102,94 101,146 101,151 98,158 96,164 98,167 101,163 102,158 103,152 102,104 158,114 154,124 152,129 154,133 153,143 155,152 159,144 164,136 166,127 167,119 166,111 163,106 158,118 157,129 158,139 158,150 159,139 161,128 161,117 160,100 99,157 99|70 105,70 112,70 122,71 130,71 138,72 146,73 154,75 163,78 171,81 178,86 184,92 190,97 195,104 200,112 202,119 203,127 204,135 204,143 202,150 199,157 195,164 190,168 185,174 179,178 171,181 164,184 156,185 148,186 139,188 131,188 123,189 114,189 106,82 86,89 79,97 78,106 80,113 83,113 88,105 86,97 85,89 85,145 83,153 80,160 79,169 80,176 88,168 85,160 84,153 86,144 88,129 100,129 111,129 121,129 131,115 137,122 139,128 140,135 139,142 138,90 100,94 97,100 96,106 97,111 101,106 102,100 103,94 102,146 101,151 98,158 96,164 98,168 101,163 102,158 103,152 102,104 158,114 154,124 152,129 153,134 153,144 155,152 160,144 164,136 166,128 167,119 167,111 163,107 159,118 157,129 158,140 157,150 159,139 161,128 161,117 160,100 99,157 99|70 104,70 112,70 121,71 129,71 138,72 146,73 154,75 163,78 171,81 177,86 184,91 190,97 195,104 200,112 202,119 204,127 205,135 204,143 202,150 199,157 195,163 191,168 185,174 179,178 171,181 164,183 156,185 148,186 139,187 131,188 123,188 114,188 105,82 86,89 79,97 78,106 79,113 83,113 88,105 86,97 85,89 85,144 83,152 79,160 79,168 79,175 88,168 85,160 84,152 86,144 88,129 99,129 110,129 121,129 131,115 137,122 139,128 140,136 139,142 138,89 99,94 96,100 95,106 96,111 100,106 101,100 102,94 101,146 100,151 97,158 96,164 97,168 100,163 102,158 103,152 102,104 158,114 154,124 152,129 153,134 152,144 155,152 160,144 164,136 166,128 167,119 166,111 163,107 159,118 157,129 158,139 157,150 159,139 161,128 161,117 160,100 98,157 99|70 105,70 113,71 122,71 130,72 139,72 147,74 155,75 164,78 172,82 179,87 185,92 191,98 196,104 201,112 203,119 205,128 205,135 205,143 203,150 200,157 196,163 192,168 186,174 180,177 172,180 165,183 157,185 149,186 140,187 132,188 124,188 115,188 107,82 87,90 79,98 78,106 80,114 83,114 88,106 86,98 85,90 85,145 83,153 80,160 79,169 79,176 88,168 85,160 84,153 86,145 88,129 99,129 111,129 122,130 132,115 138,122 140,129 141,136 140,142 139,90 99,95 97,101 96,107 97,112 101,106 102,101 103,95 102,146 101,152 98,158 96,164 98,168 101,163 102,158 103,152 102,104 159,114 155,124 153,130 154,134 154,144 156,152 160,144 164,136 166,128 167,120 166,112 163,107 159,118 158,129 159,140 158,150 160,139 161,128 161,117 160,101 99,158 99|70 101,70 109,71 119,71 127,72 136,72 144,74 152,75 162,78 170,82 177,86 184,92 190,98 196,104 201,112 204,120 205,128 206,136 206,144 204,150 200,157 196,164 191,168 185,174 178,178 170,181 163,183 155,185 146,186 137,188 129,188 120,189 112,189 103,82 84,90 77,98 76,106 78,114 82,114 87,106 85,98 83,90 83,144 82,152 79,160 78,168 78,176 87,168 84,160 83,152 86,144 88,129 97,129 109,129 120,130 130,116 136,122 138,129 139,136 138,142 137,90 97,95 95,101 93,107 94,112 99,106 100,101 101,95 99,146 99,151 96,158 94,164 96,168 99,163 100,158 101,152 100,105 158,115 154,125 152,130 153,134 153,143 155,151 159,144 164,136 166,128 167,120 166,112 163,107 158,118 157,129 158,139 157,150 159,139 161,128 160,118 160,101 97,157 97|70 98,70 106,70 116,71 124,71 133,73 142,74 150,75 159,78 168,82 176,86 183,92 190,97 196,104 201,112 205,119 206,128 208,136 207,144 204,150 201,158 196,164 190,169 184,174 177,178 169,181 161,184 153,185 145,187 135,188 127,189 118,189 109,189 99,82 82,89 74,98 74,106 75,114 79,114 85,106 82,97 81,90 81,145 79,153 76,161 75,169 76,177 84,169 82,161 81,153 83,145 85,129 96,129 107,129 117,130 128,116 134,122 136,129 137,136 136,143 135,90 96,94 93,100 92,107 93,112 97,106 98,100 99,95 98,146 97,152 94,158 92,164 94,168 97,163 99,159 100,152 99,105 158,114 153,125 151,129 152,134 151,144 154,151 159,144 165,137 168,128 169,119 168,112 164,108 158,118 156,129 156,139 156,150 159,139 162,128 163,118 161,100 95,158 96 --------------------------------------------------------------------------------