├── .gitignore ├── src ├── run.sh ├── main.py ├── utils.py ├── solver.py ├── model.py └── evaluate.py ├── README.md ├── data └── voxceleb2-800 │ ├── preprocess.sh │ ├── 2_create_mixture.py │ ├── 3_create_lip_embedding.py │ └── 1_create_mixture_list.py └── pretrain_networks └── visual_frontend.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pt 2 | **/__pycache__/ 3 | **/logs/ 4 | -------------------------------------------------------------------------------- /src/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | gpu_id=3 4 | continue_from= 5 | 6 | if [ -z ${continue_from} ]; then 7 | log_name='avaNet_'$(date '+%Y-%m-%d(%H:%M:%S)') 8 | mkdir logs/$log_name 9 | else 10 | log_name=${continue_from} 11 | fi 12 | 13 | CUDA_VISIBLE_DEVICES="$gpu_id" \ 14 | python -W ignore \ 15 | -m torch.distributed.launch \ 16 | --nproc_per_node=1 \ 17 | --master_port=1236 \ 18 | main.py \ 19 | \ 20 | --log_name $log_name \ 21 | \ 22 | --batch_size 8 \ 23 | --audio_direc '/home/panzexu/datasets/voxceleb2/audio_clean/' \ 24 | --visual_direc '/home/panzexu/datasets/voxceleb2/visual_embedding/lip/' \ 25 | --mix_lst_path '/home/panzexu/datasets/voxceleb2/audio_mixture/2_mix_min_800/mixture_data_list_2mix.csv' \ 26 | --mixture_direc '/home/panzexu/datasets/voxceleb2/audio_mixture/2_mix_min_800/' \ 27 | --C 2 \ 28 | --epochs 100 \ 29 | \ 30 | --use_tensorboard 1 \ 31 | >logs/$log_name/console.txt 2>&1 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MuSE 2 | 3 | A PyTorch implementation of the [Muse: Multi-modal target speaker extraction with visual cues](https://arxiv.org/abs/2010.07775) 4 | 5 | ## Update: 6 | * A new version of this code is scheduled to be released [here (ClearVoice repo)](https://github.com/modelscope/ClearerVoice-Studio). 7 | * The dataset can be found [here](https://huggingface.co/datasets/alibabasglab/KUL-mix). 8 | 9 | ## Project Structure 10 | 11 | `/data/voxceleb2-800`: Scripts to preprocess the voxceleb2 datasets. 12 | 13 | `/pretrain_networks`: The visual front-end network 14 | 15 | `/src`: The training scripts 16 | 17 | ## Pre-trained Weights 18 | Download the pre-trained weights for the Visual Frontend and place it in the ./pretrain_networks folder using the following command: 19 | 20 | wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1k0Zk90ASft89-xAEUbu5CmZWih_u_lRN' -O visual_frontend.pt 21 | 22 | 23 | ## References 24 | 1. The pre-trained weights of the Visual Frontend have been obtained from [Afouras T. and Chung J, Deep Audio-Visual Speech Recognition](https://github.com/lordmartian/deep_avsr) GitHub repository. 25 | 26 | 2. The model is adapted from [Conv-TasNet](https://github.com/kaituoxu/Conv-TasNet) GitHub repository. 27 | -------------------------------------------------------------------------------- /data/voxceleb2-800/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | direc=/home/panzexu/datasets/voxceleb2/ 4 | 5 | data_direc=${direc}orig/ 6 | 7 | train_samples=20000 # no. of train mixture samples simulated 8 | val_samples=5000 # no. of validation mixture samples simulated 9 | test_samples=3000 # no. of test mixture samples simulated 10 | C=2 # no. of speakers in the mixture 11 | mix_db=10 # random db ratio from -10 to 10db 12 | mixture_data_list=mixture_data_list_${C}mix.csv #mixture datalist 13 | sampling_rate=16000 # audio sampling rate 14 | min_length=4 # minimum length of audio 15 | 16 | audio_data_direc=${direc}audio_clean/ # Target audio saved directory 17 | mixture_audio_direc=${direc}audio_mixture/${C}_mix_min_800/ # Audio mixture saved directory 18 | visual_frame_direc=${direc}face/ # The visual saved directory 19 | lip_embedding_direc=${direc}lip/ # The lip embedding saved directory 20 | 21 | # # stage 1: Remove repeated datas in pretrain and train set, extract audio from mp4, create mixture list 22 | # echo 'stage 1: create mixture list' 23 | # python 1_create_mixture_list.py \ 24 | # --data_direc $data_direc \ 25 | # --C $C \ 26 | # --mix_db $mix_db \ 27 | # --train_samples $train_samples \ 28 | # --val_samples $val_samples \ 29 | # --test_samples $test_samples \ 30 | # --audio_data_direc $audio_data_direc \ 31 | # --min_length $min_length \ 32 | # --sampling_rate $sampling_rate \ 33 | # --mixture_data_list $mixture_data_list \ 34 | 35 | # # stage 2: create audio mixture from list 36 | # echo 'stage 2: create mixture audios' 37 | # python 2_create_mixture.py \ 38 | # --C $C \ 39 | # --audio_data_direc $audio_data_direc \ 40 | # --mixture_audio_direc $mixture_audio_direc \ 41 | # --mixture_data_list $mixture_data_list \ 42 | 43 | # # stage 3: create lip embedding 44 | # echo 'stage 3: create lip embedding' 45 | # python 3_create_lip_embedding.py \ 46 | # --video_data_direc $data_direc \ 47 | # --visual_frame_direc $visual_frame_direc \ 48 | # --lip_embedding_direc $lip_embedding_direc 49 | -------------------------------------------------------------------------------- /data/voxceleb2-800/2_create_mixture.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import tqdm 5 | import scipy.io.wavfile as wavfile 6 | 7 | MAX_INT16 = np.iinfo(np.int16).max 8 | 9 | def write_wav(fname, samps, sampling_rate=16000, normalize=True): 10 | """ 11 | Write wav files in int16, support single/multi-channel 12 | """ 13 | # for multi-channel, accept ndarray [Nsamples, Nchannels] 14 | if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 15 | samps = np.transpose(samps) 16 | samps = np.squeeze(samps) 17 | # same as MATLAB and kaldi 18 | if normalize: 19 | samps = samps * MAX_INT16 20 | samps = samps.astype(np.int16) 21 | fdir = os.path.dirname(fname) 22 | if fdir and not os.path.exists(fdir): 23 | os.makedirs(fdir) 24 | # NOTE: librosa 0.6.0 seems could not write non-float narray 25 | # so use scipy.io.wavfile instead 26 | wavfile.write(fname, sampling_rate, samps) 27 | 28 | 29 | def read_wav(fname, normalize=True): 30 | """ 31 | Read wave files using scipy.io.wavfile(support multi-channel) 32 | """ 33 | # samps_int16: N x C or N 34 | # N: number of samples 35 | # C: number of channels 36 | sampling_rate, samps_int16 = wavfile.read(fname) 37 | # N x C => C x N 38 | samps = samps_int16.astype(np.float) 39 | # tranpose because I used to put channel axis first 40 | if samps.ndim != 1: 41 | samps = np.transpose(samps) 42 | # normalize like MATLAB and librosa 43 | if normalize: 44 | samps = samps / MAX_INT16 45 | return sampling_rate, samps 46 | 47 | def main(args): 48 | # create mixture 49 | mixture_data_list = open(args.mixture_data_list).read().splitlines() 50 | print(len(mixture_data_list)) 51 | for line in tqdm.tqdm(mixture_data_list,desc = "Generating audio mixtures"): 52 | data = line.split(',') 53 | save_direc=args.mixture_audio_direc+data[0]+'/' 54 | if not os.path.exists(save_direc): 55 | os.makedirs(save_direc) 56 | 57 | mixture_save_path=save_direc+line.replace(',','_').replace('/','_') +'.wav' 58 | if os.path.exists(mixture_save_path): 59 | continue 60 | 61 | # read target audio 62 | _, audio_mix=read_wav(args.audio_data_direc+data[1]+'/'+data[2]+'/'+data[3]+'.wav') 63 | target_power = np.linalg.norm(audio_mix, 2)**2 / audio_mix.size 64 | 65 | # read inteference audio 66 | for c in range(1, args.C): 67 | audio_path=args.audio_data_direc+data[c*4+1]+'/'+data[c*4+2]+'/'+data[c*4+3]+'.wav' 68 | _, audio = read_wav(audio_path) 69 | intef_power = np.linalg.norm(audio, 2)**2 / audio.size 70 | 71 | # audio = audio_norm(audio) 72 | scalar = (10**(float(data[c*4+4])/20)) * np.sqrt(target_power/intef_power) 73 | audio = audio * scalar 74 | 75 | # truncate long audio with short audio in the mixture 76 | if audio_mix.shape[0] > audio.shape[0]: 77 | audio_mix = audio_mix[:audio.shape[0]] + audio 78 | else: audio_mix = audio_mix + audio[:audio_mix.shape[0]] 79 | 80 | audio_mix = np.divide(audio_mix, np.max(np.abs(audio_mix))) 81 | write_wav(mixture_save_path, audio_mix) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser(description='LRS2 dataset') 86 | parser.add_argument('--C', type=int) 87 | parser.add_argument('--audio_data_direc', type=str) 88 | parser.add_argument('--mixture_audio_direc', type=str) 89 | parser.add_argument('--mixture_data_list', type=str) 90 | args = parser.parse_args() 91 | main(args) -------------------------------------------------------------------------------- /data/voxceleb2-800/3_create_lip_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from tqdm import tqdm 4 | import numpy as np 5 | import argparse 6 | import cv2 as cv 7 | 8 | import sys 9 | sys.path.append('../../') 10 | from pretrain_networks.visual_frontend import VisualFrontend 11 | 12 | def preprocess_sample(file, params, args): 13 | 14 | """ 15 | Function to preprocess each data sample. 16 | """ 17 | 18 | videoFile = args.video_data_direc + file + ".mp4" 19 | roiFile = args.visual_frame_direc + file +".png" 20 | visualFeaturesFile = args.lip_embedding_direc + file + ".npy" 21 | 22 | if os.path.exists(visualFeaturesFile): 23 | return 24 | 25 | if not os.path.exists(roiFile[:-9]): 26 | os.makedirs(roiFile[:-9]) 27 | 28 | if not os.path.exists(visualFeaturesFile[:-9]): 29 | os.makedirs(visualFeaturesFile[:-9]) 30 | 31 | 32 | roiSize = params["roiSize"] 33 | normMean = params["normMean"] 34 | normStd = params["normStd"] 35 | vf = params["vf"] 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | 39 | #for each frame, resize to 224x224 and crop the central 112x112 region 40 | captureObj = cv.VideoCapture(videoFile) 41 | roiSequence = list() 42 | while (captureObj.isOpened()): 43 | ret, frame = captureObj.read() 44 | if ret == True: 45 | grayed = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) 46 | grayed = grayed/255 47 | grayed = cv.resize(grayed, (roiSize*2,roiSize*2)) 48 | roi = grayed[int(roiSize-(roiSize/2)):int(roiSize+(roiSize/2)), int(roiSize-(roiSize/2)):int(roiSize+(roiSize/2))] 49 | roiSequence.append(roi) 50 | else: 51 | break 52 | captureObj.release() 53 | # cv.imwrite(roiFile, np.floor(255*np.concatenate(roiSequence, axis=1)).astype(np.int)) 54 | 55 | 56 | #normalise the frames and extract features for each frame using the visual frontend 57 | #save the visual features to a .npy file 58 | inp = np.stack(roiSequence, axis=0) 59 | inp = np.expand_dims(inp, axis=[1,2]) 60 | inp = (inp - normMean)/normStd 61 | inputBatch = torch.from_numpy(inp) 62 | inputBatch = (inputBatch.float()).to(device) 63 | vf.eval() 64 | with torch.no_grad(): 65 | outputBatch = vf(inputBatch) 66 | out = torch.squeeze(outputBatch, dim=1) 67 | out = out.cpu().numpy() 68 | np.save(visualFeaturesFile, out) 69 | return 70 | 71 | def main(args): 72 | gpuAvailable = torch.cuda.is_available() 73 | device = torch.device("cuda" if gpuAvailable else "cpu") 74 | 75 | #declaring the visual frontend module 76 | vf = VisualFrontend() 77 | vf.load_state_dict(torch.load('../../pretrain_networks/visual_frontend.pt', map_location=device)) 78 | vf.to(device) 79 | 80 | #walking through the data directory and obtaining a list of all files in the dataset 81 | filesList = list() 82 | for root, dirs, files in os.walk(args.video_data_direc+'train/'): 83 | for file in files: 84 | if file.endswith(".mp4"): 85 | path = root.split('/') 86 | filesList.append((path[-3]+'/'+path[-2]+'/'+path[-1]+'/'+file[:-4])) 87 | print(len(filesList)) 88 | 89 | params = {"roiSize":112, "normMean":0.4161, "normStd":0.1688, "vf":vf} 90 | for file in tqdm(filesList, leave=True, desc="Preprocess", ncols=75): 91 | preprocess_sample(file, params, args) 92 | 93 | 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser(description='LRS3 dataset') 97 | parser.add_argument('--video_data_direc', type=str) 98 | parser.add_argument('--lip_embedding_direc', type=str) 99 | parser.add_argument('--visual_frame_direc', type=str) 100 | args = parser.parse_args() 101 | main(args) -------------------------------------------------------------------------------- /pretrain_networks/visual_frontend.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | 6 | class ResNetLayer(nn.Module): 7 | 8 | """ 9 | A ResNet layer used to build the ResNet network. 10 | Architecture: 11 | --> conv-bn-relu -> conv -> + -> bn-relu -> conv-bn-relu -> conv -> + -> bn-relu --> 12 | | | | | 13 | -----> downsample ------> -------------------------------------> 14 | """ 15 | 16 | def __init__(self, inplanes, outplanes, stride): 17 | super(ResNetLayer, self).__init__() 18 | self.conv1a = nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1a = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 20 | self.conv2a = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.stride = stride 22 | self.downsample = nn.Conv2d(inplanes, outplanes, kernel_size=(1,1), stride=stride, bias=False) 23 | self.outbna = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 24 | 25 | self.conv1b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.bn1b = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 27 | self.conv2b = nn.Conv2d(outplanes, outplanes, kernel_size=3, stride=1, padding=1, bias=False) 28 | self.outbnb = nn.BatchNorm2d(outplanes, momentum=0.01, eps=0.001) 29 | return 30 | 31 | 32 | def forward(self, inputBatch): 33 | batch = F.relu(self.bn1a(self.conv1a(inputBatch))) 34 | batch = self.conv2a(batch) 35 | if self.stride == 1: 36 | residualBatch = inputBatch 37 | else: 38 | residualBatch = self.downsample(inputBatch) 39 | batch = batch + residualBatch 40 | intermediateBatch = batch 41 | batch = F.relu(self.outbna(batch)) 42 | 43 | batch = F.relu(self.bn1b(self.conv1b(batch))) 44 | batch = self.conv2b(batch) 45 | residualBatch = intermediateBatch 46 | batch = batch + residualBatch 47 | outputBatch = F.relu(self.outbnb(batch)) 48 | return outputBatch 49 | 50 | 51 | 52 | class ResNet(nn.Module): 53 | 54 | """ 55 | An 18-layer ResNet architecture. 56 | """ 57 | 58 | def __init__(self): 59 | super(ResNet, self).__init__() 60 | self.layer1 = ResNetLayer(64, 64, stride=1) 61 | self.layer2 = ResNetLayer(64, 128, stride=2) 62 | self.layer3 = ResNetLayer(128, 256, stride=2) 63 | self.layer4 = ResNetLayer(256, 512, stride=2) 64 | self.avgpool = nn.AvgPool2d(kernel_size=(4,4), stride=(1,1)) 65 | return 66 | 67 | 68 | def forward(self, inputBatch): 69 | batch = self.layer1(inputBatch) 70 | batch = self.layer2(batch) 71 | batch = self.layer3(batch) 72 | batch = self.layer4(batch) 73 | outputBatch = self.avgpool(batch) 74 | return outputBatch 75 | 76 | 77 | 78 | class VisualFrontend(nn.Module): 79 | 80 | """ 81 | A visual feature extraction module. Generates a 512-dim feature vector per video frame. 82 | Architecture: A 3D convolution block followed by an 18-layer ResNet. 83 | """ 84 | 85 | def __init__(self): 86 | super(VisualFrontend, self).__init__() 87 | self.frontend3D = nn.Sequential( 88 | nn.Conv3d(1, 64, kernel_size=(5,7,7), stride=(1,2,2), padding=(2,3,3), bias=False), 89 | nn.BatchNorm3d(64, momentum=0.01, eps=0.001), 90 | nn.ReLU(), 91 | nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) 92 | ) 93 | self.resnet = ResNet() 94 | return 95 | 96 | 97 | def forward(self, inputBatch): 98 | inputBatch = inputBatch.transpose(0, 1).transpose(1, 2) 99 | batchsize = inputBatch.shape[0] 100 | batch = self.frontend3D(inputBatch) 101 | 102 | batch = batch.transpose(1, 2) 103 | batch = batch.reshape(batch.shape[0]*batch.shape[1], batch.shape[2], batch.shape[3], batch.shape[4]) 104 | outputBatch = self.resnet(batch) 105 | outputBatch = outputBatch.reshape(batchsize, -1, 512) 106 | outputBatch = outputBatch.transpose(1 ,2) 107 | outputBatch = outputBatch.transpose(1, 2).transpose(0, 1) 108 | return outputBatch -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from utils import * 4 | import os 5 | from model import muse 6 | from solver import Solver 7 | 8 | 9 | def main(args): 10 | if args.distributed: 11 | torch.manual_seed(0) 12 | torch.cuda.set_device(args.local_rank) 13 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 14 | 15 | # speaker id assignment 16 | mix_lst=open(args.mix_lst_path).read().splitlines() 17 | train_lst=list(filter(lambda x: x.split(',')[0]=='train', mix_lst)) 18 | IDs = 0 19 | speaker_dict={} 20 | for line in train_lst: 21 | for i in range(2): 22 | ID = line.split(',')[i*4+2] 23 | if ID not in speaker_dict: 24 | speaker_dict[ID]=IDs 25 | IDs += 1 26 | args.speaker_dict=speaker_dict 27 | args.speakers=len(speaker_dict) 28 | 29 | # Model 30 | model = muse(args.N, args.L, args.B, args.H, args.P, args.X, args.R, 31 | args.C, args.speakers) 32 | 33 | if (args.distributed and args.local_rank ==0) or args.distributed == False: 34 | print("started on " + args.log_name + '\n') 35 | print(args) 36 | print("\nTotal number of parameters: {} \n".format(sum(p.numel() for p in model.parameters()))) 37 | print(model) 38 | 39 | model = model.cuda() 40 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 41 | 42 | train_sampler, train_generator = get_dataloader(args,'train') 43 | _, val_generator = get_dataloader(args, 'val') 44 | _, test_generator = get_dataloader(args, 'test') 45 | args.train_sampler=train_sampler 46 | 47 | solver = Solver(args=args, 48 | model = model, 49 | optimizer = optimizer, 50 | train_data = train_generator, 51 | validation_data = val_generator, 52 | test_data = test_generator) 53 | solver.train() 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser("avConv-tasnet") 57 | 58 | # Dataloader 59 | parser.add_argument('--mix_lst_path', type=str, default='/home/panzexu/datasets/LRS2/audio/2_mix_min/mixture_data_list_2mix.csv', 60 | help='directory including train data') 61 | parser.add_argument('--audio_direc', type=str, default='/home/panzexu/datasets/LRS2/audio/Audio/', 62 | help='directory including validation data') 63 | parser.add_argument('--visual_direc', type=str, default='/home/panzexu/datasets/LRS2/lip/', 64 | help='directory including test data') 65 | parser.add_argument('--mixture_direc', type=str, default='/home/panzexu/datasets/LRS2/audio/2_mix_min/', 66 | help='directory of audio') 67 | 68 | 69 | # Training 70 | parser.add_argument('--batch_size', default=8, type=int, 71 | help='Batch size') 72 | parser.add_argument('--max_length', default=6, type=int, 73 | help='max_length of mixture in training') 74 | parser.add_argument('--num_workers', default=4, type=int, 75 | help='Number of workers to generate minibatch') 76 | parser.add_argument('--epochs', default=100, type=int, 77 | help='Number of maximum epochs') 78 | 79 | # Model hyperparameters 80 | parser.add_argument('--L', default=40, type=int, 81 | help='Length of the filters in samples (80=5ms at 16kHZ)') 82 | parser.add_argument('--N', default=256, type=int, 83 | help='Number of filters in autoencoder') 84 | parser.add_argument('--B', default=256, type=int, 85 | help='Number of channels in bottleneck 1 × 1-conv block') 86 | parser.add_argument('--C', type=int, default=2, 87 | help='number of speakers to mix') 88 | parser.add_argument('--H', default=512, type=int, 89 | help='Number of channels in convolutional blocks') 90 | parser.add_argument('--P', default=3, type=int, 91 | help='Kernel size in convolutional blocks') 92 | parser.add_argument('--X', default=8, type=int, 93 | help='Number of convolutional blocks in each repeat') 94 | parser.add_argument('--R', default=4, type=int, 95 | help='Number of repeats') 96 | 97 | # optimizer 98 | parser.add_argument('--lr', default=1e-3, type=float, 99 | help='Init learning rate') 100 | parser.add_argument('--max_norm', default=5, type=float, 101 | help='Gradient norm threshold to clip') 102 | 103 | 104 | # Log and Visulization 105 | parser.add_argument('--log_name', type=str, default=None, 106 | help='the name of the log') 107 | parser.add_argument('--use_tensorboard', type=int, default=0, 108 | help='Whether to use use_tensorboard') 109 | parser.add_argument('--continue_from', type=str, default='', 110 | help='Whether to use use_tensorboard') 111 | 112 | # Distributed training 113 | parser.add_argument('--opt-level', default='O0', type=str) 114 | parser.add_argument("--local_rank", default=0, type=int) 115 | parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) 116 | parser.add_argument('--patch_torch_functions', type=str, default=None) 117 | 118 | args = parser.parse_args() 119 | 120 | args.distributed = False 121 | args.world_size = 1 122 | if 'WORLD_SIZE' in os.environ: 123 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 124 | args.world_size = int(os.environ['WORLD_SIZE']) 125 | 126 | assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." 127 | 128 | main(args) 129 | -------------------------------------------------------------------------------- /data/voxceleb2-800/1_create_mixture_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import csv 5 | import tqdm 6 | import librosa 7 | import scipy.io.wavfile as wavfile 8 | from multiprocessing import Pool 9 | 10 | np.random.seed(0) 11 | 12 | def extract_wav_from_mp4(line): 13 | # Extract .wav file from mp4 14 | video_from_path=args.data_direc + line[0]+'/'+line[1]+'/'+line[2]+'.mp4' 15 | audio_save_path=args.audio_data_direc + line[0]+'/'+line[1]+'/'+line[2]+'.wav' 16 | if not os.path.exists(audio_save_path.rsplit('/', 1)[0]): 17 | os.makedirs(audio_save_path.rsplit('/', 1)[0]) 18 | 19 | if not os.path.exists(audio_save_path): 20 | os.system("ffmpeg -i %s %s"%(video_from_path, audio_save_path)) 21 | sr, audio = wavfile.read(audio_save_path) 22 | assert sr==args.sampling_rate , "sampling_rate mismatch" 23 | sample_length = audio.shape[0] 24 | return sample_length # In seconds 25 | 26 | 27 | def main(args): 28 | # read the datalist and separate into train, val and test set 29 | train_list=[] 30 | val_list=[] 31 | test_list=[] 32 | tmp_list=[] 33 | 34 | print("Gathering file names") 35 | 36 | # Get test set list of audios 37 | for path, dirs ,files in os.walk(args.data_direc + 'test/'): 38 | for filename in files: 39 | if filename[-4:] =='.mp4': 40 | ln = [path.split('/')[-3], path.split('/')[-2], path.split('/')[-1] +'/'+ filename.split('.')[0]] 41 | sample_length = extract_wav_from_mp4(ln) 42 | if sample_length < args.min_length*args.sampling_rate: continue 43 | ln += [sample_length/args.sampling_rate] 44 | test_list.append(ln) 45 | 46 | # Get train set list of audios 47 | for path, dirs ,files in os.walk(args.data_direc + 'train/'): 48 | for filename in files: 49 | if filename[-4:] =='.mp4': 50 | ln = [path.split('/')[-3], path.split('/')[-2], path.split('/')[-1] +'/'+ filename.split('.')[0]] 51 | sample_length = extract_wav_from_mp4(ln) 52 | if sample_length < args.min_length*args.sampling_rate: continue 53 | ln += [sample_length/args.sampling_rate] 54 | tmp_list.append(ln) 55 | # print(len(tmp_list)) 56 | 57 | # Sort the speakers with the number of utterances in pretrain set 58 | speakers = {} 59 | for ln in tmp_list: 60 | ID = ln[1] 61 | if ID not in speakers: 62 | speakers[ID] = 1 63 | else: speakers[ID] +=1 64 | sort_speakers = sorted(speakers.items(), key=lambda x: x[1], reverse=True) 65 | 66 | # Get the 1600 speakers with the most no. of utterances 67 | train_speakers={} 68 | for i, (ID) in enumerate(sort_speakers): 69 | if i == 800: 70 | break 71 | train_speakers[ID[0]] = 0 72 | 73 | for ln in tmp_list: 74 | ID = ln[1] 75 | if ID in train_speakers: 76 | if train_speakers[ID] < 12: 77 | val_list.append(ln) 78 | train_speakers[ID] +=1 79 | elif train_speakers[ID] >=62: 80 | continue 81 | else: 82 | train_list.append(ln) 83 | train_speakers[ID] +=1 84 | 85 | 86 | # Create mixture list 87 | print("Creating mixture list") 88 | f=open(args.mixture_data_list,'w') 89 | w=csv.writer(f) 90 | 91 | 92 | # create test set and validation set 93 | for data_list in [train_list, val_list]: 94 | if len(data_list)<20000: 95 | data='val' 96 | length = args.val_samples 97 | else: 98 | data='train' 99 | length = args.train_samples 100 | 101 | count_list=[] 102 | for ln in data_list: 103 | if not ln[1] in count_list: 104 | count_list.append(ln[1]) 105 | print("In %s list: %s speakers, %s utterances"%(data, len(count_list), len(data_list))) 106 | 107 | cache_list = data_list[:] 108 | count = 0 109 | while (len(data_list) >= args.C): 110 | mixtures=[data] 111 | shortest = 200 112 | cache = [] 113 | while len(cache) < args.C: 114 | idx = np.random.randint(0, len(data_list)) 115 | if data_list[idx][1] in cache: 116 | continue 117 | cache.append(data_list[idx][1]) 118 | mixtures = mixtures + list(data_list[idx]) 119 | if float(mixtures[-1]) < shortest: shortest = float(mixtures[-1]) 120 | del mixtures[-1] 121 | if len(cache)==1: db_ratio =0 122 | else: db_ratio = np.random.uniform(-args.mix_db,args.mix_db) 123 | mixtures.append(db_ratio) 124 | data_list.pop(idx) 125 | mixtures.append(shortest) 126 | w.writerow(mixtures) 127 | count +=1 128 | if count == length: 129 | break 130 | 131 | if count < length: 132 | for j in range(length-count): 133 | mixtures=[data] 134 | shortest = 200 135 | cache = [] 136 | while len(cache) < args.C: 137 | idx = np.random.randint(0, len(cache_list)) 138 | if cache_list[idx][1] in cache: 139 | continue 140 | cache.append(cache_list[idx][1]) 141 | mixtures = mixtures + list(cache_list[idx]) 142 | if float(mixtures[-1]) < shortest: shortest = float(mixtures[-1]) 143 | del mixtures[-1] 144 | if len(cache)==1: db_ratio =0 145 | else: db_ratio = np.random.uniform(-args.mix_db,args.mix_db) 146 | mixtures.append(db_ratio) 147 | mixtures.append(shortest) 148 | w.writerow(mixtures) 149 | 150 | 151 | data_list=test_list 152 | data='test' 153 | length = args.test_samples 154 | count_list=[] 155 | for ln in data_list: 156 | if not ln[1] in count_list: 157 | count_list.append(ln[1]) 158 | print("In %s list: %s speakers, %s utterances"%(data, len(count_list), len(data_list))) 159 | 160 | for _ in range(length): 161 | mixtures=[data] 162 | shortest = 200 163 | cache = [] 164 | while len(cache) < args.C: 165 | idx = np.random.randint(0, len(data_list)) 166 | if data_list[idx][1] in cache: 167 | continue 168 | cache.append(data_list[idx][1]) 169 | mixtures = mixtures + list(data_list[idx]) 170 | if float(mixtures[-1]) < shortest: shortest = float(mixtures[-1]) 171 | del mixtures[-1] 172 | if len(cache)==1: db_ratio =0 173 | else: db_ratio = np.random.uniform(-args.mix_db,args.mix_db) 174 | mixtures.append(db_ratio) 175 | mixtures.append(shortest) 176 | w.writerow(mixtures) 177 | 178 | 179 | f.close() 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser(description='LRS2 dataset') 183 | parser.add_argument('--data_direc', type=str) 184 | parser.add_argument('--C', type=int) 185 | parser.add_argument('--mix_db', type=float) 186 | parser.add_argument('--train_samples', type=int) 187 | parser.add_argument('--val_samples', type=int) 188 | parser.add_argument('--test_samples', type=int) 189 | parser.add_argument('--audio_data_direc', type=str) 190 | parser.add_argument('--min_length', type=int) 191 | parser.add_argument('--sampling_rate', type=int) 192 | parser.add_argument('--mixture_data_list', type=str) 193 | args = parser.parse_args() 194 | 195 | main(args) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch.distributed as dist 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | import scipy.io.wavfile as wavfile 8 | from itertools import permutations 9 | from apex import amp 10 | import tqdm 11 | import os 12 | 13 | EPS = 1e-6 14 | 15 | class dataset(data.Dataset): 16 | def __init__(self, 17 | speaker_dict, 18 | mix_lst_path, 19 | audio_direc, 20 | visual_direc, 21 | mixture_direc, 22 | batch_size, 23 | partition='test', 24 | audio_only=False, 25 | sampling_rate=16000, 26 | max_length=4, 27 | mix_no=2): 28 | 29 | self.minibatch =[] 30 | self.audio_only = audio_only 31 | self.audio_direc = audio_direc 32 | self.visual_direc = visual_direc 33 | self.mixture_direc = mixture_direc 34 | self.sampling_rate = sampling_rate 35 | self.partition = partition 36 | self.max_length = max_length 37 | self.C=mix_no 38 | self.speaker_id=speaker_dict 39 | 40 | mix_lst=open(mix_lst_path).read().splitlines() 41 | mix_lst=list(filter(lambda x: x.split(',')[0]==partition, mix_lst)) 42 | 43 | assert (batch_size%self.C) == 0, "input batch_size should be multiples of mixture speakers" 44 | 45 | self.batch_size = int(batch_size/self.C ) 46 | sorted_mix_lst = sorted(mix_lst, key=lambda data: float(data.split(',')[-1]), reverse=True) 47 | start = 0 48 | while True: 49 | end = min(len(sorted_mix_lst), start + self.batch_size) 50 | self.minibatch.append(sorted_mix_lst[start:end]) 51 | if end == len(sorted_mix_lst): 52 | break 53 | start = end 54 | 55 | def __getitem__(self, index): 56 | batch_lst = self.minibatch[index] 57 | min_length = int(float(batch_lst[-1].split(',')[-1])*self.sampling_rate) 58 | 59 | mixtures=[] 60 | audios=[] 61 | visuals=[] 62 | speakers=[] 63 | for line in batch_lst: 64 | mixture_path=self.mixture_direc+self.partition+'/'+ line.replace(',','_').replace('/','_')+'.wav' 65 | _, mixture = wavfile.read(mixture_path) 66 | mixture = self._audio_norm(mixture[:min_length]) 67 | 68 | line=line.split(',') 69 | for c in range(self.C): 70 | # read target audio 71 | audio_path=self.audio_direc+line[c*4+1]+'/'+line[c*4+2]+'/'+line[c*4+3]+'.wav' 72 | _, audio = wavfile.read(audio_path) 73 | audios.append(self._audio_norm(audio[:min_length])) 74 | 75 | # read target audio id 76 | if self.partition == 'test': 77 | speakers.append(0) 78 | else: speakers.append(self.speaker_id[line[c*4+2]]) 79 | 80 | # read target visual reference 81 | visual_path=self.visual_direc+line[c*4+1]+'/'+line[c*4+2]+'/'+line[c*4+3]+'.npy' 82 | visual = np.load(visual_path) 83 | length = math.floor(min_length/self.sampling_rate*25) 84 | visual = visual[:length,...] 85 | a = visual.shape[0] 86 | if visual.shape[0] < length: 87 | visual = np.pad(visual, ((0,int(length - visual.shape[0])),(0,0)), mode = 'edge') 88 | visuals.append(visual) 89 | 90 | # read overlapped speech 91 | mixtures.append(mixture) 92 | 93 | return np.asarray(mixtures)[...,:self.max_length*self.sampling_rate], \ 94 | np.asarray(audios)[...,:self.max_length*self.sampling_rate], \ 95 | np.asarray(visuals)[...,:self.max_length*25,:], \ 96 | np.asarray(speakers) 97 | 98 | def __len__(self): 99 | return len(self.minibatch) 100 | 101 | def _audio_norm(self,audio): 102 | return np.divide(audio, np.max(np.abs(audio))) 103 | 104 | 105 | class DistributedSampler(data.Sampler): 106 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): 107 | if num_replicas is None: 108 | if not dist.is_available(): 109 | raise RuntimeError("Requires distributed package to be available") 110 | num_replicas = dist.get_world_size() 111 | if rank is None: 112 | if not dist.is_available(): 113 | raise RuntimeError("Requires distributed package to be available") 114 | rank = dist.get_rank() 115 | self.dataset = dataset 116 | self.num_replicas = num_replicas 117 | self.rank = rank 118 | self.epoch = 0 119 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 120 | self.total_size = self.num_samples * self.num_replicas 121 | self.shuffle = shuffle 122 | self.seed = seed 123 | 124 | def __iter__(self): 125 | if self.shuffle: 126 | # deterministically shuffle based on epoch and seed 127 | g = torch.Generator() 128 | g.manual_seed(self.seed + self.epoch) 129 | # indices = torch.randperm(len(self.dataset), generator=g).tolist() 130 | ind = torch.randperm(int(len(self.dataset)/self.num_replicas), generator=g)*self.num_replicas 131 | indices = [] 132 | for i in range(self.num_replicas): 133 | indices = indices + (ind+i).tolist() 134 | else: 135 | indices = list(range(len(self.dataset))) 136 | 137 | # add extra samples to make it evenly divisible 138 | indices += indices[:(self.total_size - len(indices))] 139 | assert len(indices) == self.total_size 140 | 141 | # subsample 142 | # indices = indices[self.rank:self.total_size:self.num_replicas] 143 | indices = indices[self.rank*self.num_samples:(self.rank+1)*self.num_samples] 144 | assert len(indices) == self.num_samples 145 | 146 | return iter(indices) 147 | 148 | def __len__(self): 149 | return self.num_samples 150 | 151 | def set_epoch(self, epoch): 152 | self.epoch = epoch 153 | 154 | def get_dataloader(args, partition): 155 | datasets = dataset( 156 | speaker_dict =args.speaker_dict, 157 | mix_lst_path=args.mix_lst_path, 158 | audio_direc=args.audio_direc, 159 | visual_direc=args.visual_direc, 160 | mixture_direc=args.mixture_direc, 161 | batch_size=args.batch_size, 162 | max_length=args.max_length, 163 | partition=partition, 164 | mix_no=args.C) 165 | 166 | sampler = DistributedSampler( 167 | datasets, 168 | num_replicas=args.world_size, 169 | rank=args.local_rank) if args.distributed else None 170 | 171 | generator = data.DataLoader(datasets, 172 | batch_size = 1, 173 | shuffle = (sampler is None), 174 | num_workers = args.num_workers, 175 | sampler=sampler) 176 | 177 | return sampler, generator 178 | 179 | @amp.float_function 180 | def cal_SISNR(source, estimate_source): 181 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) 182 | Args: 183 | source: torch tensor, [batch size, sequence length] 184 | estimate_source: torch tensor, [batch size, sequence length] 185 | Returns: 186 | SISNR, [batch size] 187 | """ 188 | assert source.size() == estimate_source.size() 189 | 190 | # Step 1. Zero-mean norm 191 | source = source - torch.mean(source, axis = -1, keepdim=True) 192 | estimate_source = estimate_source - torch.mean(estimate_source, axis = -1, keepdim=True) 193 | 194 | # Step 2. SI-SNR 195 | # s_target = s / ||s||^2 196 | ref_energy = torch.sum(source ** 2, axis = -1, keepdim=True) + EPS 197 | proj = torch.sum(source * estimate_source, axis = -1, keepdim=True) * source / ref_energy 198 | # e_noise = s' - s_target 199 | noise = estimate_source - proj 200 | # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) 201 | ratio = torch.sum(proj ** 2, axis = -1) / (torch.sum(noise ** 2, axis = -1) + EPS) 202 | sisnr = 10 * torch.log10(ratio + EPS) 203 | 204 | return sisnr 205 | 206 | 207 | if __name__ == '__main__': 208 | datasets = dataset( 209 | mix_lst_path='/home/panzexu/datasets/LRS3/audio/2_mix_min/mixture_data_list_2mix.csv', 210 | audio_direc='/home/panzexu/datasets/LRS3/audio/Audio/', 211 | visual_direc='/home/panzexu/datasets/LRS3/lip/', 212 | mixture_direc='/home/panzexu/datasets/LRS3/audio/2_mix_min/', 213 | batch_size=8, 214 | partition='train') 215 | data_loader = data.DataLoader(datasets, 216 | batch_size = 1, 217 | shuffle= True, 218 | num_workers = 4) 219 | 220 | for a_mix, a_tgt, v_tgt, speakers in tqdm.tqdm(data_loader): 221 | # print(a_mix.squeeze().size()) 222 | # print(a_tgt.squeeze().size()) 223 | # print(v_tgt.squeeze().size()) 224 | pass 225 | 226 | # a = np.ones((24,512)) 227 | # print(a.shape) 228 | # a = np.pad(a, ((0,-1), (0,0)), 'edge') 229 | # print(a.shape) 230 | 231 | # a = np.random.rand(2,2,3) 232 | # print(a) 233 | # a = a.reshape(4,3) 234 | # print(a) -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | import time 2 | from utils import * 3 | from apex import amp 4 | from apex.parallel import DistributedDataParallel as DDP 5 | from torch.utils.tensorboard import SummaryWriter 6 | import torch.distributed as dist 7 | 8 | class Solver(object): 9 | def __init__(self, train_data, validation_data, test_data, model, optimizer, args): 10 | self.train_data = train_data 11 | self.validation_data = validation_data 12 | self.test_data = test_data 13 | self.args = args 14 | self.amp = amp 15 | self.ae_loss = nn.CrossEntropyLoss() 16 | 17 | self.print = False 18 | if (self.args.distributed and self.args.local_rank ==0) or not self.args.distributed: 19 | self.print = True 20 | if self.args.use_tensorboard: 21 | self.writer = SummaryWriter('logs/%s/tensorboard/' % args.log_name) 22 | 23 | self.model, self.optimizer = self.amp.initialize(model, optimizer, 24 | opt_level=args.opt_level, 25 | patch_torch_functions=args.patch_torch_functions) 26 | 27 | if self.args.distributed: 28 | self.model = DDP(self.model) 29 | 30 | self._reset() 31 | 32 | def _reset(self): 33 | self.halving = False 34 | if self.args.continue_from: 35 | checkpoint = torch.load('logs/%s/model_dict.pt' % self.args.continue_from, map_location='cpu') 36 | 37 | self.model.load_state_dict(checkpoint['model']) 38 | self.optimizer.load_state_dict(checkpoint['optimizer']) 39 | self.amp.load_state_dict(checkpoint['amp']) 40 | 41 | self.start_epoch=checkpoint['epoch'] 42 | self.prev_val_loss = checkpoint['prev_val_loss'] 43 | self.best_val_loss = checkpoint['best_val_loss'] 44 | self.val_no_impv = checkpoint['val_no_impv'] 45 | 46 | if self.print: print("Resume training from epoch: {}".format(self.start_epoch)) 47 | 48 | else: 49 | self.prev_val_loss = float("inf") 50 | self.best_val_loss = float("inf") 51 | self.val_no_impv = 0 52 | self.start_epoch=1 53 | if self.print: print('Start new training') 54 | 55 | def train(self): 56 | for epoch in range(self.start_epoch, self.args.epochs+1): 57 | self.joint_loss_weight=epoch 58 | if self.args.distributed: self.args.train_sampler.set_epoch(epoch) 59 | # Train 60 | self.model.train() 61 | start = time.time() 62 | tr_loss,tr_loss_speaker = self._run_one_epoch(data_loader = self.train_data, state='train') 63 | reduced_tr_loss = self._reduce_tensor(tr_loss) 64 | reduced_tr_loss_speaker = self._reduce_tensor(tr_loss_speaker) 65 | 66 | if self.print: print('Train Summary | End of Epoch {0} | Time {1:.2f}s | ' 67 | 'Train Loss {2:.3f}'.format( 68 | epoch, time.time() - start, reduced_tr_loss)) 69 | 70 | # Validation 71 | self.model.eval() 72 | start = time.time() 73 | with torch.no_grad(): 74 | val_loss, val_loss_speaker = self._run_one_epoch(data_loader = self.validation_data, state='val') 75 | reduced_val_loss = self._reduce_tensor(val_loss) 76 | reduced_val_loss_speaker = self._reduce_tensor(val_loss_speaker) 77 | 78 | if self.print: print('Valid Summary | End of Epoch {0} | Time {1:.2f}s | ' 79 | 'Valid Loss {2:.3f}'.format( 80 | epoch, time.time() - start, reduced_val_loss)) 81 | 82 | # test 83 | self.model.eval() 84 | start = time.time() 85 | with torch.no_grad(): 86 | test_loss,_ = self._run_one_epoch(data_loader = self.test_data, state='test') 87 | reduced_test_loss = self._reduce_tensor(test_loss) 88 | 89 | if self.print: print('Test Summary | End of Epoch {0} | Time {1:.2f}s | ' 90 | 'Test Loss {2:.3f}'.format( 91 | epoch, time.time() - start, reduced_test_loss)) 92 | 93 | 94 | # Check whether to adjust learning rate and early stop 95 | if reduced_val_loss >= self.prev_val_loss: 96 | self.val_no_impv += 1 97 | if self.val_no_impv >= 3: 98 | self.halving = True 99 | if self.val_no_impv >= 6: 100 | if self.print: print("No imporvement for 6 epochs, early stopping.") 101 | break 102 | else: 103 | self.val_no_impv = 0 104 | 105 | # Halfing the learning rate 106 | if self.halving: 107 | optim_state = self.optimizer.state_dict() 108 | optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr']/2 109 | self.optimizer.load_state_dict(optim_state) 110 | if self.print: print('Learning rate adjusted to: {lr:.6f}'.format( 111 | lr=optim_state['param_groups'][0]['lr'])) 112 | self.halving = False 113 | self.prev_val_loss = reduced_val_loss 114 | 115 | if self.print: 116 | # Tensorboard logging 117 | if self.args.use_tensorboard: 118 | self.writer.add_scalar('Train loss', reduced_tr_loss, epoch) 119 | self.writer.add_scalar('Validation loss', reduced_val_loss, epoch) 120 | self.writer.add_scalar('Test loss', reduced_test_loss, epoch) 121 | self.writer.add_scalar('Validation loss speaker', reduced_val_loss_speaker, epoch) 122 | self.writer.add_scalar('Train loss speaker', reduced_tr_loss_speaker, epoch) 123 | 124 | # Save model 125 | if reduced_val_loss < self.best_val_loss: 126 | self.best_val_loss = reduced_val_loss 127 | checkpoint = {'model': self.model.state_dict(), 128 | 'optimizer': self.optimizer.state_dict(), 129 | 'amp': self.amp.state_dict(), 130 | 'epoch': epoch+1, 131 | 'prev_val_loss': self.prev_val_loss, 132 | 'best_val_loss': self.best_val_loss, 133 | 'val_no_impv': self.val_no_impv} 134 | torch.save(checkpoint, "logs/"+ self.args.log_name+"/model_dict.pt") 135 | print("Fund new best model, dict saved") 136 | 137 | 138 | def _run_one_epoch(self, data_loader, state): 139 | total_loss = 0 140 | total_loss_speaker = 0 141 | # total_acc_0=0 142 | # total_acc_1=0 143 | # total_acc_2=0 144 | # total_acc_3=0 145 | speaker_loss=0 146 | for i, (a_mix, a_tgt, v_tgt, speaker) in enumerate(data_loader): 147 | a_mix = a_mix.cuda().squeeze().float() 148 | a_tgt = a_tgt.cuda().squeeze().float() 149 | v_tgt = v_tgt.cuda().squeeze().float() 150 | speaker = speaker.cuda().squeeze() 151 | 152 | est_speaker, est_a_tgt = self.model(a_mix, v_tgt) 153 | max_snr = cal_SISNR(a_tgt, est_a_tgt) 154 | 155 | if state !='test': 156 | sisnr_loss = 0 - torch.mean(max_snr) 157 | speaker_loss = self.ae_loss(est_speaker[0], speaker) + \ 158 | self.ae_loss(est_speaker[1], speaker) + \ 159 | self.ae_loss(est_speaker[2], speaker) + \ 160 | self.ae_loss(est_speaker[3], speaker) 161 | loss = sisnr_loss + 0.1* speaker_loss #*np.power(0.96,self.joint_loss_weight-1) 162 | 163 | # total_acc_0 += self.cal_acc(est_speaker[0],speaker) 164 | # total_acc_1 += self.cal_acc(est_speaker[1],speaker) 165 | # total_acc_2 += self.cal_acc(est_speaker[2],speaker) 166 | # total_acc_3 += self.cal_acc(est_speaker[3],speaker) 167 | 168 | if state == 'train': 169 | self.optimizer.zero_grad() 170 | with self.amp.scale_loss(loss, self.optimizer) as scaled_loss: 171 | scaled_loss.backward() 172 | torch.nn.utils.clip_grad_norm_(self.amp.master_params(self.optimizer), 173 | self.args.max_norm) 174 | self.optimizer.step() 175 | 176 | # if state == 'val': 177 | # loss = sisnr_loss 178 | 179 | else: loss = 0 - torch.mean(max_snr[::self.args.C]) 180 | 181 | total_loss += loss.data 182 | total_loss_speaker += speaker_loss 183 | 184 | # print("speaker recognition acc: %s"%str(total_acc_0/(i+1)*100)) 185 | # print("speaker recognition acc: %s"%str(total_acc_1/(i+1)*100)) 186 | # print("speaker recognition acc: %s"%str(total_acc_2/(i+1)*100)) 187 | # print("speaker recognition acc: %s"%str(total_acc_3/(i+1)*100)) 188 | return total_loss / (i+1), total_loss_speaker/ (i+1) 189 | 190 | def _reduce_tensor(self, tensor): 191 | if not self.args.distributed: return tensor 192 | rt = tensor.clone() 193 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 194 | rt /= self.args.world_size 195 | return rt 196 | 197 | def cal_acc(self, output, target): 198 | pred = output.argmax(dim=1, keepdim=False) 199 | correct = 0 200 | total = 0 201 | for i in range(target.shape[0]): 202 | total += 1 203 | if (pred[i] == target[i]): 204 | correct += 1 205 | return correct/total 206 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from apex import amp 6 | import copy 7 | 8 | EPS = 1e-8 9 | 10 | def _clones(module, N): 11 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 12 | 13 | class muse(nn.Module): 14 | def __init__(self, N, L, B, H, P, X, R, C, M): 15 | super(muse, self).__init__() 16 | self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C 17 | 18 | self.encoder = Encoder(L, N) 19 | self.separator = TemporalConvNet(N, B, H, P, X, R, C, M) 20 | self.decoder = Decoder(N, L) 21 | 22 | for p in self.parameters(): 23 | if p.dim() > 1: 24 | nn.init.xavier_normal_(p) 25 | 26 | def forward(self, mixture, visual): 27 | mixture_w = self.encoder(mixture) 28 | est_a_emb, est_mask = self.separator(mixture_w, visual) 29 | est_source = self.decoder(mixture_w, est_mask) 30 | 31 | # T changed after conv1d in encoder, fix it here 32 | T_origin = mixture.size(-1) 33 | T_conv = est_source.size(-1) 34 | est_source = F.pad(est_source, (0, T_origin - T_conv)) 35 | return est_a_emb, est_source 36 | 37 | class Encoder(nn.Module): 38 | def __init__(self, L, N): 39 | super(Encoder, self).__init__() 40 | self.L, self.N = L, N 41 | self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=L // 2, bias=False) 42 | 43 | def forward(self, mixture): 44 | mixture = torch.unsqueeze(mixture, 1) # [M, 1, T] 45 | mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] 46 | return mixture_w 47 | 48 | class Decoder(nn.Module): 49 | def __init__(self, N, L): 50 | super(Decoder, self).__init__() 51 | self.N, self.L = N, L 52 | self.basis_signals = nn.Linear(N, L, bias=False) 53 | 54 | def forward(self, mixture_w, est_mask): 55 | est_source = mixture_w * est_mask # [M, N, K] 56 | est_source = torch.transpose(est_source, 2, 1) # [M, K, N] 57 | est_source = self.basis_signals(est_source) # [M, K, L] 58 | est_source = overlap_and_add(est_source, self.L//2) # M x C x T 59 | return est_source 60 | 61 | 62 | class TemporalConvNet(nn.Module): 63 | def __init__(self, N, B, H, P, X, R, C, M): 64 | super(TemporalConvNet, self).__init__() 65 | self.C = C 66 | self.layer_norm = ChannelWiseLayerNorm(N) 67 | self.bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) 68 | 69 | # Audio TCN 70 | tcn_blocks = [] 71 | tcn_blocks += [nn.Conv1d(B*3, B, 1, bias=False)] 72 | for x in range(X): 73 | dilation = 2**x 74 | padding = (P - 1) * dilation // 2 75 | tcn_blocks += [TemporalBlock(B, H, P, stride=1, 76 | padding=padding, 77 | dilation=dilation)] 78 | self.tcn = _clones(nn.Sequential(*tcn_blocks), R) 79 | 80 | # visual blocks 81 | ve_blocks = [] 82 | for x in range(5): 83 | ve_blocks +=[VisualConv1D()] 84 | self.visual_conv = nn.Sequential(*ve_blocks) 85 | 86 | # Audio and visual seprated layers before concatenation 87 | self.ve_conv1x1 = _clones(nn.Conv1d(512, B, 1, bias=False),R) 88 | self.ve_conv1x1_SE = _clones(nn.Conv1d(512, B, 1, bias=False),R) 89 | 90 | # speaker embedding extraction and classification 91 | self.se_net=_clones(SpeakerEmbedding(B), R) 92 | self.audio_linear=_clones(nn.Linear(B, M),R) 93 | 94 | # Mask generation layer 95 | self.mask_conv1x1 = nn.Conv1d(B, N, 1, bias=False) 96 | 97 | 98 | def forward(self, x, visual): 99 | visual = visual.transpose(1,2) 100 | visual = self.visual_conv(visual) 101 | 102 | x = self.layer_norm(x) 103 | x = self.bottleneck_conv1x1(x) 104 | 105 | mixture = x 106 | 107 | batch, B, K = x.size() 108 | 109 | est_a_emb=[] 110 | 111 | for i in range(len(self.tcn)): 112 | v = self.ve_conv1x1[i](visual) 113 | v = F.interpolate(v, (32*v.size()[-1]), mode='linear') 114 | v = F.pad(v,(0,K-v.size()[-1])) 115 | v_2 = self.ve_conv1x1_SE[i](visual) 116 | v_2 = F.interpolate(v_2, (32*v_2.size()[-1]), mode='linear') 117 | v_2 = F.pad(v_2,(0,K-v_2.size()[-1])) 118 | a = mixture*F.relu(x) 119 | a = self.se_net[i](torch.cat((a,v_2),1)) 120 | est_a_emb.append(self.audio_linear[i](a.squeeze())) 121 | a = torch.repeat_interleave(a, repeats=K, dim=2) 122 | x = torch.cat((a, x, v),1) 123 | x = self.tcn[i](x) 124 | 125 | x = self.mask_conv1x1(x) 126 | x = F.relu(x) 127 | est_a_emb = torch.stack(est_a_emb) 128 | return est_a_emb, x 129 | 130 | class SpeakerEmbedding(nn.Module): 131 | def __init__(self, B, R=3, H=256): 132 | super(SpeakerEmbedding, self).__init__() 133 | self.conv_proj = nn.Conv1d(B*2, B, 1, bias=False) 134 | Conv_1=nn.Conv1d(B, H, 1, bias=False) 135 | norm_1=nn.BatchNorm1d(H) 136 | prelu_1=nn.PReLU() 137 | Conv_2=nn.Conv1d(H, B, 1, bias=False) 138 | norm_2=nn.BatchNorm1d(B) 139 | self.resnet=_clones(nn.Sequential(Conv_1, norm_1,\ 140 | prelu_1, Conv_2, norm_2), R) 141 | self.prelu=_clones(nn.PReLU(),R) 142 | self.maxPool=_clones(nn.AvgPool1d(3),R) 143 | 144 | self.conv=nn.Conv1d(B,B,1) 145 | self.avgPool=nn.AdaptiveAvgPool1d(1) 146 | 147 | def forward(self, x): 148 | x = self.conv_proj(x) 149 | for i in range(len(self.resnet)): 150 | residual = x 151 | x = self.resnet[i](x) 152 | x = self.prelu[i](x+residual) 153 | x = self.maxPool[i](x) 154 | 155 | x = self.conv(x) 156 | x = self.avgPool(x) 157 | 158 | return x 159 | 160 | 161 | 162 | class VisualConv1D(nn.Module): 163 | def __init__(self): 164 | super(VisualConv1D, self).__init__() 165 | relu = nn.ReLU() 166 | norm_1 = nn.BatchNorm1d(512) 167 | dsconv = nn.Conv1d(512, 512, 3, stride=1, padding=1,dilation=1, groups=512, bias=False) 168 | prelu = nn.PReLU() 169 | norm_2 = nn.BatchNorm1d(512) 170 | pw_conv = nn.Conv1d(512, 512, 1, bias=False) 171 | 172 | self.net = nn.Sequential(relu, norm_1 ,dsconv, prelu, norm_2, pw_conv) 173 | 174 | def forward(self, x): 175 | out = self.net(x) 176 | return out + x 177 | 178 | class TemporalBlock(nn.Module): 179 | def __init__(self, in_channels, out_channels, kernel_size, 180 | stride, padding, dilation): 181 | super(TemporalBlock, self).__init__() 182 | conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) 183 | prelu = nn.PReLU() 184 | norm = GlobalLayerNorm(out_channels) 185 | dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, 186 | stride, padding, dilation) 187 | # Put together 188 | self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) 189 | 190 | def forward(self, x): 191 | 192 | residual = x 193 | out = self.net(x) 194 | return out + residual # look like w/o F.relu is better than w/ F.relu 195 | 196 | 197 | class DepthwiseSeparableConv(nn.Module): 198 | def __init__(self, in_channels, out_channels, kernel_size, 199 | stride, padding, dilation): 200 | super(DepthwiseSeparableConv, self).__init__() 201 | depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, 202 | stride=stride, padding=padding, 203 | dilation=dilation, groups=in_channels, 204 | bias=False) 205 | 206 | prelu = nn.PReLU() 207 | norm = GlobalLayerNorm(in_channels) 208 | pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) 209 | self.net = nn.Sequential(depthwise_conv, prelu, norm, 210 | pointwise_conv) 211 | 212 | def forward(self, x): 213 | return self.net(x) 214 | 215 | class ChannelWiseLayerNorm(nn.LayerNorm): 216 | @amp.float_function 217 | def __init__(self, *args, **kwargs): 218 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs) 219 | 220 | @amp.float_function 221 | def forward(self, x): 222 | if x.dim() != 3: 223 | raise RuntimeError("{} accept 3D tensor as input".format( 224 | self.__name__)) 225 | # N x C x T => N x T x C 226 | x = torch.transpose(x, 1, 2) 227 | # LN 228 | x = super().forward(x) 229 | # N x C x T => N x T x C 230 | x = torch.transpose(x, 1, 2) 231 | return x 232 | 233 | 234 | class GlobalLayerNorm(nn.Module): 235 | """Global Layer Normalization (gLN)""" 236 | @amp.float_function 237 | def __init__(self, channel_size): 238 | super(GlobalLayerNorm, self).__init__() 239 | self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] 240 | self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1] 241 | self.reset_parameters() 242 | 243 | @amp.float_function 244 | def reset_parameters(self): 245 | self.gamma.data.fill_(1) 246 | self.beta.data.zero_() 247 | 248 | @amp.float_function 249 | def forward(self, y): 250 | """ 251 | Args: 252 | y: [M, N, K], M is batch size, N is channel size, K is length 253 | Returns: 254 | gLN_y: [M, N, K] 255 | """ 256 | # TODO: in torch 1.0, torch.mean() support dim list 257 | mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1] 258 | var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) 259 | gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta 260 | return gLN_y 261 | 262 | @amp.float_function 263 | def overlap_and_add(signal, frame_step): 264 | """Reconstructs a signal from a framed representation. 265 | Adds potentially overlapping frames of a signal with shape 266 | `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 267 | The resulting tensor has shape `[..., output_size]` where 268 | output_size = (frames - 1) * frame_step + frame_length 269 | Args: 270 | signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. 271 | frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. 272 | Returns: 273 | A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. 274 | output_size = (frames - 1) * frame_step + frame_length 275 | Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py 276 | """ 277 | outer_dimensions = signal.size()[:-2] 278 | frames, frame_length = signal.size()[-2:] 279 | 280 | subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor 281 | subframe_step = frame_step // subframe_length 282 | subframes_per_frame = frame_length // subframe_length 283 | output_size = frame_step * (frames - 1) + frame_length 284 | output_subframes = output_size // subframe_length 285 | 286 | subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) 287 | 288 | frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) 289 | frame = signal.new_tensor(frame).long() # signal may in GPU or CPU 290 | frame = frame.contiguous().view(-1) 291 | 292 | result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) 293 | result.index_add_(-2, frame, subframe_signal) 294 | result = result.view(*outer_dimensions, -1) 295 | return result 296 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from utils import * 4 | import os 5 | from aveNet import aveNet 6 | from mir_eval.separation import bss_eval_sources 7 | from pystoi import stoi 8 | from pesq import pesq 9 | 10 | MAX_INT16 = np.iinfo(np.int16).max 11 | 12 | def write_wav(fname, samps, sampling_rate=16000, normalize=True): 13 | """ 14 | Write wav files in int16, support single/multi-channel 15 | """ 16 | # for multi-channel, accept ndarray [Nsamples, Nchannels] 17 | samps = np.divide(samps, np.max(np.abs(samps))) 18 | 19 | # same as MATLAB and kaldi 20 | if normalize: 21 | samps = samps * MAX_INT16 22 | samps = samps.astype(np.int16) 23 | fdir = os.path.dirname(fname) 24 | if fdir and not os.path.exists(fdir): 25 | os.makedirs(fdir) 26 | # NOTE: librosa 0.6.0 seems could not write non-float narray 27 | # so use scipy.io.wavfile instead 28 | wavfile.write(fname, sampling_rate, samps) 29 | 30 | def SDR(est, egs, mix): 31 | ''' 32 | calculate SDR 33 | est: Network generated audio 34 | egs: Ground Truth 35 | ''' 36 | sdr, _, _, _ = bss_eval_sources(egs, est) 37 | mix_sdr, _, _, _ = bss_eval_sources(egs, mix) 38 | return float(sdr-mix_sdr) 39 | 40 | class dataset(data.Dataset): 41 | def __init__(self, 42 | mix_lst_path, 43 | audio_direc, 44 | visual_direc, 45 | mixture_direc, 46 | batch_size=1, 47 | partition='test', 48 | sampling_rate=16000, 49 | mix_no=2): 50 | 51 | self.minibatch =[] 52 | self.audio_direc = audio_direc 53 | self.visual_direc = visual_direc 54 | self.mixture_direc = mixture_direc 55 | self.sampling_rate = sampling_rate 56 | self.partition = partition 57 | self.C=mix_no 58 | 59 | mix_csv=open(mix_lst_path).read().splitlines() 60 | self.mix_lst=list(filter(lambda x: x.split(',')[0]==partition, mix_csv)) 61 | 62 | def __getitem__(self, index): 63 | line = self.mix_lst[index] 64 | 65 | mixture_path=self.mixture_direc+self.partition+'/'+ line.replace(',','_').replace('/', '_') +'.wav' 66 | _, mixture = wavfile.read(mixture_path) 67 | mixture = self._audio_norm(mixture) 68 | 69 | min_length = mixture.shape[0] 70 | 71 | line=line.split(',') 72 | c=0 73 | audio_path=self.audio_direc+line[c*4+1]+'/'+line[c*4+2]+'/'+line[c*4+3]+'.wav' 74 | _, audio = wavfile.read(audio_path) 75 | audio = self._audio_norm(audio[:min_length]) 76 | 77 | # read visual 78 | visual_path=self.visual_direc+line[c*4+1]+'/'+line[c*4+2]+'/'+line[c*4+3]+'.npy' 79 | visual = np.load(visual_path) 80 | length = math.floor(min_length/self.sampling_rate*25) 81 | visual = visual[:length,...] 82 | a = visual.shape[0] 83 | if visual.shape[0] < length: 84 | visual = np.pad(visual, ((0,int(length - visual.shape[0])),(0,0)), mode = 'edge') 85 | 86 | embedding = 0 87 | 88 | return mixture, audio, visual, embedding, (line[c*4+2]+'/'+line[c*4+3]) 89 | 90 | def __len__(self): 91 | return len(self.mix_lst) 92 | 93 | def _audio_norm(self,audio): 94 | return np.divide(audio, np.max(np.abs(audio))) 95 | 96 | def main(args): 97 | # Model 98 | model = aveNet(args.N, args.L, args.B, args.H, args.P, args.X, args.R, 99 | args.C, 800) 100 | 101 | model = model.cuda() 102 | pretrained_model = torch.load('%smodel_dict.pt' % args.continue_from, map_location='cpu')['model'] 103 | 104 | state = model.state_dict() 105 | for key in state.keys(): 106 | pretrain_key = 'module.' + key 107 | if pretrain_key in pretrained_model.keys(): 108 | state[key] = pretrained_model[pretrain_key] 109 | model.load_state_dict(state) 110 | 111 | datasets = dataset( 112 | mix_lst_path=args.mix_lst_path, 113 | audio_direc=args.audio_direc, 114 | visual_direc=args.visual_direc, 115 | mixture_direc=args.mixture_direc, 116 | mix_no=args.C) 117 | 118 | test_generator = data.DataLoader(datasets, 119 | batch_size = 1, 120 | shuffle = False, 121 | num_workers = args.num_workers) 122 | 123 | 124 | model.eval() 125 | with torch.no_grad(): 126 | avg_sisnri = 0 127 | avg_sdri = 0 128 | avg_pesqi = 0 129 | avg_stoii = 0 130 | for i, (a_mix, a_tgt, v_tgt, a_emb, fname) in enumerate(tqdm.tqdm(test_generator)): 131 | a_mix = a_mix.cuda().squeeze().float().unsqueeze(0) 132 | a_tgt = a_tgt.cuda().squeeze().float().unsqueeze(0) 133 | v_tgt = v_tgt.cuda().squeeze().float().unsqueeze(0) 134 | a_emb = a_emb.cuda().squeeze() 135 | 136 | est_speaker, estimate_source = model(a_mix, v_tgt) 137 | 138 | sisnr_mix = cal_SISNR(a_tgt, a_mix) 139 | sisnr_est = cal_SISNR(a_tgt, estimate_source) 140 | sisnri = sisnr_est - sisnr_mix 141 | avg_sisnri += sisnri 142 | print(sisnri) 143 | 144 | estimate_source = estimate_source.squeeze().cpu().numpy() 145 | a_tgt = a_tgt.squeeze().cpu().numpy() 146 | a_mix = a_mix.squeeze().cpu().numpy() 147 | 148 | 149 | # save_path = "/home/panzexu/samples/avaNet1/" + fname[0] +'.wav' 150 | # write_wav(save_path, estimate_source) 151 | 152 | save_path = "/home/panzexu/samples/avaNet1_emb/test/" + fname[0] +'.npy' 153 | # print(save_path) 154 | if not os.path.exists(save_path.rsplit('/', 1)[0]): 155 | os.makedirs(save_path.rsplit('/', 1)[0]) 156 | np.save(save_path, est_speaker.cpu().numpy()) 157 | 158 | # avg_sdri += SDR(estimate_source, a_tgt, a_mix) 159 | # # print(SDR(estimate_source, a_tgt, a_mix)) 160 | # avg_pesqi += (pesq(16000, a_tgt, estimate_source, 'wb') - pesq(16000, a_tgt, a_mix, 'wb')) 161 | # # print(pesq(16000, a_tgt, estimate_source, 'wb') - pesq(16000, a_tgt, a_mix, 'wb')) 162 | # avg_stoii += (stoi(a_tgt, estimate_source, 16000, extended=False) - stoi(a_tgt, a_mix, 16000, extended=False)) 163 | # # print((stoi(a_tgt, estimate_source, 16000, extended=False) - stoi(a_tgt, a_mix, 16000, extended=False))) 164 | # # if sisnri < 10 and sisnr_est < 10: 165 | # # wavfile.write('/home/panzexu/listen/ava1/%s_tgt.wav' %i, 16000, a_tgt) 166 | # # wavfile.write('/home/panzexu/listen/ava1/%s_mix_%s.wav' %(i, sisnr_mix.item()), 16000, a_mix) 167 | # # wavfile.write('/home/panzexu/listen/ava1/%s_est_%s.wav' %(i, sisnr_est.item()), 16000, np.divide(estimate_source, np.max(np.abs(estimate_source)))) 168 | # # # if i >=10: 169 | # # break 170 | 171 | avg_sisnri = avg_sisnri / (i+1) 172 | avg_sdri = avg_sdri / (i+1) 173 | avg_pesqi = avg_pesqi / (i+1) 174 | avg_stoii = avg_stoii / (i+1) 175 | print(avg_sisnri) 176 | print(avg_sdri) 177 | print(avg_pesqi) 178 | print(avg_stoii) 179 | 180 | if __name__ == '__main__': 181 | parser = argparse.ArgumentParser("avConv-tasnet") 182 | 183 | # Dataloader 184 | # parser.add_argument('--mix_lst_path', type=str, default='/home/panzexu/datasets/voxceleb2/audio_mixture/2_mix_min_800/mixture_data_list_2mix.csv', 185 | # help='directory including train data') 186 | # parser.add_argument('--audio_direc', type=str, default='/home/panzexu/datasets/voxceleb2/audio_clean/', 187 | # help='directory including validation data') 188 | # parser.add_argument('--visual_direc', type=str, default='/home/panzexu/datasets/voxceleb2/visual_embedding/lip/', 189 | # help='directory including test data') 190 | # parser.add_argument('--mixture_direc', type=str, default='/home/panzexu/datasets/voxceleb2/audio_mixture/2_mix_min_800/', 191 | # help='directory of audio') 192 | # parser.add_argument('--mix_lst_path', type=str, default='/home/panzexu/datasets/LRS3/audio_mixture/2_mix_min/mixture_data_list_2mix.csv', 193 | # help='directory including train data') 194 | # parser.add_argument('--audio_direc', type=str, default='/home/panzexu/datasets/LRS3/audio_clean/', 195 | # help='directory including validation data') 196 | # parser.add_argument('--visual_direc', type=str, default='/home/panzexu/datasets/LRS3/visual_embedding/lip/', 197 | # help='directory including test data') 198 | # parser.add_argument('--mixture_direc', type=str, default='/home/panzexu/datasets/LRS3/audio_mixture/2_mix_min/', 199 | # help='directory of audio') 200 | # # # grid 201 | # TCDTIMITDataset 202 | # LRS2 203 | # LRS3 204 | # avspeech_subset 205 | 206 | # Log and Visulization 207 | # parser.add_argument('--continue_from', type=str, 208 | # default='/home/panzexu/workspace/avss_speaker_embedding/log/voxceleb_800/2_mix/avaNet_2020-09-29(19:27:48)_wo/') 209 | # parser.add_argument('--continue_from', type=str, 210 | # default='/home/panzexu/workspace/avss_speaker_embedding/src/avaNet1/logs/avaNet_2020-09-29(19:27:48)/') 211 | 212 | # parser.add_argument('--mix_lst_path', type=str, default='/home/panzexu/datasets/voxceleb2/audio_mixture/3_mix_min_800/mixture_data_list_3mix.csv', 213 | # help='directory including train data') 214 | # parser.add_argument('--audio_direc', type=str, default='/home/panzexu/datasets/voxceleb2/audio_clean/', 215 | # help='directory including validation data') 216 | # parser.add_argument('--visual_direc', type=str, default='/home/panzexu/datasets/voxceleb2/visual_embedding/lip/', 217 | # help='directory including test data') 218 | # parser.add_argument('--mixture_direc', type=str, default='/home/panzexu/datasets/voxceleb2/audio_mixture/3_mix_min_800/', 219 | # help='directory of audio') 220 | parser.add_argument('--mix_lst_path', type=str, default='/home/panzexu/datasets/LRS3/audio_mixture/3_mix_min/mixture_data_list_3mix.csv', 221 | help='directory including train data') 222 | parser.add_argument('--audio_direc', type=str, default='/home/panzexu/datasets/LRS3/audio_clean/', 223 | help='directory including validation data') 224 | parser.add_argument('--visual_direc', type=str, default='/home/panzexu/datasets/LRS3/visual_embedding/lip/', 225 | help='directory including test data') 226 | parser.add_argument('--mixture_direc', type=str, default='/home/panzexu/datasets/LRS3/audio_mixture/3_mix_min/', 227 | help='directory of audio') 228 | 229 | parser.add_argument('--continue_from', type=str, 230 | default='/home/panzexu/workspace/avss_speaker_embedding/log/voxceleb_800/3_mix/avaNet_2020-10-04(16:49:32)/') 231 | 232 | 233 | # Training 234 | parser.add_argument('--num_workers', default=4, type=int, 235 | help='Number of workers to generate minibatch') 236 | 237 | # Model hyperparameters 238 | parser.add_argument('--L', default=40, type=int, 239 | help='Length of the filters in samples (80=5ms at 16kHZ)') 240 | parser.add_argument('--N', default=256, type=int, 241 | help='Number of filters in autoencoder') 242 | parser.add_argument('--B', default=256, type=int, 243 | help='Number of channels in bottleneck 1 × 1-conv block') 244 | parser.add_argument('--C', type=int, default=2, 245 | help='number of speakers to mix') 246 | parser.add_argument('--H', default=512, type=int, 247 | help='Number of channels in convolutional blocks') 248 | parser.add_argument('--P', default=3, type=int, 249 | help='Kernel size in convolutional blocks') 250 | parser.add_argument('--X', default=8, type=int, 251 | help='Number of convolutional blocks in each repeat') 252 | parser.add_argument('--R', default=4, type=int, 253 | help='Number of repeats') 254 | 255 | args = parser.parse_args() 256 | 257 | main(args) --------------------------------------------------------------------------------