├── .gitignore ├── README.md ├── data └── LRS2 │ ├── 1_create_mixture_list.py │ ├── 2_create_mixture.py │ ├── mixture_data_list_2mix.csv │ └── preprocess.sh └── src ├── .gitignore └── av-dprnn ├── avDprnn.py ├── main.py ├── run.sh ├── solver.py ├── stft_loss.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | dict/ 2 | ignore/ 3 | log/**/*.pt 4 | **/__pycache__/ 5 | **/hlt.txt 6 | **/slurm* 7 | **/tmp_list*.npy 8 | **/*.tar 9 | **/score 10 | **/egs1/ 11 | **/.DS_Store 12 | **/._* 13 | 14 | deep_avsr_weights/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## The Hybrid Continuity Loss for Target Speaker Extraction 2 | 3 | A PyTorch implementation of the [A Hybrid Continuity Loss Function to Reduce Over-Suppression for Time-domain Target Speaker Extraction](https://arxiv.org/abs/2203.16843) 4 | 5 | ## Project Structure 6 | 7 | `/data`: Scripts to pre-process the LRS2 dataset. 8 | 9 | `/src`: The training scripts of the network. 10 | 11 | -------------------------------------------------------------------------------- /data/LRS2/1_create_mixture_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import csv 5 | import tqdm 6 | import scipy.io.wavfile as wavfile 7 | 8 | np.random.seed(0) 9 | 10 | def extract_wav_from_mp4(line): 11 | # Extract .wav file from mp4 12 | video_from_path=args.data_direc + line[0]+'/'+line[1]+'/'+line[2]+'.mp4' 13 | audio_save_direc=args.audio_data_direc + line[0]+'/'+line[1]+'/' 14 | if not os.path.exists(audio_save_direc): 15 | os.makedirs(audio_save_direc) 16 | audio_save_path=audio_save_direc+line[2]+'.wav' 17 | if not os.path.exists(audio_save_path): 18 | os.system("ffmpeg -i %s %s"%(video_from_path, audio_save_path)) 19 | 20 | sr, audio = wavfile.read(audio_save_path) 21 | assert sr==args.sampling_rate , "sampling_rate mismatch" 22 | sample_length = audio.shape[0]/args.sampling_rate 23 | return sample_length # In seconds 24 | 25 | 26 | def main(args): 27 | # read the datalist and separate into train, val and test set 28 | train_list=[] 29 | val_list=[] 30 | test_list=[] 31 | 32 | test_data_list = open(args.test_list).read().splitlines() 33 | for line in tqdm.tqdm(test_data_list, desc='Processing Test List'): 34 | line = line.split(' ') 35 | line = line[0].split('/') 36 | ln = ('main',line[0],line[1]) 37 | sample_length = extract_wav_from_mp4(ln) 38 | test_list.append(('main',line[0],line[1], sample_length)) 39 | 40 | val_data_list = open(args.val_list).read().splitlines() 41 | for line in tqdm.tqdm(val_data_list, desc='Processing Validation List'): 42 | line = line.split('/') 43 | ln = ('main',line[0],line[1]) 44 | sample_length = extract_wav_from_mp4(ln) 45 | val_list.append(('main',line[0],line[1], sample_length)) 46 | 47 | train_data_list = open(args.train_list).read().splitlines() 48 | for line in tqdm.tqdm(train_data_list, desc='Processing Train List'): 49 | line = line.split('/') 50 | ln = ('main',line[0],line[1]) 51 | sample_length = extract_wav_from_mp4(ln) 52 | ln=('main',line[0],line[1], sample_length) 53 | train_list.append(ln) 54 | 55 | # Create mixture list 56 | f=open(args.mixture_data_list,'w') 57 | w=csv.writer(f) 58 | create_mixture_list(args, 'test', args.test_samples, test_list, w) 59 | create_mixture_list(args, 'val', args.val_samples, val_list, w) 60 | create_mixture_list(args, 'train', args.train_samples, train_list, w) 61 | f.close() 62 | 63 | 64 | def create_mixture_list(args, data, length, data_list, w): 65 | # data_list = sorted(data_list, key=lambda data: data[3], reverse=True) 66 | for _ in range(length): 67 | mixtures=[data] 68 | cache = [] 69 | 70 | # target speaker 71 | idx = np.random.randint(0, len(data_list)) 72 | cache.append(idx) 73 | mixtures = mixtures + list(data_list[idx]) 74 | shortest = mixtures[-1] 75 | del mixtures[-1] 76 | mixtures.append(0) 77 | 78 | while len(cache) < args.C: 79 | idx = np.random.randint(0, len(data_list)) 80 | if idx in cache: 81 | continue 82 | cache.append(idx) 83 | mixtures = mixtures + list(data_list[idx]) 84 | del mixtures[-1] 85 | db_ratio = np.random.uniform(-args.mix_db,args.mix_db) 86 | mixtures.append(db_ratio) 87 | mixtures.append(shortest) 88 | w.writerow(mixtures) 89 | 90 | if __name__ == '__main__': 91 | parser = argparse.ArgumentParser(description='LRS2 dataset') 92 | parser.add_argument('--data_direc', type=str) 93 | parser.add_argument('--pretrain_list', type=str) 94 | parser.add_argument('--train_list', type=str) 95 | parser.add_argument('--val_list', type=str) 96 | parser.add_argument('--test_list', type=str) 97 | parser.add_argument('--C', type=int) 98 | parser.add_argument('--mix_db', type=float) 99 | parser.add_argument('--train_samples', type=int) 100 | parser.add_argument('--val_samples', type=int) 101 | parser.add_argument('--test_samples', type=int) 102 | parser.add_argument('--audio_data_direc', type=str) 103 | parser.add_argument('--sampling_rate', type=int) 104 | parser.add_argument('--mixture_data_list', type=str) 105 | args = parser.parse_args() 106 | 107 | main(args) -------------------------------------------------------------------------------- /data/LRS2/2_create_mixture.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import tqdm 5 | import scipy.io.wavfile as wavfile 6 | 7 | MAX_INT16 = np.iinfo(np.int16).max 8 | 9 | def write_wav(fname, samps, sampling_rate=16000, normalize=True): 10 | """ 11 | Write wav files in int16, support single/multi-channel 12 | """ 13 | # for multi-channel, accept ndarray [Nsamples, Nchannels] 14 | if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 15 | samps = np.transpose(samps) 16 | samps = np.squeeze(samps) 17 | # same as MATLAB and kaldi 18 | if normalize: 19 | samps = samps * MAX_INT16 20 | samps = samps.astype(np.int16) 21 | fdir = os.path.dirname(fname) 22 | if fdir and not os.path.exists(fdir): 23 | os.makedirs(fdir) 24 | # NOTE: librosa 0.6.0 seems could not write non-float narray 25 | # so use scipy.io.wavfile instead 26 | wavfile.write(fname, sampling_rate, samps) 27 | 28 | 29 | def read_wav(fname, normalize=True): 30 | """ 31 | Read wave files using scipy.io.wavfile(support multi-channel) 32 | """ 33 | # samps_int16: N x C or N 34 | # N: number of samples 35 | # C: number of channels 36 | sampling_rate, samps_int16 = wavfile.read(fname) 37 | # N x C => C x N 38 | samps = samps_int16.astype(np.float) 39 | # tranpose because I used to put channel axis first 40 | if samps.ndim != 1: 41 | samps = np.transpose(samps) 42 | # normalize like MATLAB and librosa 43 | if normalize: 44 | samps = samps / MAX_INT16 45 | return sampling_rate, samps 46 | 47 | def main(args): 48 | # create mixture 49 | mixture_data_list = open(args.mixture_data_list).read().splitlines() 50 | print(len(mixture_data_list)) 51 | for line in tqdm.tqdm(mixture_data_list,desc = "Generating audio mixtures"): 52 | data = line.split(',') 53 | save_direc=args.mixture_audio_direc+data[0]+'/' 54 | if not os.path.exists(save_direc): 55 | os.makedirs(save_direc) 56 | 57 | mixture_save_path=save_direc+line.replace(',','_')+'.wav' 58 | if os.path.exists(mixture_save_path): 59 | continue 60 | 61 | # read target audio 62 | _, audio_mix=read_wav(args.audio_data_direc+data[1]+'/'+data[2]+'/'+data[3]+'.wav') 63 | target_power = np.linalg.norm(audio_mix, 2)**2 / audio_mix.size 64 | 65 | # read inteference audio 66 | for c in range(1, args.C): 67 | audio_path=args.audio_data_direc+data[c*4+1]+'/'+data[c*4+2]+'/'+data[c*4+3]+'.wav' 68 | _, audio = read_wav(audio_path) 69 | intef_power = np.linalg.norm(audio, 2)**2 / audio.size 70 | 71 | # audio = audio_norm(audio) 72 | scalar = (10**(float(data[c*4+4])/20)) * np.sqrt(target_power/intef_power) 73 | audio = audio * scalar 74 | 75 | # truncate long audio with short audio in the mixture 76 | if audio_mix.shape[0] > audio.shape[0]: 77 | audio_mix[:audio.shape[0]] = audio_mix[:audio.shape[0]] + audio 78 | else: audio_mix = audio_mix + audio[:audio_mix.shape[0]] 79 | 80 | audio_mix = np.divide(audio_mix, np.max(np.abs(audio_mix))) 81 | write_wav(mixture_save_path, audio_mix) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser(description='LRS2 dataset') 86 | parser.add_argument('--C', type=int) 87 | parser.add_argument('--audio_data_direc', type=str) 88 | parser.add_argument('--mixture_audio_direc', type=str) 89 | parser.add_argument('--mixture_data_list', type=str) 90 | args = parser.parse_args() 91 | main(args) -------------------------------------------------------------------------------- /data/LRS2/preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | direc=/home/panzexu/datasets/LRS2/ 4 | 5 | data_direc=${direc}mvlrs_v1/ 6 | pretrain_list=${data_direc}pretrain_list.txt 7 | train_list=${data_direc}train_list.txt 8 | val_list=${data_direc}val_list.txt 9 | test_list=${data_direc}test_list.txt 10 | 11 | train_samples=200000 # no. of train mixture samples simulated 12 | val_samples=5000 # no. of validation mixture samples simulated 13 | test_samples=3000 # no. of test mixture samples simulated 14 | C=2 # no. of speakers in the mixture 15 | mix_db=10 # random db ratio from -10 to 10db 16 | mixture_data_list=mixture_data_list_${C}mix.csv #mixture datalist 17 | sampling_rate=16000 # audio sampling rate 18 | 19 | audio_data_direc=${direc}audio_clean/ # Target audio saved directory 20 | mixture_audio_direc=${direc}audio_mixture/${C}_mix_min_asr/ # Audio mixture saved directory 21 | 22 | #stage 1: create mixture list 23 | echo 'stage 1: create mixture list' 24 | python 1_create_mixture_list.py \ 25 | --data_direc $data_direc \ 26 | --pretrain_list $pretrain_list \ 27 | --train_list $train_list \ 28 | --val_list $val_list \ 29 | --test_list $test_list \ 30 | --C $C \ 31 | --mix_db $mix_db \ 32 | --train_samples $train_samples \ 33 | --val_samples $val_samples \ 34 | --test_samples $test_samples \ 35 | --audio_data_direc $audio_data_direc \ 36 | --sampling_rate $sampling_rate \ 37 | --mixture_data_list $mixture_data_list \ 38 | 39 | # stage 2: create audio mixture from list 40 | echo 'stage 2: create mixture audios' 41 | python 2_create_mixture.py \ 42 | --C $C \ 43 | --audio_data_direc $audio_data_direc \ 44 | --mixture_audio_direc $mixture_audio_direc \ 45 | --mixture_data_list $mixture_data_list \ 46 | 47 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | 133 | deep_avsr_weights/ -------------------------------------------------------------------------------- /src/av-dprnn/avDprnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from apex import amp 6 | 7 | EPS = 1e-8 8 | 9 | class avDprnn(nn.Module): 10 | def __init__(self, N, L, B, H, K, R, C): 11 | 12 | super(avDprnn, self).__init__() 13 | self.N, self.L, self.B, self.H, self.K, self.R, self.C = N, L, B, H, K, R, C 14 | 15 | self.encoder = Encoder(L, N) 16 | self.separator = rnn(N, B, H, K, R, C) 17 | self.decoder = Decoder(N, L) 18 | 19 | for p in self.parameters(): 20 | if p.dim() > 1: 21 | nn.init.xavier_normal_(p) 22 | 23 | def forward(self, mixture, visual): 24 | """ 25 | Args: 26 | mixture: [M, T], M is batch size, T is #samples 27 | Returns: 28 | est_source: [M, C, T] 29 | """ 30 | mixture_w = self.encoder(mixture) 31 | 32 | # # pad to *8 for tensor core 33 | # T_mix=mixture_w.size(-1) 34 | # res = T_mix%8 35 | # if res != 0: 36 | # mixture_w=F.pad(mixture_w,(0,8-res)) 37 | 38 | est_mask = self.separator(mixture_w, visual) 39 | est_source = self.decoder(mixture_w, est_mask) 40 | 41 | # T changed after conv1d in encoder, fix it here 42 | T_origin = mixture.size(-1) 43 | T_conv = est_source.size(-1) 44 | est_source = F.pad(est_source, (0, T_origin - T_conv)) 45 | return est_source 46 | 47 | class Encoder(nn.Module): 48 | def __init__(self, L, N): 49 | super(Encoder, self).__init__() 50 | self.L, self.N = L, N 51 | self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=L // 2, bias=False) 52 | 53 | def forward(self, mixture): 54 | """ 55 | Args: 56 | mixture: [M, T], M is batch size, T is #samples 57 | Returns: 58 | mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 59 | """ 60 | mixture = torch.unsqueeze(mixture, 1) # [M, 1, T] 61 | mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K] 62 | return mixture_w 63 | 64 | 65 | class Decoder(nn.Module): 66 | def __init__(self, N, L): 67 | super(Decoder, self).__init__() 68 | self.N, self.L = N, L 69 | self.basis_signals = nn.Linear(N, L, bias=False) 70 | 71 | def forward(self, mixture_w, est_mask): 72 | """ 73 | Args: 74 | mixture_w: [M, N, K] 75 | est_mask: [M, C, N, K] 76 | Returns: 77 | est_source: [M, C, T] 78 | """ 79 | est_source = mixture_w * est_mask # [M, N, K] 80 | est_source = torch.transpose(est_source, 2, 1) # [M, K, N] 81 | est_source = self.basis_signals(est_source) # [M, K, L] 82 | est_source = overlap_and_add(est_source, self.L//2) # M x C x T 83 | return est_source 84 | 85 | class Dual_RNN_Block(nn.Module): 86 | ''' 87 | Implementation of the intra-RNN and the inter-RNN 88 | input: 89 | in_channels: The number of expected features in the input x 90 | out_channels: The number of features in the hidden state h 91 | rnn_type: RNN, LSTM, GRU 92 | norm: gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm" 93 | dropout: If non-zero, introduces a Dropout layer on the outputs 94 | of each LSTM layer except the last layer, 95 | with dropout probability equal to dropout. Default: 0 96 | bidirectional: If True, becomes a bidirectional LSTM. Default: False 97 | ''' 98 | 99 | def __init__(self, out_channels, 100 | hidden_channels, rnn_type='LSTM', 101 | dropout=0, bidirectional=False, num_spks=2): 102 | super(Dual_RNN_Block, self).__init__() 103 | # RNN model 104 | self.intra_rnn = getattr(nn, rnn_type)( 105 | out_channels, hidden_channels, 1, batch_first=True, dropout=dropout, bidirectional=bidirectional) 106 | self.inter_rnn = getattr(nn, rnn_type)( 107 | out_channels, hidden_channels, 1, batch_first=True, dropout=dropout, bidirectional=bidirectional) 108 | # Norm 109 | self.intra_norm = nn.GroupNorm(1, out_channels, eps=1e-8) 110 | self.inter_norm = nn.GroupNorm(1, out_channels, eps=1e-8) 111 | # Linear 112 | self.intra_linear = nn.Linear( 113 | hidden_channels*2 if bidirectional else hidden_channels, out_channels) 114 | self.inter_linear = nn.Linear( 115 | hidden_channels*2 if bidirectional else hidden_channels, out_channels) 116 | 117 | 118 | def forward(self, x): 119 | ''' 120 | x: [B, N, K, S] 121 | out: [Spks, B, N, K, S] 122 | ''' 123 | B, N, K, S = x.shape 124 | # intra RNN 125 | # [BS, K, N] 126 | intra_rnn = x.permute(0, 3, 2, 1).contiguous().view(B*S, K, N) 127 | # [BS, K, H] 128 | intra_rnn, _ = self.intra_rnn(intra_rnn) 129 | # [BS, K, N] 130 | intra_rnn = self.intra_linear(intra_rnn.contiguous().view(B*S*K, -1)).view(B*S, K, -1) 131 | # [B, S, K, N] 132 | intra_rnn = intra_rnn.view(B, S, K, N) 133 | # [B, N, K, S] 134 | intra_rnn = intra_rnn.permute(0, 3, 2, 1).contiguous() 135 | intra_rnn = self.intra_norm(intra_rnn) 136 | 137 | # [B, N, K, S] 138 | intra_rnn = intra_rnn + x 139 | 140 | # inter RNN 141 | # [BK, S, N] 142 | inter_rnn = intra_rnn.permute(0, 2, 3, 1).contiguous().view(B*K, S, N) 143 | # [BK, S, H] 144 | inter_rnn, _ = self.inter_rnn(inter_rnn) 145 | # [BK, S, N] 146 | inter_rnn = self.inter_linear(inter_rnn.contiguous().view(B*S*K, -1)).view(B*K, S, -1) 147 | # [B, K, S, N] 148 | inter_rnn = inter_rnn.view(B, K, S, N) 149 | # [B, N, K, S] 150 | inter_rnn = inter_rnn.permute(0, 3, 1, 2).contiguous() 151 | inter_rnn = self.inter_norm(inter_rnn) 152 | # [B, N, K, S] 153 | out = inter_rnn + intra_rnn 154 | 155 | return out 156 | 157 | class rnn(nn.Module): 158 | def __init__(self, N, B, H, K, R, C): 159 | super(rnn, self).__init__() 160 | self.C ,self.K , self.R = C, K, R 161 | # [M, N, K] -> [M, N, K] 162 | self.layer_norm = nn.GroupNorm(1, N, eps=1e-8) 163 | # [M, N, K] -> [M, B, K] 164 | self.bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) 165 | 166 | self.dual_rnn = nn.ModuleList([]) 167 | for i in range(R): 168 | self.dual_rnn.append(Dual_RNN_Block(B, H, 169 | rnn_type='LSTM', dropout=0, 170 | bidirectional=True)) 171 | 172 | self.prelu = nn.PReLU() 173 | self.mask_conv1x1 = nn.Conv1d(B, N, 1, bias=False) 174 | 175 | 176 | # visual 177 | stacks = [] 178 | for x in range(5): 179 | stacks +=[VisualConv1D(V=256, H=512)] 180 | self.visual_conv = nn.Sequential(*stacks) 181 | self.v_ds = nn.Conv1d(512, 256, 1, bias=False) 182 | self.av_conv = nn.Conv1d(B+256, B, 1, bias=False) 183 | 184 | 185 | def forward(self, x, visual): 186 | """ 187 | Keep this API same with TasNet 188 | Args: 189 | mixture_w: [M, N, K], M is batch size 190 | returns: 191 | est_mask: [M, C, N, K] 192 | """ 193 | M, N, D = x.size() 194 | 195 | visual = visual.transpose(1,2) 196 | visual = self.v_ds(visual) 197 | visual = self.visual_conv(visual) 198 | visual = F.interpolate(visual, (D), mode='linear') 199 | 200 | x = self.layer_norm(x) # [M, N, K] 201 | x = self.bottleneck_conv1x1(x) # [M, B, K] 202 | 203 | x = torch.cat((x, visual),1) 204 | x = self.av_conv(x) 205 | 206 | x, gap = self._Segmentation(x, self.K) # [M, B, k, S] 207 | 208 | for i in range(self.R): 209 | x = self.dual_rnn[i](x) 210 | 211 | x = self._over_add(x, gap) 212 | 213 | x = self.prelu(x) 214 | x = self.mask_conv1x1(x) 215 | 216 | x = x.view(M, N, D) # [M, C*N, K] -> [M, C, N, K] 217 | x = F.relu(x) 218 | return x 219 | 220 | def _padding(self, input, K): 221 | ''' 222 | padding the audio times 223 | K: chunks of length 224 | P: hop size 225 | input: [B, N, L] 226 | ''' 227 | B, N, L = input.shape 228 | P = K // 2 229 | gap = K - (P + L % K) % K 230 | if gap > 0: 231 | pad = torch.Tensor(torch.zeros(B, N, gap)).type(input.type()) 232 | input = torch.cat([input, pad], dim=2) 233 | 234 | _pad = torch.Tensor(torch.zeros(B, N, P)).type(input.type()) 235 | input = torch.cat([_pad, input, _pad], dim=2) 236 | 237 | return input, gap 238 | 239 | def _Segmentation(self, input, K): 240 | ''' 241 | the segmentation stage splits 242 | K: chunks of length 243 | P: hop size 244 | input: [B, N, L] 245 | output: [B, N, K, S] 246 | ''' 247 | B, N, L = input.shape 248 | P = K // 2 249 | input, gap = self._padding(input, K) 250 | # [B, N, K, S] 251 | input1 = input[:, :, :-P].contiguous().view(B, N, -1, K) 252 | input2 = input[:, :, P:].contiguous().view(B, N, -1, K) 253 | input = torch.cat([input1, input2], dim=3).view( 254 | B, N, -1, K).transpose(2, 3) 255 | 256 | return input.contiguous(), gap 257 | 258 | 259 | def _over_add(self, input, gap): 260 | ''' 261 | Merge sequence 262 | input: [B, N, K, S] 263 | gap: padding length 264 | output: [B, N, L] 265 | ''' 266 | B, N, K, S = input.shape 267 | P = K // 2 268 | # [B, N, S, K] 269 | input = input.transpose(2, 3).contiguous().view(B, N, -1, K * 2) 270 | 271 | input1 = input[:, :, :, :K].contiguous().view(B, N, -1)[:, :, P:] 272 | input2 = input[:, :, :, K:].contiguous().view(B, N, -1)[:, :, :-P] 273 | input = input1 + input2 274 | # [B, N, L] 275 | if gap > 0: 276 | input = input[:, :, :-gap] 277 | 278 | return input 279 | 280 | class VisualConv1D(nn.Module): 281 | def __init__(self, V=256, H=512): 282 | super(VisualConv1D, self).__init__() 283 | relu_0 = nn.ReLU() 284 | norm_0 = GlobalLayerNorm(V) 285 | conv1x1 = nn.Conv1d(V, H, 1, bias=False) 286 | relu = nn.ReLU() 287 | norm_1 = GlobalLayerNorm(H) 288 | dsconv = nn.Conv1d(H, H, 3, stride=1, padding=1,dilation=1, groups=H, bias=False) 289 | prelu = nn.PReLU() 290 | norm_2 = GlobalLayerNorm(H) 291 | pw_conv = nn.Conv1d(H, V, 1, bias=False) 292 | self.net = nn.Sequential(relu_0, norm_0, conv1x1, relu, norm_1 ,dsconv, prelu, norm_2, pw_conv) 293 | 294 | def forward(self, x): 295 | out = self.net(x) 296 | return out + x 297 | 298 | class GlobalLayerNorm(nn.Module): 299 | """Global Layer Normalization (gLN)""" 300 | @amp.float_function 301 | def __init__(self, channel_size): 302 | super(GlobalLayerNorm, self).__init__() 303 | self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] 304 | self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1] 305 | self.reset_parameters() 306 | 307 | @amp.float_function 308 | def reset_parameters(self): 309 | self.gamma.data.fill_(1) 310 | self.beta.data.zero_() 311 | 312 | @amp.float_function 313 | def forward(self, y): 314 | """ 315 | Args: 316 | y: [M, N, K], M is batch size, N is channel size, K is length 317 | Returns: 318 | gLN_y: [M, N, K] 319 | """ 320 | # TODO: in torch 1.0, torch.mean() support dim list 321 | mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1] 322 | var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) 323 | gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta 324 | return gLN_y 325 | 326 | @amp.float_function 327 | def overlap_and_add(signal, frame_step): 328 | """Reconstructs a signal from a framed representation. 329 | 330 | Adds potentially overlapping frames of a signal with shape 331 | `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 332 | The resulting tensor has shape `[..., output_size]` where 333 | 334 | output_size = (frames - 1) * frame_step + frame_length 335 | 336 | Args: 337 | signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. 338 | frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. 339 | 340 | Returns: 341 | A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. 342 | output_size = (frames - 1) * frame_step + frame_length 343 | 344 | Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py 345 | """ 346 | outer_dimensions = signal.size()[:-2] 347 | frames, frame_length = signal.size()[-2:] 348 | 349 | subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor 350 | subframe_step = frame_step // subframe_length 351 | subframes_per_frame = frame_length // subframe_length 352 | output_size = frame_step * (frames - 1) + frame_length 353 | output_subframes = output_size // subframe_length 354 | 355 | subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) 356 | 357 | frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) 358 | frame = signal.new_tensor(frame).long().cuda() # signal may in GPU or CPU 359 | frame = frame.contiguous().view(-1) 360 | 361 | result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) 362 | result.index_add_(-2, frame, subframe_signal) 363 | result = result.view(*outer_dimensions, -1) 364 | return result 365 | -------------------------------------------------------------------------------- /src/av-dprnn/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from utils import * 4 | import os 5 | from avDprnn import avDprnn 6 | from solver import Solver 7 | 8 | 9 | def main(args): 10 | if args.distributed: 11 | torch.manual_seed(0) 12 | torch.cuda.set_device(args.local_rank) 13 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 14 | 15 | # Model 16 | model = avDprnn(args.N, args.L, args.B, args.H, args.K, args.R, 17 | args.C) 18 | 19 | if (args.distributed and args.local_rank ==0) or args.distributed == False: 20 | print("started on " + args.log_name + '\n') 21 | print(args) 22 | print(model) 23 | print("\nTotal number of parameters: {} \n".format(sum(p.numel() for p in model.parameters()))) 24 | 25 | model = model.cuda() 26 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 27 | 28 | train_sampler, train_generator = get_dataloader(args,'train') 29 | _, val_generator = get_dataloader(args, 'val') 30 | _, test_generator = get_dataloader(args, 'test') 31 | args.train_sampler=train_sampler 32 | 33 | solver = Solver(args=args, 34 | model = model, 35 | optimizer = optimizer, 36 | train_data = train_generator, 37 | validation_data = val_generator, 38 | test_data = test_generator) 39 | solver.train() 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser("avConv-tasnet") 43 | 44 | # Dataloader 45 | parser.add_argument('--mix_lst_path', type=str, default='/home/panzexu/datasets/LRS2/audio/2_mix_min/mixture_data_list_2mix.csv', 46 | help='directory including train data') 47 | parser.add_argument('--audio_direc', type=str, default='/home/panzexu/datasets/LRS2/audio/Audio/', 48 | help='directory including validation data') 49 | parser.add_argument('--visual_direc', type=str, default='/home/panzexu/datasets/LRS2/lip/', 50 | help='directory including test data') 51 | parser.add_argument('--mixture_direc', type=str, default='/home/panzexu/datasets/LRS2/audio/2_mix_min/', 52 | help='directory of audio') 53 | 54 | # Training 55 | parser.add_argument('--batch_size', default=8, type=int, 56 | help='Batch size') 57 | parser.add_argument('--max_length', default=6, type=int, 58 | help='max_length of mixture in training') 59 | parser.add_argument('--num_workers', default=4, type=int, 60 | help='Number of workers to generate minibatch') 61 | parser.add_argument('--epochs', default=100, type=int, 62 | help='Number of maximum epochs') 63 | 64 | # Model hyperparameters 65 | parser.add_argument('--L', default=40, type=int, 66 | help='Length of the filters in samples (80=5ms at 16kHZ)') 67 | parser.add_argument('--N', default=256, type=int, 68 | help='Number of filters in autoencoder') 69 | parser.add_argument('--B', default=64, type=int, 70 | help='Number of output channels') 71 | parser.add_argument('--C', type=int, default=2, 72 | help='number of speakers to mix') 73 | parser.add_argument('--H', default=128, type=int, 74 | help='Number of hidden size in rnn') 75 | parser.add_argument('--K', default=100, type=int, 76 | help='Number of chunk size') 77 | parser.add_argument('--R', default=6, type=int, 78 | help='Number of layers') 79 | 80 | # optimizer 81 | parser.add_argument('--lr', default=0.001, type=float, 82 | help='Init learning rate') 83 | parser.add_argument('--max_norm', default=5, type=float, 84 | help='Gradient norm threshold to clip') 85 | 86 | 87 | # Log and Visulization 88 | parser.add_argument('--log_name', type=str, default=None, 89 | help='the name of the log') 90 | parser.add_argument('--use_tensorboard', type=int, default=0, 91 | help='Whether to use use_tensorboard') 92 | parser.add_argument('--continue_from', type=str, default='', 93 | help='Whether to use use_tensorboard') 94 | 95 | # Distributed training 96 | parser.add_argument('--opt-level', default='O0', type=str) 97 | parser.add_argument("--local_rank", default=0, type=int) 98 | parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) 99 | parser.add_argument('--patch_torch_functions', type=str, default=None) 100 | 101 | args = parser.parse_args() 102 | 103 | args.distributed = False 104 | args.world_size = 1 105 | if 'WORLD_SIZE' in os.environ: 106 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 107 | args.world_size = int(os.environ['WORLD_SIZE']) 108 | 109 | assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." 110 | 111 | main(args) -------------------------------------------------------------------------------- /src/av-dprnn/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | gpu_id=11,12 4 | 5 | continue_from= 6 | 7 | if [ -z ${continue_from} ]; then 8 | log_name='avDprnn_'$(date '+%d-%m-%Y(%H:%M:%S)') 9 | mkdir logs/$log_name 10 | else 11 | log_name=${continue_from} 12 | fi 13 | 14 | CUDA_VISIBLE_DEVICES="$gpu_id" \ 15 | python -W ignore \ 16 | -m torch.distributed.launch \ 17 | --nproc_per_node=2 \ 18 | --master_port=2191 \ 19 | main.py \ 20 | --log_name $log_name \ 21 | \ 22 | --audio_direc '/home/panzexu/datasets/LRS2/audio_clean/' \ 23 | --visual_direc '/home/panzexu/datasets/LRS2/visual_embedding/lip/' \ 24 | --mix_lst_path '/home/panzexu/datasets/LRS2/audio_mixture/2_mix_min_asr/mixture_data_list_2mix.csv' \ 25 | --mixture_direc '/home/panzexu/datasets/LRS2/audio_mixture/2_mix_min_asr/' \ 26 | --C 3 \ 27 | \ 28 | --batch_size 24 \ 29 | --num_workers 2 \ 30 | \ 31 | --epochs 100 \ 32 | --use_tensorboard 1 \ 33 | >logs/$log_name/console.txt 2>&1 34 | 35 | # --continue_from ${continue_from} \ 36 | 37 | -------------------------------------------------------------------------------- /src/av-dprnn/solver.py: -------------------------------------------------------------------------------- 1 | import time 2 | from utils import * 3 | from apex import amp 4 | from apex.parallel import DistributedDataParallel as DDP 5 | from torch.utils.tensorboard import SummaryWriter 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | from stft_loss import MultiResolutionSTFTLoss 9 | 10 | class Solver(object): 11 | def __init__(self, train_data, validation_data, test_data, model, optimizer, args): 12 | self.train_data = train_data 13 | self.validation_data = validation_data 14 | self.test_data = test_data 15 | self.args = args 16 | self.amp = amp 17 | self.stft_loss = MultiResolutionSTFTLoss() 18 | 19 | self.print = False 20 | if (self.args.distributed and self.args.local_rank ==0) or not self.args.distributed: 21 | self.print = True 22 | if self.args.use_tensorboard: 23 | self.writer = SummaryWriter('logs/%s/tensorboard/' % args.log_name) 24 | 25 | self.model, self.optimizer = self.amp.initialize(model, optimizer, 26 | opt_level=args.opt_level, 27 | patch_torch_functions=args.patch_torch_functions) 28 | 29 | if self.args.distributed: 30 | self.model = DDP(self.model) 31 | 32 | self._reset() 33 | 34 | def _reset(self): 35 | self.halving = False 36 | if self.args.continue_from: 37 | checkpoint = torch.load('logs/%s/model_dict.pt' % self.args.continue_from, map_location='cpu') 38 | 39 | self.model.load_state_dict(checkpoint['model']) 40 | self.optimizer.load_state_dict(checkpoint['optimizer']) 41 | self.amp.load_state_dict(checkpoint['amp']) 42 | 43 | self.start_epoch=checkpoint['epoch'] 44 | self.prev_val_loss = checkpoint['prev_val_loss'] 45 | self.best_val_loss = checkpoint['best_val_loss'] 46 | self.val_no_impv = checkpoint['val_no_impv'] 47 | 48 | if self.print: print("Resume training from epoch: {}".format(self.start_epoch)) 49 | 50 | else: 51 | self.prev_val_loss = float("inf") 52 | self.best_val_loss = float("inf") 53 | self.val_no_impv = 0 54 | self.start_epoch=1 55 | if self.print: print('Start new training') 56 | 57 | 58 | def train(self): 59 | for epoch in range(self.start_epoch, self.args.epochs+1): 60 | if self.args.distributed: self.args.train_sampler.set_epoch(epoch) 61 | # Train 62 | self.model.train() 63 | start = time.time() 64 | tr_loss, tr_snr_loss, tr_stft_loss = self._run_one_epoch(data_loader = self.train_data) 65 | reduced_tr_loss = self._reduce_tensor(tr_loss) 66 | reduced_tr_snr_loss = self._reduce_tensor(tr_snr_loss) 67 | reduced_tr_stft_loss = self._reduce_tensor(tr_stft_loss) 68 | 69 | if self.print: print('Train Summary | End of Epoch {0} | Time {1:.2f}s | ' 70 | 'Train Loss {2:.3f}'.format( 71 | epoch, time.time() - start, reduced_tr_loss)) 72 | 73 | # Validation 74 | self.model.eval() 75 | start = time.time() 76 | with torch.no_grad(): 77 | val_loss, val_snr_loss, val_stft_loss = self._run_one_epoch(data_loader = self.validation_data, state='val') 78 | reduced_val_loss = self._reduce_tensor(val_loss) 79 | reduced_val_snr_loss = self._reduce_tensor(val_snr_loss) 80 | reduced_val_stft_loss = self._reduce_tensor(val_stft_loss) 81 | 82 | if self.print: print('Valid Summary | End of Epoch {0} | Time {1:.2f}s | ' 83 | 'Valid Loss {2:.3f}'.format( 84 | epoch, time.time() - start, reduced_val_loss)) 85 | 86 | # test 87 | self.model.eval() 88 | start = time.time() 89 | with torch.no_grad(): 90 | test_loss, test_snr_loss, test_stft_loss = self._run_one_epoch(data_loader = self.test_data, state='test') 91 | reduced_test_loss = self._reduce_tensor(test_loss) 92 | reduced_test_snr_loss = self._reduce_tensor(test_snr_loss) 93 | reduced_test_stft_loss = self._reduce_tensor(test_stft_loss) 94 | 95 | if self.print: print('Test Summary | End of Epoch {0} | Time {1:.2f}s | ' 96 | 'Test Loss {2:.3f}'.format( 97 | epoch, time.time() - start, reduced_test_loss)) 98 | 99 | 100 | # Check whether to adjust learning rate and early stop 101 | if reduced_val_loss >= self.best_val_loss: 102 | self.val_no_impv += 1 103 | if self.val_no_impv >= 10: 104 | if self.print: print("No imporvement for 10 epochs, early stopping.") 105 | break 106 | else: 107 | self.val_no_impv = 0 108 | 109 | if self.val_no_impv == 6: #(epoch %2) == 0: 110 | self.halving = True 111 | 112 | # Halfing the learning rate 113 | if self.halving: 114 | optim_state = self.optimizer.state_dict() 115 | optim_state['param_groups'][0]['lr'] = \ 116 | optim_state['param_groups'][0]['lr'] * 0.5 117 | self.optimizer.load_state_dict(optim_state) 118 | if self.print: print('Learning rate adjusted to: {lr:.6f}'.format( 119 | lr=optim_state['param_groups'][0]['lr'])) 120 | self.halving = False 121 | self.prev_val_loss = reduced_val_loss 122 | 123 | if self.print: 124 | # Tensorboard logging 125 | if self.args.use_tensorboard: 126 | self.writer.add_scalar('Train_loss', reduced_tr_loss, epoch) 127 | self.writer.add_scalar('Validation_loss', reduced_val_loss, epoch) 128 | self.writer.add_scalar('Test_loss', reduced_test_loss, epoch) 129 | self.writer.add_scalar('Train_snr_loss', reduced_tr_snr_loss, epoch) 130 | self.writer.add_scalar('Validation_snr_loss', reduced_val_snr_loss, epoch) 131 | self.writer.add_scalar('Test_snr_loss', reduced_test_snr_loss, epoch) 132 | self.writer.add_scalar('Train_stft_loss', reduced_tr_stft_loss, epoch) 133 | self.writer.add_scalar('Validation_stft_loss', reduced_val_stft_loss, epoch) 134 | self.writer.add_scalar('Test_stft_loss', reduced_test_stft_loss, epoch) 135 | 136 | # Save model 137 | if reduced_val_loss < self.best_val_loss: 138 | self.best_val_loss = reduced_val_loss 139 | checkpoint = {'model': self.model.state_dict(), 140 | 'optimizer': self.optimizer.state_dict(), 141 | 'amp': self.amp.state_dict(), 142 | 'epoch': epoch+1, 143 | 'prev_val_loss': self.prev_val_loss, 144 | 'best_val_loss': self.best_val_loss, 145 | 'val_no_impv': self.val_no_impv} 146 | torch.save(checkpoint, "logs/"+ self.args.log_name+"/model_dict.pt") 147 | print("Fund new best model, dict saved") 148 | 149 | def _run_one_epoch(self, data_loader, state='train'): 150 | total_loss = 0 151 | total_snr_loss = 0 152 | total_stft_loss = 0 153 | for i, (a_mix, a_tgt, v_tgt) in enumerate(data_loader): 154 | a_mix = a_mix.cuda().squeeze(0).float() 155 | a_tgt = a_tgt.cuda().squeeze(0).float() 156 | v_tgt = v_tgt.cuda().squeeze(0).float() 157 | 158 | a_tgt_est = self.model(a_mix, v_tgt) 159 | 160 | pos_snr = cal_SISNR(a_tgt, a_tgt_est) 161 | snr_loss = 0 - torch.mean(pos_snr) 162 | 163 | # print(loss.item()) 164 | 165 | stft_loss_sc, stft_loss_mag = self.stft_loss(a_tgt_est, a_tgt) 166 | stft_loss = stft_loss_mag + stft_loss_sc 167 | 168 | gamma = 1 169 | loss = snr_loss + gamma* stft_loss 170 | 171 | if state=='train': 172 | # self._adjust_lr() 173 | self.optimizer.zero_grad() 174 | with self.amp.scale_loss(loss, self.optimizer) as scaled_loss: 175 | scaled_loss.backward() 176 | torch.nn.utils.clip_grad_norm_(self.amp.master_params(self.optimizer), 177 | self.args.max_norm) 178 | self.optimizer.step() 179 | if state=='test': 180 | loss = 0 - torch.mean(pos_snr[::self.args.C]) 181 | 182 | total_loss += loss.data 183 | total_snr_loss += snr_loss 184 | total_stft_loss += stft_loss 185 | 186 | return total_loss / (i+1), total_snr_loss / (i+1), total_stft_loss / (i+1) 187 | 188 | def _reduce_tensor(self, tensor): 189 | if not self.args.distributed: return tensor 190 | rt = tensor.clone() 191 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 192 | rt /= self.args.world_size 193 | return rt 194 | -------------------------------------------------------------------------------- /src/av-dprnn/stft_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torchaudio 11 | 12 | from distutils.version import LooseVersion 13 | 14 | is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7") 15 | 16 | 17 | def stft(x, fft_size, hop_size, win_length, window): 18 | """Perform STFT and convert to magnitude spectrogram. 19 | 20 | Args: 21 | x (Tensor): Input signal tensor (B, T). 22 | fft_size (int): FFT size. 23 | hop_size (int): Hop size. 24 | win_length (int): Window length. 25 | window (str): Window function type. 26 | 27 | Returns: 28 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 29 | 30 | """ 31 | window = window.cuda() 32 | if is_pytorch_17plus: 33 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=False) 34 | else: 35 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window) 36 | real = x_stft[..., 0] 37 | imag = x_stft[..., 1] 38 | 39 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 40 | return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) 41 | 42 | 43 | class SpectralConvergenceLoss(torch.nn.Module): 44 | """Spectral convergence loss module.""" 45 | 46 | def __init__(self): 47 | """Initilize spectral convergence loss module.""" 48 | super(SpectralConvergenceLoss, self).__init__() 49 | 50 | def forward(self, x_mag, y_mag): 51 | """Calculate forward propagation. 52 | 53 | Args: 54 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 55 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 56 | 57 | Returns: 58 | Tensor: Spectral convergence loss value. 59 | 60 | """ 61 | delta_win_length=5 62 | 63 | loss_1 = torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 64 | 65 | x_mag=x_mag.transpose(-1,-2) 66 | y_mag=y_mag.transpose(-1,-2) 67 | 68 | x_del = torchaudio.functional.compute_deltas(x_mag,win_length=delta_win_length) 69 | y_del = torchaudio.functional.compute_deltas(y_mag,win_length=delta_win_length) 70 | loss_2 = torch.norm(y_del - x_del, p="fro") / torch.norm(y_del, p="fro") 71 | 72 | x_acc = torchaudio.functional.compute_deltas(x_del,win_length=delta_win_length) 73 | y_acc = torchaudio.functional.compute_deltas(y_del,win_length=delta_win_length) 74 | loss_3 = torch.norm(y_acc - x_acc, p="fro") / torch.norm(y_acc, p="fro") 75 | 76 | return loss_1 + loss_2 + loss_3 77 | 78 | 79 | 80 | 81 | class LogSTFTMagnitudeLoss(torch.nn.Module): 82 | """Log STFT magnitude loss module.""" 83 | 84 | def __init__(self): 85 | """Initilize los STFT magnitude loss module.""" 86 | super(LogSTFTMagnitudeLoss, self).__init__() 87 | 88 | def forward(self, x_mag, y_mag): 89 | """Calculate forward propagation. 90 | 91 | Args: 92 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 93 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 94 | 95 | Returns: 96 | Tensor: Log STFT magnitude loss value. 97 | 98 | """ 99 | delta_win_length=5 100 | 101 | loss_1 = F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 102 | 103 | x_mag=torch.log(x_mag).transpose(-1,-2) 104 | y_mag=torch.log(y_mag).transpose(-1,-2) 105 | 106 | x_del = torchaudio.functional.compute_deltas(x_mag ,win_length=delta_win_length) 107 | y_del = torchaudio.functional.compute_deltas(y_mag ,win_length=delta_win_length) 108 | loss_2 = F.l1_loss(y_del, x_del) 109 | 110 | x_acc = torchaudio.functional.compute_deltas(x_del ,win_length=delta_win_length) 111 | y_acc = torchaudio.functional.compute_deltas(y_del ,win_length=delta_win_length) 112 | loss_3 = F.l1_loss(y_acc, x_acc) 113 | 114 | 115 | return loss_1 + loss_2 + loss_3 116 | 117 | 118 | 119 | class STFTLoss(torch.nn.Module): 120 | """STFT loss module.""" 121 | 122 | def __init__( 123 | self, fft_size=1024, shift_size=120, win_length=600, window="hann_window" 124 | ): 125 | """Initialize STFT loss module.""" 126 | super(STFTLoss, self).__init__() 127 | self.fft_size = fft_size 128 | self.shift_size = shift_size 129 | self.win_length = win_length 130 | self.spectral_convergence_loss = SpectralConvergenceLoss() 131 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 132 | # NOTE(kan-bayashi): Use register_buffer to fix #223 133 | self.register_buffer("window", getattr(torch, window)(win_length)) 134 | 135 | def forward(self, x, y): 136 | """Calculate forward propagation. 137 | 138 | Args: 139 | x (Tensor): Predicted signal (B, T). 140 | y (Tensor): Groundtruth signal (B, T). 141 | 142 | Returns: 143 | Tensor: Spectral convergence loss value. 144 | Tensor: Log STFT magnitude loss value. 145 | 146 | """ 147 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 148 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 149 | sc_loss = self.spectral_convergence_loss(x_mag, y_mag) 150 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 151 | 152 | return sc_loss, mag_loss 153 | 154 | 155 | class MultiResolutionSTFTLoss(torch.nn.Module): 156 | """Multi resolution STFT loss module.""" 157 | 158 | def __init__( 159 | self, 160 | fft_sizes=[1024, 2048, 512], 161 | hop_sizes=[120, 240, 50], 162 | win_lengths=[600, 1200, 240], 163 | window="hann_window", 164 | ): 165 | """Initialize Multi resolution STFT loss module. 166 | 167 | Args: 168 | fft_sizes (list): List of FFT sizes. 169 | hop_sizes (list): List of hop sizes. 170 | win_lengths (list): List of window lengths. 171 | window (str): Window function type. 172 | 173 | """ 174 | super(MultiResolutionSTFTLoss, self).__init__() 175 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 176 | self.stft_losses = torch.nn.ModuleList() 177 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 178 | self.stft_losses += [STFTLoss(fs, ss, wl, window)] 179 | 180 | def forward(self, x, y): 181 | """Calculate forward propagation. 182 | 183 | Args: 184 | x (Tensor): Predicted signal (B, T) or (B, #subband, T). 185 | y (Tensor): Groundtruth signal (B, T) or (B, #subband, T). 186 | 187 | Returns: 188 | Tensor: Multi resolution spectral convergence loss value. 189 | Tensor: Multi resolution log STFT magnitude loss value. 190 | 191 | """ 192 | if len(x.shape) == 3: 193 | x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T) 194 | y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T) 195 | sc_loss = 0.0 196 | mag_loss = 0.0 197 | for f in self.stft_losses: 198 | sc_l, mag_l = f(x, y) 199 | sc_loss += sc_l 200 | mag_loss += mag_l 201 | sc_loss /= len(self.stft_losses) 202 | mag_loss /= len(self.stft_losses) 203 | 204 | return sc_loss, mag_loss 205 | -------------------------------------------------------------------------------- /src/av-dprnn/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch.distributed as dist 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | import scipy.io.wavfile as wavfile 8 | from itertools import permutations 9 | from apex import amp 10 | import tqdm 11 | import os 12 | 13 | EPS = 1e-6 14 | 15 | class dataset(data.Dataset): 16 | def __init__(self, 17 | mix_lst_path, 18 | audio_direc, 19 | visual_direc, 20 | mixture_direc, 21 | batch_size, 22 | partition='test', 23 | audio_only=False, 24 | sampling_rate=16000, 25 | max_length=4, 26 | mix_no=2): 27 | 28 | self.minibatch =[] 29 | self.audio_only = audio_only 30 | self.audio_direc = audio_direc 31 | self.visual_direc = visual_direc 32 | self.mixture_direc = mixture_direc 33 | self.sampling_rate = sampling_rate 34 | self.partition = partition 35 | self.max_length = max_length 36 | self.C=mix_no 37 | 38 | mix_lst=open(mix_lst_path).read().splitlines() 39 | mix_lst=list(filter(lambda x: x.split(',')[0]==partition, mix_lst)) 40 | 41 | # if partition=='train': 42 | # mix_lst = mix_lst[:20000] 43 | 44 | self.batch_size = batch_size 45 | 46 | sorted_mix_lst = sorted(mix_lst, key=lambda data: float(data.split(',')[-1]), reverse=True) 47 | 48 | start = 0 49 | while True: 50 | end = min(len(sorted_mix_lst), start + self.batch_size) 51 | self.minibatch.append(sorted_mix_lst[start:end]) 52 | if end == len(sorted_mix_lst): 53 | break 54 | start = end 55 | 56 | def __getitem__(self, index): 57 | batch_lst = self.minibatch[index] 58 | min_length = int(float(batch_lst[-1].split(',')[-1])*self.sampling_rate) 59 | 60 | mixtures=[] 61 | audios=[] 62 | visuals=[] 63 | for line in batch_lst: 64 | mixture_path=self.mixture_direc+self.partition+'/'+ line.replace(',','_').replace('/','_')+'.wav' 65 | _, mixture = wavfile.read(mixture_path) 66 | mixture = self._audio_norm(mixture[:min_length]) 67 | 68 | line=line.split(',') 69 | for c in range(1): 70 | audio_path=self.audio_direc+line[c*4+1]+'/'+line[c*4+2]+'/'+line[c*4+3]+'.wav' 71 | _, audio = wavfile.read(audio_path) 72 | audios.append(self._audio_norm(audio[:min_length])) 73 | 74 | # read visual 75 | visual_path=self.visual_direc+line[c*4+1]+'/'+line[c*4+2]+'/'+line[c*4+3]+'.npy' 76 | visual = np.load(visual_path) 77 | length = math.floor(min_length/self.sampling_rate*25) 78 | visual = visual[:length,...] 79 | a = visual.shape[0] 80 | if visual.shape[0] < length: 81 | visual = np.pad(visual, ((0,int(length - visual.shape[0])),(0,0)), mode = 'edge') 82 | visuals.append(visual) 83 | 84 | mixtures.append(mixture) 85 | 86 | return np.asarray(mixtures)[...,:self.max_length*self.sampling_rate], \ 87 | np.asarray(audios)[...,:self.max_length*self.sampling_rate], \ 88 | np.asarray(visuals)[...,:self.max_length*25,:] 89 | 90 | def __len__(self): 91 | return len(self.minibatch) 92 | 93 | def _audio_norm(self,audio): 94 | return np.divide(audio, np.max(np.abs(audio))) 95 | 96 | 97 | class DistributedSampler(data.Sampler): 98 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0): 99 | if num_replicas is None: 100 | if not dist.is_available(): 101 | raise RuntimeError("Requires distributed package to be available") 102 | num_replicas = dist.get_world_size() 103 | if rank is None: 104 | if not dist.is_available(): 105 | raise RuntimeError("Requires distributed package to be available") 106 | rank = dist.get_rank() 107 | self.dataset = dataset 108 | self.num_replicas = num_replicas 109 | self.rank = rank 110 | self.epoch = 0 111 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 112 | self.total_size = self.num_samples * self.num_replicas 113 | self.shuffle = shuffle 114 | self.seed = seed 115 | 116 | def __iter__(self): 117 | if self.shuffle: 118 | # deterministically shuffle based on epoch and seed 119 | g = torch.Generator() 120 | g.manual_seed(self.seed + self.epoch) 121 | # indices = torch.randperm(len(self.dataset), generator=g).tolist() 122 | ind = torch.randperm(int(len(self.dataset)/self.num_replicas), generator=g)*self.num_replicas 123 | indices = [] 124 | for i in range(self.num_replicas): 125 | indices = indices + (ind+i).tolist() 126 | else: 127 | indices = list(range(len(self.dataset))) 128 | 129 | # add extra samples to make it evenly divisible 130 | indices += indices[:(self.total_size - len(indices))] 131 | assert len(indices) == self.total_size 132 | 133 | # subsample 134 | # indices = indices[self.rank:self.total_size:self.num_replicas] 135 | indices = indices[self.rank*self.num_samples:(self.rank+1)*self.num_samples] 136 | assert len(indices) == self.num_samples 137 | 138 | return iter(indices) 139 | 140 | def __len__(self): 141 | return self.num_samples 142 | 143 | def set_epoch(self, epoch): 144 | self.epoch = epoch 145 | 146 | def get_dataloader(args, partition): 147 | datasets = dataset( 148 | mix_lst_path=args.mix_lst_path, 149 | audio_direc=args.audio_direc, 150 | visual_direc=args.visual_direc, 151 | mixture_direc=args.mixture_direc, 152 | batch_size=args.batch_size, 153 | max_length=args.max_length, 154 | partition=partition, 155 | mix_no=args.C) 156 | 157 | sampler = DistributedSampler( 158 | datasets, 159 | num_replicas=args.world_size, 160 | rank=args.local_rank) if args.distributed else None 161 | 162 | generator = data.DataLoader(datasets, 163 | batch_size = 1, 164 | shuffle = (sampler is None), 165 | num_workers = args.num_workers, 166 | sampler=sampler) 167 | 168 | return sampler, generator 169 | 170 | @amp.float_function 171 | def cal_SISNR(source, estimate_source): 172 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) 173 | Args: 174 | source: torch tensor, [batch size, sequence length] 175 | estimate_source: torch tensor, [batch size, sequence length] 176 | Returns: 177 | SISNR, [batch size] 178 | """ 179 | assert source.size() == estimate_source.size() 180 | 181 | # Step 1. Zero-mean norm 182 | source = source - torch.mean(source, axis = -1, keepdim=True) 183 | estimate_source = estimate_source - torch.mean(estimate_source, axis = -1, keepdim=True) 184 | 185 | # Step 2. SI-SNR 186 | # s_target = s / ||s||^2 187 | ref_energy = torch.sum(source ** 2, axis = -1, keepdim=True) + EPS 188 | proj = torch.sum(source * estimate_source, axis = -1, keepdim=True) * source / ref_energy 189 | # e_noise = s' - s_target 190 | noise = estimate_source - proj 191 | # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) 192 | ratio = torch.sum(proj ** 2, axis = -1) / (torch.sum(noise ** 2, axis = -1) + EPS) 193 | sisnr = 10 * torch.log10(ratio + EPS) 194 | 195 | return sisnr 196 | 197 | 198 | if __name__ == '__main__': 199 | datasets = dataset( 200 | mix_lst_path='/home/panzexu/datasets/LRS2/audio/2_mix_min/mixture_data_list_2mix_5db.csv', 201 | audio_direc='/home/panzexu/datasets/LRS2/audio/Audio/', 202 | visual_direc='/home/panzexu/datasets/LRS2/lip/', 203 | mixture_direc='/home/panzexu/datasets/LRS2/audio/2_mix_min/', 204 | batch_size=8, 205 | partition='train') 206 | data_loader = data.DataLoader(datasets, 207 | batch_size = 1, 208 | shuffle= True, 209 | num_workers = 1) 210 | 211 | for a_mix, a_tgt, v_tgt in tqdm.tqdm(data_loader): 212 | # print(a_mix.squeeze().size()) 213 | # print(a_tgt.squeeze().size()) 214 | # print(v_tgt.squeeze().size()) 215 | pass 216 | 217 | # a = np.ones((24,512)) 218 | # print(a.shape) 219 | # a = np.pad(a, ((0,-1), (0,0)), 'edge') 220 | # print(a.shape) 221 | 222 | # a = np.random.rand(2,2,3) 223 | # print(a) 224 | # a = a.reshape(4,3) 225 | # print(a) --------------------------------------------------------------------------------