├── .gitattributes ├── AudioData.py ├── README.md ├── conv_stft.py ├── criteria.py ├── eval_composite.py ├── gen_pair.py ├── metric.py ├── model.py ├── test.py ├── train.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /AudioData.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from preprocess import SignalToFrames, ToTensor 4 | import numpy as np 5 | import random 6 | import h5py 7 | 8 | 9 | class TrainingDataset(Dataset): 10 | r"""Training dataset.""" 11 | 12 | def __init__(self, file_path, frame_size=512, frame_shift=256, nsamples=64000): 13 | 14 | with open(file_path, 'r') as train_file_list: 15 | self.file_list = [line.strip() for line in train_file_list.readlines()] 16 | 17 | self.nsamples = nsamples 18 | self.get_frames = SignalToFrames(frame_size=frame_size, 19 | frame_shift=frame_shift) 20 | self.to_tensor = ToTensor() 21 | 22 | def __len__(self): 23 | #print(len(self.file_list)) 24 | return len(self.file_list) 25 | 26 | def __getitem__(self, index): 27 | filename = self.file_list[index] 28 | reader = h5py.File(filename, 'r') 29 | feature = reader['noisy_raw'][:] 30 | label = reader['clean_raw'][:] 31 | reader.close() 32 | 33 | 34 | size = feature.shape[0] 35 | start = random.randint(0, max(0, size - self.nsamples)) 36 | feature = feature[start:start + self.nsamples] 37 | label = label[start:start + self.nsamples] 38 | 39 | #print(feature.shape) 40 | feature = np.reshape(feature, [1, -1]) # [1, sig_len] 41 | #print(feature.shape) 42 | label = np.reshape(label, [1, -1]) # [1, sig_len] 43 | 44 | # feature = self.get_frames(feture.shape)ature) # [1, num_frames, sig_len] 45 | feature = self.to_tensor(feature) # [1, sig_len] 46 | label = self.to_tensor(label) # [1, sig_len] 47 | 48 | 49 | sig_len = feature.shape[-1] 50 | feature_ = torch.zeros((1, self.nsamples)) 51 | label_ = torch.zeros((1, self.nsamples)) 52 | feature_[:,:sig_len] = feature 53 | label_[:,:sig_len] = label 54 | #print(feature.shape) 55 | 56 | #return feature, label 57 | return feature_, label_, sig_len 58 | 59 | 60 | class EvalDataset(Dataset): 61 | r"""Evaluation dataset.""" 62 | 63 | def __init__(self, file_path, frame_size=512, frame_shift=256): 64 | 65 | #self.filename = filename 66 | with open(file_path, 'r') as validation_file_list: 67 | self.file_list = [line.strip() for line in validation_file_list.readlines()] 68 | 69 | self.get_frames = SignalToFrames(frame_size=frame_size, 70 | frame_shift=frame_shift) 71 | self.to_tensor = ToTensor() 72 | 73 | def __len__(self): 74 | return len(self.file_list) 75 | 76 | def __getitem__(self, index): 77 | filename = self.file_list[index] 78 | reader = h5py.File(filename, 'r') 79 | 80 | feature = reader['noisy_raw'][:] 81 | label = reader['clean_raw'][:] 82 | 83 | feature = np.reshape(feature, [1, -1]) # [1, 1, sig_len] 84 | 85 | # feature = self.get_frames(feature) # [1, 1, num_frames, frame_size] 86 | # print(feature.shape) 87 | feature = self.to_tensor(feature) # [1, 1, num_frames, frame_size] 88 | label = self.to_tensor(label) # [sig_len, ] 89 | 90 | return feature, label -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPT-FSNet 2 | 3 | This project provides the source code for the DPT-FSNet. 4 | -------------------------------------------------------------------------------- /conv_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from scipy.signal import get_window 6 | 7 | 8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 9 | if win_type == 'None' or win_type is None: 10 | window = np.ones(win_len) 11 | else: 12 | window = get_window(win_type, win_len, fftbins=True)#**0.5 13 | 14 | N = fft_len 15 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 16 | real_kernel = np.real(fourier_basis) 17 | imag_kernel = np.imag(fourier_basis) 18 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 19 | 20 | if invers : 21 | kernel = np.linalg.pinv(kernel).T 22 | 23 | kernel = kernel*window 24 | kernel = kernel[:, None, :] 25 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) 26 | 27 | 28 | class ConvSTFT(nn.Module): 29 | 30 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 31 | super(ConvSTFT, self).__init__() 32 | 33 | if fft_len == None: 34 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 35 | else: 36 | self.fft_len = fft_len 37 | 38 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 39 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 40 | self.register_buffer('weight', kernel) 41 | self.feature_type = feature_type 42 | self.stride = win_inc 43 | self.win_len = win_len 44 | self.dim = self.fft_len 45 | 46 | def forward(self, inputs): 47 | if inputs.dim() == 2: 48 | inputs = torch.unsqueeze(inputs, 1) 49 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 50 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 51 | 52 | if self.feature_type == 'complex': 53 | return outputs 54 | else: 55 | dim = self.dim//2+1 56 | real = outputs[:, :dim, :] 57 | imag = outputs[:, dim:, :] 58 | mags = torch.sqrt(real**2+imag**2) 59 | phase = torch.atan2(imag, real) 60 | return mags, phase 61 | 62 | class ConviSTFT(nn.Module): 63 | 64 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 65 | super(ConviSTFT, self).__init__() 66 | if fft_len == None: 67 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 68 | else: 69 | self.fft_len = fft_len 70 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 71 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 72 | self.register_buffer('weight', kernel) 73 | self.feature_type = feature_type 74 | self.win_type = win_type 75 | self.win_len = win_len 76 | self.stride = win_inc 77 | self.stride = win_inc 78 | self.dim = self.fft_len 79 | self.register_buffer('window', window) 80 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 81 | 82 | def forward(self, inputs, phase=None): 83 | """ 84 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 85 | phase: [B, N//2+1, T] (if not none) 86 | """ 87 | 88 | if phase is not None: 89 | real = inputs*torch.cos(phase) 90 | imag = inputs*torch.sin(phase) 91 | inputs = torch.cat([real, imag], 1) 92 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 93 | 94 | # this is from torch-stft: https://github.com/pseeth/torch-stft 95 | t = self.window.repeat(1,1,inputs.size(-1))**2 96 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 97 | outputs = outputs/(coff+1e-8) 98 | #outputs = torch.where(coff == 0, outputs, outputs/coff) 99 | outputs = outputs[...,self.win_len-self.stride:-(self.win_len-self.stride)] 100 | 101 | return outputs 102 | 103 | def test_fft(): 104 | torch.manual_seed(20) 105 | win_len = 320 106 | win_inc = 160 107 | fft_len = 512 108 | inputs = torch.randn([1, 1, 16000*4]) 109 | fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real') 110 | import librosa 111 | 112 | outputs1 = fft(inputs)[0] 113 | outputs1 = outputs1.numpy()[0] 114 | np_inputs = inputs.numpy().reshape([-1]) 115 | librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False) 116 | print(np.mean((outputs1 - np.abs(librosa_stft))**2)) 117 | 118 | def test_fft(): 119 | torch.manual_seed(20) 120 | win_len = 320 121 | win_inc = 160 122 | fft_len = 512 123 | inputs = torch.randn([1, 1, 16000*4]) 124 | fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real') 125 | import librosa 126 | 127 | outputs1 = fft(inputs)[0] 128 | outputs1 = outputs1.numpy()[0] 129 | np_inputs = inputs.numpy().reshape([-1]) 130 | librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False) 131 | print(np.mean((outputs1 - np.abs(librosa_stft))**2)) 132 | 133 | 134 | def test_ifft1(): 135 | import soundfile as sf 136 | N = 400 137 | inc = 100 138 | fft_len=512 139 | torch.manual_seed(N) 140 | data = np.random.randn(16000*8)[None,None,:] 141 | # data = sf.read('../ori.wav')[0] 142 | inputs = data.reshape([1,1,-1]) 143 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 144 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 145 | inputs = torch.from_numpy(inputs.astype(np.float32)) 146 | outputs1 = fft(inputs) 147 | print(outputs1.shape) 148 | outputs2 = ifft(outputs1) 149 | sf.write('conv_stft.wav', outputs2.numpy()[0,0,:],16000) 150 | print('wav MSE', torch.mean(torch.abs(inputs[...,:outputs2.size(2)]-outputs2)**2)) 151 | 152 | 153 | def test_ifft2(): 154 | N = 400 155 | inc = 100 156 | fft_len=512 157 | np.random.seed(20) 158 | torch.manual_seed(20) 159 | t = np.random.randn(16000*4)*0.001 160 | t = np.clip(t, -1, 1) 161 | #input = torch.randn([1,16000*4]) 162 | input = torch.from_numpy(t[None,None,:].astype(np.float32)) 163 | 164 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 165 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 166 | 167 | out1 = fft(input) 168 | output = ifft(out1) 169 | print('random MSE', torch.mean(torch.abs(input-output)**2)) 170 | import soundfile as sf 171 | sf.write('zero.wav', output[0,0].numpy(),16000) 172 | 173 | 174 | if __name__ == '__main__': 175 | #test_fft() 176 | test_ifft1() 177 | #test_ifft2() -------------------------------------------------------------------------------- /criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from conv_stft import ConvSTFT 3 | 4 | 5 | class stftm_loss(object): 6 | def __init__(self, frame_size=512, frame_shift=256, loss_type='mae'): 7 | self.stft = ConvSTFT(frame_size, frame_shift, frame_size, 'hanning', 'complex', fix=True).cuda() 8 | self.fft_len = 512 9 | 10 | 11 | def __call__(self, outputs, labels): 12 | out_real, out_imag = self.get_stftm(outputs) 13 | lab_real, lab_imag = self.get_stftm(labels) 14 | 15 | if self.loss_type == 'mae': 16 | loss = torch.mean(torch.abs(out_real-lab_real)+torch.abs(out_imag-lab_imag)) 17 | elif self.loss_type == 'char': 18 | loss = self.char_loss(out_real, lab_real) + self.char_loss(out_imag, lab_imag) 19 | elif self.loss_type == 'hybrid': 20 | loss = (self.edge_loss(out_real, lab_real) + self.edge_loss(out_imag, lab_imag)) * 0.05 +\ 21 | self.char_loss(out_real, lab_real) + self.char_loss(out_imag, lab_imag) 22 | 23 | 24 | return loss 25 | 26 | 27 | def get_stftm(self, ipt): 28 | specs = self.stft(ipt) 29 | 30 | real = specs[:,:self.fft_len//2+1] 31 | imag = specs[:,self.fft_len//2+1:] 32 | 33 | return real, imag 34 | 35 | 36 | 37 | 38 | class mag_loss(object): 39 | def __init__(self, frame_size=512, frame_shift=256, loss_type='mae'): 40 | self.stft = ConvSTFT(frame_size, frame_shift, frame_size, 'hanning', 'complex', fix=True).cuda() 41 | self.fft_len = 512 42 | 43 | def __call__(self, outputs, labels): 44 | out_mags = self.get_mag(outputs) 45 | lab_mags = self.get_mag(labels) 46 | 47 | if self.loss_type == 'mae': 48 | loss = torch.mean(torch.abs(out_mags-lab_mags)) 49 | 50 | return loss 51 | 52 | def get_mag(self, ipt): 53 | 54 | specs = self.stft(ipt) 55 | real = specs[:,:self.fft_len//2+1] 56 | imag = specs[:,self.fft_len//2+1:] 57 | spec_mags = torch.sqrt(real**2+imag**2+1e-8) 58 | 59 | return spec_mags 60 | 61 | 62 | -------------------------------------------------------------------------------- /eval_composite.py: -------------------------------------------------------------------------------- 1 | from scipy.linalg import toeplitz 2 | from pystoi import stoi 3 | from tqdm import tqdm 4 | from pesq import pesq 5 | import librosa 6 | import numpy as np 7 | import os 8 | import sys 9 | 10 | 11 | def eval_composite(ref_wav, deg_wav, sr=16000): 12 | ref_wav = ref_wav.reshape(-1) 13 | deg_wav = deg_wav.reshape(-1) 14 | 15 | alpha = 0.95 16 | len_ = min(ref_wav.shape[0], deg_wav.shape[0]) 17 | ref_wav = ref_wav[:len_] 18 | ref_len = ref_wav.shape[0] 19 | deg_wav = deg_wav[:len_] 20 | 21 | # Compute WSS measure 22 | wss_dist_vec = wss(ref_wav, deg_wav, 16000) 23 | wss_dist_vec = sorted(wss_dist_vec, reverse=False) 24 | wss_dist = np.mean(wss_dist_vec[:int(round(len(wss_dist_vec) * alpha))]) 25 | 26 | # Compute LLR measure 27 | LLR_dist = llr(ref_wav, deg_wav, 16000) 28 | LLR_dist = sorted(LLR_dist, reverse=False) 29 | LLRs = LLR_dist 30 | LLR_len = round(len(LLR_dist) * alpha) 31 | llr_mean = np.mean(LLRs[:LLR_len]) 32 | 33 | # Compute the SSNR 34 | snr_mean, segsnr_mean = SSNR(ref_wav, deg_wav, 16000) 35 | segSNR = np.mean(segsnr_mean) 36 | 37 | # Compute the PESQ 38 | #pesq_raw = PESQ(ref_wav, deg_wav) 39 | pesq_raw = get_pesq(ref_wav, deg_wav, sr) 40 | stoi_raw = get_stoi(ref_wav, deg_wav, sr) 41 | 42 | Csig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_raw - 0.009 * wss_dist 43 | Csig = trim_mos(Csig) 44 | Cbak = 1.634 + 0.478 * pesq_raw - 0.007 * wss_dist + 0.063 * segSNR 45 | Cbak = trim_mos(Cbak) 46 | Covl = 1.594 + 0.805 * pesq_raw - 0.512 * llr_mean - 0.007 * wss_dist 47 | Covl = trim_mos(Covl) 48 | 49 | return {'csig': Csig, 'cbak': Cbak, 'covl': Covl, 'stoi': stoi_raw, 'pesq': pesq_raw, 'ssnr': segSNR} 50 | 51 | 52 | # ----------------------------- HELPERS ------------------------------------ # 53 | def trim_mos(val): 54 | return min(max(val, 1), 5) 55 | 56 | 57 | def lpcoeff(speech_frame, model_order): 58 | # (1) Compute Autocor lags 59 | winlength = speech_frame.shape[0] 60 | R = [] 61 | for k in range(model_order + 1): 62 | first = speech_frame[:(winlength - k)] 63 | second = speech_frame[k:winlength] 64 | R.append(np.sum(first * second)) 65 | 66 | # (2) Lev-Durbin 67 | a = np.ones((model_order,)) 68 | E = np.zeros((model_order + 1,)) 69 | rcoeff = np.zeros((model_order,)) 70 | E[0] = R[0] 71 | for i in range(model_order): 72 | if i == 0: 73 | sum_term = 0 74 | else: 75 | a_past = a[:i] 76 | sum_term = np.sum(a_past * np.array(R[i:0:-1])) 77 | rcoeff[i] = (R[i + 1] - sum_term) / E[i] 78 | a[i] = rcoeff[i] 79 | if i > 0: 80 | a[:i] = a_past[:i] - rcoeff[i] * a_past[::-1] 81 | E[i + 1] = (1 - rcoeff[i] * rcoeff[i]) * E[i] 82 | acorr = np.array(R, dtype=np.float32) 83 | refcoeff = np.array(rcoeff, dtype=np.float32) 84 | a = a * -1 85 | lpparams = np.array([1] + list(a), dtype=np.float32) 86 | acorr = np.array(acorr, dtype=np.float32) 87 | refcoeff = np.array(refcoeff, dtype=np.float32) 88 | lpparams = np.array(lpparams, dtype=np.float32) 89 | 90 | return acorr, refcoeff, lpparams 91 | 92 | 93 | # -------------------------------------------------------------------------- # 94 | 95 | # ---------------------- Speech Quality Metric ----------------------------- # 96 | def PESQ(ref_wav, deg_wav): 97 | rate = 16000 98 | return pesq(rate, ref_wav, deg_wav, 'wb') 99 | 100 | 101 | def get_pesq(ref, deg, sr): 102 | 103 | score = pesq(sr, ref, deg, 'wb') 104 | 105 | return score 106 | 107 | def get_stoi(ref, deg, sr): 108 | 109 | score = stoi(ref, deg, sr, extended=False) 110 | 111 | return score 112 | 113 | def SSNR(ref_wav, deg_wav, srate=16000, eps=1e-10): 114 | """ Segmental Signal-to-Noise Ratio Objective Speech Quality Measure 115 | This function implements the segmental signal-to-noise ratio 116 | as defined in [1, p. 45] (see Equation 2.12). 117 | """ 118 | clean_speech = ref_wav 119 | processed_speech = deg_wav 120 | clean_length = ref_wav.shape[0] 121 | processed_length = deg_wav.shape[0] 122 | 123 | # scale both to have same dynamic range. Remove DC too. 124 | clean_speech -= clean_speech.mean() 125 | processed_speech -= processed_speech.mean() 126 | processed_speech *= (np.max(np.abs(clean_speech)) / np.max(np.abs(processed_speech))) 127 | 128 | # Signal-to-Noise Ratio 129 | dif = ref_wav - deg_wav 130 | overall_snr = 10 * np.log10(np.sum(ref_wav ** 2) / (np.sum(dif ** 2) + 131 | 10e-20)) 132 | # global variables 133 | winlength = int(np.round(30 * srate / 1000)) # 30 msecs 134 | skiprate = winlength // 4 135 | MIN_SNR = -10 136 | MAX_SNR = 35 137 | 138 | # For each frame, calculate SSNR 139 | num_frames = int(clean_length / skiprate - (winlength / skiprate)) 140 | start = 0 141 | time = np.linspace(1, winlength, winlength) / (winlength + 1) 142 | window = 0.5 * (1 - np.cos(2 * np.pi * time)) 143 | segmental_snr = [] 144 | 145 | for frame_count in range(int(num_frames)): 146 | # (1) get the frames for the test and ref speech. 147 | # Apply Hanning Window 148 | clean_frame = clean_speech[start:start + winlength] 149 | processed_frame = processed_speech[start:start + winlength] 150 | clean_frame = clean_frame * window 151 | processed_frame = processed_frame * window 152 | 153 | # (2) Compute Segmental SNR 154 | signal_energy = np.sum(clean_frame ** 2) 155 | noise_energy = np.sum((clean_frame - processed_frame) ** 2) 156 | segmental_snr.append(10 * np.log10(signal_energy / (noise_energy + eps) + eps)) 157 | segmental_snr[-1] = max(segmental_snr[-1], MIN_SNR) 158 | segmental_snr[-1] = min(segmental_snr[-1], MAX_SNR) 159 | start += int(skiprate) 160 | return overall_snr, segmental_snr 161 | 162 | 163 | def wss(ref_wav, deg_wav, srate): 164 | clean_speech = ref_wav 165 | processed_speech = deg_wav 166 | clean_length = ref_wav.shape[0] 167 | processed_length = deg_wav.shape[0] 168 | 169 | assert clean_length == processed_length, clean_length 170 | 171 | winlength = round(30 * srate / 1000.) # 240 wlen in samples 172 | skiprate = np.floor(winlength / 4) 173 | max_freq = srate / 2 174 | num_crit = 25 # num of critical bands 175 | 176 | USE_FFT_SPECTRUM = 1 177 | n_fft = int(2 ** np.ceil(np.log(2 * winlength) / np.log(2))) 178 | n_fftby2 = int(n_fft / 2) 179 | Kmax = 20 180 | Klocmax = 1 181 | 182 | # Critical band filter definitions (Center frequency and BW in Hz) 183 | cent_freq = [50., 120, 190, 260, 330, 400, 470, 540, 617.372, 184 | 703.378, 798.717, 904.128, 1020.38, 1148.30, 185 | 1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 186 | 2211.08, 2446.71, 2701.97, 2978.04, 3276.17, 187 | 3597.63] 188 | bandwidth = [70., 70, 70, 70, 70, 70, 70, 77.3724, 86.0056, 189 | 95.3398, 105.411, 116.256, 127.914, 140.423, 190 | 153.823, 168.154, 183.457, 199.776, 217.153, 191 | 235.631, 255.255, 276.072, 298.126, 321.465, 192 | 346.136] 193 | 194 | bw_min = bandwidth[0] # min critical bandwidth 195 | 196 | # set up critical band filters. Note here that Gaussianly shaped filters 197 | # are used. Also, the sum of the filter weights are equivalent for each 198 | # critical band filter. Filter less than -30 dB and set to zero. 199 | min_factor = np.exp(-30. / (2 * 2.303)) # -30 dB point of filter 200 | 201 | crit_filter = np.zeros((num_crit, n_fftby2)) 202 | all_f0 = [] 203 | for i in range(num_crit): 204 | f0 = (cent_freq[i] / max_freq) * (n_fftby2) 205 | all_f0.append(np.floor(f0)) 206 | bw = (bandwidth[i] / max_freq) * (n_fftby2) 207 | norm_factor = np.log(bw_min) - np.log(bandwidth[i]) 208 | j = list(range(n_fftby2)) 209 | crit_filter[i, :] = np.exp(-11 * (((j - np.floor(f0)) / bw) ** 2) + \ 210 | norm_factor) 211 | crit_filter[i, :] = crit_filter[i, :] * (crit_filter[i, :] > \ 212 | min_factor) 213 | 214 | # For each frame of input speech, compute Weighted Spectral Slope Measure 215 | num_frames = int(clean_length / skiprate - (winlength / skiprate)) 216 | start = 0 # starting sample 217 | time = np.linspace(1, winlength, winlength) / (winlength + 1) 218 | window = 0.5 * (1 - np.cos(2 * np.pi * time)) 219 | distortion = [] 220 | 221 | for frame_count in range(num_frames): 222 | # (1) Get the Frames for the test and reference speeech. 223 | # Multiply by Hanning window. 224 | clean_frame = clean_speech[start:start + winlength] 225 | processed_frame = processed_speech[start:start + winlength] 226 | clean_frame = clean_frame * window 227 | processed_frame = processed_frame * window 228 | 229 | # (2) Compuet Power Spectrum of clean and processed 230 | clean_spec = (np.abs(np.fft.fft(clean_frame, n_fft)) ** 2) 231 | processed_spec = (np.abs(np.fft.fft(processed_frame, n_fft)) ** 2) 232 | clean_energy = [None] * num_crit 233 | processed_energy = [None] * num_crit 234 | 235 | # (3) Compute Filterbank output energies (in dB) 236 | for i in range(num_crit): 237 | clean_energy[i] = np.sum(clean_spec[:n_fftby2] * \ 238 | crit_filter[i, :]) 239 | processed_energy[i] = np.sum(processed_spec[:n_fftby2] * \ 240 | crit_filter[i, :]) 241 | clean_energy = np.array(clean_energy).reshape(-1, 1) 242 | eps = np.ones((clean_energy.shape[0], 1)) * 1e-10 243 | clean_energy = np.concatenate((clean_energy, eps), axis=1) 244 | clean_energy = 10 * np.log10(np.max(clean_energy, axis=1)) 245 | processed_energy = np.array(processed_energy).reshape(-1, 1) 246 | processed_energy = np.concatenate((processed_energy, eps), axis=1) 247 | processed_energy = 10 * np.log10(np.max(processed_energy, axis=1)) 248 | 249 | # (4) Compute Spectral Shape (dB[i+1] - dB[i]) 250 | clean_slope = clean_energy[1:num_crit] - clean_energy[:num_crit - 1] 251 | processed_slope = processed_energy[1:num_crit] - \ 252 | processed_energy[:num_crit - 1] 253 | 254 | # (5) Find the nearest peak locations in the spectra to each 255 | # critical band. If the slope is negative, we search 256 | # to the left. If positive, we search to the right. 257 | clean_loc_peak = [] 258 | processed_loc_peak = [] 259 | for i in range(num_crit - 1): 260 | if clean_slope[i] > 0: 261 | # search to the right 262 | n = i 263 | while n < num_crit - 1 and clean_slope[n] > 0: 264 | n += 1 265 | clean_loc_peak.append(clean_energy[n - 1]) 266 | else: 267 | # search to the left 268 | n = i 269 | while n >= 0 and clean_slope[n] <= 0: 270 | n -= 1 271 | clean_loc_peak.append(clean_energy[n + 1]) 272 | # find the peaks in the processed speech signal 273 | if processed_slope[i] > 0: 274 | n = i 275 | while n < num_crit - 1 and processed_slope[n] > 0: 276 | n += 1 277 | processed_loc_peak.append(processed_energy[n - 1]) 278 | else: 279 | n = i 280 | while n >= 0 and processed_slope[n] <= 0: 281 | n -= 1 282 | processed_loc_peak.append(processed_energy[n + 1]) 283 | 284 | # (6) Compuet the WSS Measure for this frame. This includes 285 | # determination of the weighting functino 286 | dBMax_clean = max(clean_energy) 287 | dBMax_processed = max(processed_energy) 288 | 289 | # The weights are calculated by averaging individual 290 | # weighting factors from the clean and processed frame. 291 | # These weights W_clean and W_processed should range 292 | # from 0 to 1 and place more emphasis on spectral 293 | # peaks and less emphasis on slope differences in spectral 294 | # valleys. This procedure is described on page 1280 of 295 | # Klatt's 1982 ICASSP paper. 296 | clean_loc_peak = np.array(clean_loc_peak) 297 | processed_loc_peak = np.array(processed_loc_peak) 298 | Wmax_clean = Kmax / (Kmax + dBMax_clean - clean_energy[:num_crit - 1]) 299 | Wlocmax_clean = Klocmax / (Klocmax + clean_loc_peak - clean_energy[:num_crit - 1]) 300 | W_clean = Wmax_clean * Wlocmax_clean 301 | Wmax_processed = Kmax / (Kmax + dBMax_processed - processed_energy[:num_crit - 1]) 302 | Wlocmax_processed = Klocmax / (Klocmax + processed_loc_peak - processed_energy[:num_crit - 1]) 303 | W_processed = Wmax_processed * Wlocmax_processed 304 | W = (W_clean + W_processed) / 2 305 | distortion.append(np.sum(W * (clean_slope[:num_crit - 1] - processed_slope[:num_crit - 1]) ** 2)) 306 | 307 | # this normalization is not part of Klatt's paper, but helps 308 | # to normalize the meaasure. Here we scale the measure by the sum of the 309 | # weights 310 | distortion[frame_count] = distortion[frame_count] / np.sum(W) 311 | start += int(skiprate) 312 | return distortion 313 | 314 | 315 | def llr(ref_wav, deg_wav, srate): 316 | clean_speech = ref_wav 317 | processed_speech = deg_wav 318 | clean_length = ref_wav.shape[0] 319 | processed_length = deg_wav.shape[0] 320 | assert clean_length == processed_length, clean_length 321 | 322 | winlength = round(30 * srate / 1000.) # 240 wlen in samples 323 | skiprate = np.floor(winlength / 4) 324 | if srate < 10000: 325 | # LPC analysis order 326 | P = 10 327 | else: 328 | P = 16 329 | 330 | # For each frame of input speech, calculate the Log Likelihood Ratio 331 | num_frames = int(clean_length / skiprate - (winlength / skiprate)) 332 | start = 0 333 | time = np.linspace(1, winlength, winlength) / (winlength + 1) 334 | window = 0.5 * (1 - np.cos(2 * np.pi * time)) 335 | distortion = [] 336 | 337 | for frame_count in range(num_frames): 338 | # (1) Get the Frames for the test and reference speeech. 339 | # Multiply by Hanning window. 340 | clean_frame = clean_speech[start:start + winlength] 341 | processed_frame = processed_speech[start:start + winlength] 342 | clean_frame = clean_frame * window 343 | processed_frame = processed_frame * window 344 | 345 | # (2) Get the autocorrelation logs and LPC params used 346 | # to compute the LLR measure 347 | R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P) 348 | R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P) 349 | A_clean = A_clean[None, :] 350 | A_processed = A_processed[None, :] 351 | 352 | # (3) Compute the LLR measure 353 | numerator = A_processed.dot(toeplitz(R_clean)).dot(A_processed.T) 354 | denominator = A_clean.dot(toeplitz(R_clean)).dot(A_clean.T) 355 | 356 | if (numerator / denominator) <= 0: 357 | print(f'Numerator: {numerator}') 358 | print(f'Denominator: {denominator}') 359 | 360 | log_ = np.log(numerator / denominator) 361 | distortion.append(np.squeeze(log_)) 362 | start += int(skiprate) 363 | return np.nan_to_num(np.array(distortion)) -------------------------------------------------------------------------------- /gen_pair.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import glob 4 | import os 5 | import h5py 6 | import time 7 | 8 | 9 | def gen_pair(): 10 | 11 | train_clean_path = '/isilon/backup_netapp/dangfeng/VCTK_DEMAND/clean_trainset_28spk_wav_16k' 12 | train_noisy_path = '/isilon/backup_netapp/dangfeng/VCTK_DEMAND/noisy_trainset_28spk_wav_16k' 13 | train_mix_path = './dataset/voice_bank_mix/trainset' 14 | 15 | train_clean_name = sorted(os.listdir(train_clean_path)) 16 | train_noisy_name = sorted(os.listdir(train_noisy_path)) 17 | 18 | # print(train_clean_name) 19 | #print(train_noisy_name) 20 | 21 | for count in range(len(train_clean_name)): 22 | 23 | clean_name = train_clean_name[count] 24 | noisy_name = train_noisy_name[count] 25 | #print(clean_name, noisy_name) 26 | if clean_name == noisy_name: 27 | file_name = '%s_%d' % ('train_mix', count+1) 28 | train_writer = h5py.File(train_mix_path + '/' + file_name, 'w') 29 | 30 | clean_audio, sr = librosa.load(os.path.join(train_clean_path, clean_name), sr=16000) 31 | noisy_audio, sr1 = librosa.load(os.path.join(train_noisy_path, noisy_name), sr=16000) 32 | 33 | train_writer.create_dataset('noisy_raw', data=noisy_audio.astype(np.float32), chunks=True) 34 | train_writer.create_dataset('clean_raw', data=clean_audio.astype(np.float32), chunks=True) 35 | train_writer.close() 36 | else: 37 | raise TypeError('clean file and noisy file do not match') 38 | 39 | # save .txt file 40 | print('sleep for 3 secs...') 41 | time.sleep(3) 42 | print('begin save .txt file...') 43 | train_file_list = sorted(glob.glob(os.path.join(train_mix_path, '*'))) 44 | read_train = open("train_file_list", "w+") 45 | 46 | for i in range(len(train_file_list)): 47 | read_train.write("%s\n" % (train_file_list[i])) 48 | 49 | read_train.close() 50 | print('making training data finished!') 51 | 52 | 53 | if __name__ == "__main__": 54 | gen_pair() 55 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from pesq import pesq 2 | from pystoi import stoi 3 | #from STOI import stoi 4 | 5 | def get_pesq(ref, deg, sr): 6 | 7 | score = pesq(sr, ref, deg, 'wb') 8 | 9 | return score 10 | 11 | def get_stoi(ref, deg, sr): 12 | 13 | score = stoi(ref, deg, sr, extended=False) 14 | 15 | return score -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from torch.nn.modules.module import Module 6 | from torch.nn.modules.activation import MultiheadAttention 7 | from torch.nn.modules.container import ModuleList 8 | from torch.nn.init import xavier_uniform_ 9 | from torch.nn.modules.dropout import Dropout 10 | from torch.nn.modules.linear import Linear 11 | from torch.nn.modules.rnn import LSTM, GRU 12 | from torch.nn.modules.normalization import LayerNorm 13 | from utils import show_params, show_model 14 | from conv_stft import ConvSTFT, ConviSTFT 15 | 16 | class TransformerEncoderLayer(Module): 17 | def __init__(self, d_model, nhead, bidirectional=True, dropout=0, activation="relu"): 18 | super(TransformerEncoderLayer, self).__init__() 19 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 20 | self.gru = GRU(d_model, d_model*2, 1, bidirectional=bidirectional) 21 | self.dropout = Dropout(dropout) 22 | if bidirectional: 23 | self.linear2 = Linear(d_model*2*2, d_model) 24 | else: 25 | self.linear2 = Linear(d_model*2, d_model) 26 | 27 | self.norm1 = LayerNorm(d_model) 28 | self.norm2 = LayerNorm(d_model) 29 | self.dropout1 = Dropout(dropout) 30 | self.dropout2 = Dropout(dropout) 31 | 32 | self.activation = _get_activation_fn(activation) 33 | 34 | def __setstate__(self, state): 35 | if 'activation' not in state: 36 | state['activation'] = F.relu 37 | super(TransformerEncoderLayer, self).__setstate__(state) 38 | 39 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 40 | # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor 41 | r"""Pass the input through the encoder layer. 42 | Args: 43 | src: the sequnce to the encoder layer (required). 44 | src_mask: the mask for the src sequence (optional). 45 | src_key_padding_mask: the mask for the src keys per batch (optional). 46 | Shape: 47 | see the docs in Transformer class. 48 | """ 49 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 50 | key_padding_mask=src_key_padding_mask)[0] 51 | src = src + self.dropout1(src2) 52 | src = self.norm1(src) 53 | self.gru.flatten_parameters() 54 | out, h_n = self.gru(src) 55 | del h_n 56 | src2 = self.linear2(self.dropout(self.activation(out))) 57 | src = src + self.dropout2(src2) 58 | src = self.norm2(src) 59 | return src 60 | 61 | 62 | def _get_clones(module, N): 63 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 64 | 65 | 66 | def _get_activation_fn(activation): 67 | if activation == "relu": 68 | return F.relu 69 | elif activation == "gelu": 70 | return F.gelu 71 | 72 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 73 | 74 | class Dual_Transformer(nn.Module): 75 | def __init__(self, input_size, output_size, dropout=0, num_layers=1): 76 | super(Dual_Transformer, self).__init__() 77 | 78 | self.input_size = input_size 79 | self.output_size = output_size 80 | 81 | self.input = nn.Sequential( 82 | nn.Conv2d(input_size, input_size // 2, kernel_size=1), 83 | nn.PReLU() 84 | ) 85 | 86 | # dual-path RNN 87 | self.row_trans = nn.ModuleList([]) 88 | self.col_trans = nn.ModuleList([]) 89 | self.row_norm = nn.ModuleList([]) 90 | self.col_norm = nn.ModuleList([]) 91 | for i in range(num_layers): 92 | self.row_trans.append(TransformerEncoderLayer(d_model=input_size//2, nhead=4, dropout=dropout, bidirectional=True)) 93 | self.col_trans.append(TransformerEncoderLayer(d_model=input_size//2, nhead=4, dropout=dropout, bidirectional=True)) 94 | self.row_norm.append(nn.GroupNorm(1, input_size//2, eps=1e-8)) 95 | self.col_norm.append(nn.GroupNorm(1, input_size//2, eps=1e-8)) 96 | 97 | # output layer 98 | self.output = nn.Sequential(nn.PReLU(), 99 | nn.Conv2d(input_size//2, output_size, 1) 100 | ) 101 | 102 | def forward(self, input): 103 | # input --- [b, c, num_frames, frame_size] --- [b, c, t, f] 104 | b, c, t, f = input.shape 105 | output = self.input(input) 106 | for i in range(len(self.row_trans)): 107 | row_input = output.permute(2, 0, 3, 1).contiguous().view(t, b*f, -1) # [t, b*f, c] 108 | row_output = self.row_trans[i](row_input) 109 | row_output = row_output.view(t, b, f, -1).permute(1, 3, 0, 2).contiguous() # [b, c, t, f] 110 | row_output = self.row_norm[i](row_output) 111 | output = output + row_output 112 | 113 | col_input = output.permute(3, 0, 2, 1).contiguous().view(f, b*t, -1) # [f, b*t, c] 114 | col_output = self.col_trans[i](col_input) 115 | col_output = col_output.view(f, b, t, -1).permute(1, 3, 2, 0).contiguous() # [b, c, t, f] 116 | col_output = self.col_norm[i](col_output) 117 | output = output + col_output 118 | 119 | del row_input, row_output, col_input, col_output 120 | output = self.output(output) # [b, c, t, f] 121 | 122 | return output 123 | 124 | 125 | 126 | class SPConvTranspose2d(nn.Module): 127 | def __init__(self, in_channels, out_channels, kernel_size, r=1): 128 | # upconvolution only along second dimension of image 129 | # Upsampling using sub pixel layers 130 | super(SPConvTranspose2d, self).__init__() 131 | self.out_channels = out_channels 132 | self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) 133 | self.r = r 134 | 135 | def forward(self, x): 136 | out = self.conv(x) 137 | batch_size, nchannels, H, W = out.shape 138 | out = out.view((batch_size, self.r, nchannels // self.r, H, W)) 139 | out = out.permute(0, 2, 3, 4, 1) 140 | out = out.contiguous().view((batch_size, nchannels // self.r, H, -1)) 141 | return out 142 | 143 | 144 | class DenseBlock(nn.Module): 145 | def __init__(self, input_size, depth=5, in_channels=64): 146 | super(DenseBlock, self).__init__() 147 | self.depth = depth 148 | self.in_channels = in_channels 149 | self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.) 150 | self.twidth = 2 151 | self.kernel_size = (self.twidth, 3) 152 | for i in range(self.depth): 153 | dil = 2 ** i 154 | pad_length = self.twidth + (dil - 1) * (self.twidth - 1) - 1 155 | setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((1, 1, pad_length, 0), value=0.)) 156 | setattr(self, 'conv{}'.format(i + 1), 157 | nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=self.kernel_size, 158 | dilation=(dil, 1))) 159 | setattr(self, 'norm{}'.format(i + 1), nn.LayerNorm(input_size)) 160 | setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels)) 161 | 162 | def forward(self, x): 163 | skip = x 164 | for i in range(self.depth): 165 | out = getattr(self, 'pad{}'.format(i + 1))(skip) 166 | out = getattr(self, 'conv{}'.format(i + 1))(out) 167 | out = getattr(self, 'norm{}'.format(i + 1))(out) 168 | out = getattr(self, 'prelu{}'.format(i + 1))(out) 169 | skip = torch.cat([out, skip], dim=1) 170 | return out 171 | 172 | 173 | 174 | class Net(nn.Module): 175 | def __init__(self, L=512, width=64): 176 | super(Net, self).__init__() 177 | self.L = L 178 | self.frame_shift = self.L // 2 179 | self.N = 256 180 | self.B = 256 181 | self.H = 512 182 | self.P = 3 183 | # self.device = device 184 | self.in_channels = 2 185 | self.out_channels = 2 186 | self.kernel_size = (2, 3) 187 | # self.elu = nn.SELU(inplace=True) 188 | self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.) 189 | self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.) 190 | self.width = width 191 | 192 | self.inp_conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.width, kernel_size=(1, 1)) # [b, 64, nframes, 256] 193 | self.inp_norm = nn.LayerNorm(256) 194 | self.inp_prelu = nn.PReLU(self.width) 195 | 196 | self.enc_dense1 = DenseBlock(256, 4, self.width) 197 | 198 | self.dual_transformer = Dual_Transformer(64, 64, num_layers=4) # # [b, 64, nframes, 8] 199 | 200 | # gated output layer 201 | self.output1 = nn.Sequential( 202 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1), 203 | nn.Tanh() 204 | ) 205 | 206 | self.output2 = nn.Sequential( 207 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1), 208 | nn.Sigmoid() 209 | ) 210 | 211 | self.dec_dense1 = DenseBlock(256, 4, self.width) 212 | 213 | self.out_conv = nn.Conv2d(in_channels=self.width, out_channels=self.out_channels, kernel_size=(1, 1)) 214 | self.stft = ConvSTFT(self.L, self.frame_shift, self.L, 'hanning', 'real', fix=True) 215 | self.istft = ConviSTFT(self.L, self.frame_shift, self.L, 'hanning', 'complex', fix=True) 216 | 217 | show_model(self) 218 | show_params(self) 219 | 220 | def forward(self, x, masking_mode='C'): 221 | 222 | 223 | x = self.stft(x) 224 | real = x[0] 225 | imag = x[1] 226 | x = torch.stack([real,imag],1) 227 | x = x.permute(0,1,3,2) # [B, 2, num_frames, num_bins] 228 | x = x[...,1:] 229 | #print(x.shape) 230 | 231 | out = self.inp_prelu(self.inp_norm(self.inp_conv(x))) # [b, 64, num_frames, frame_size] 232 | x1 = self.enc_dense1(out) # [b, 64, num_frames, frame_size] 233 | out = self.dual_transformer(x1) # [b, 64, num_frames, 256] 234 | out = self.output1(out) * self.output2(out) # mask [b, 64, num_frames, 256] 235 | out = self.dec_dense1(out) 236 | out = self.out_conv(out) 237 | 238 | real = x[:,0] 239 | imag = x[:,1] 240 | 241 | mask_real = out[:,0] 242 | mask_imag = out[:,1] 243 | 244 | 245 | if masking_mode == 'E' : 246 | mask_mags = (mask_real**2+mask_imag**2)**0.5 247 | real_phase = mask_real/(mask_mags+1e-8) 248 | imag_phase = mask_imag/(mask_mags+1e-8) 249 | mask_phase = torch.atan2( imag_phase, real_phase ) 250 | #mask_mags = torch.clamp_(mask_mags,0,100) 251 | mask_mags = torch.tanh(mask_mags) 252 | est_mags = mask_mags*spec_mags 253 | est_phase = spec_phase + mask_phase 254 | real = est_mags*torch.cos(est_phase) 255 | imag = est_mags*torch.sin(est_phase) 256 | elif masking_mode == 'C': 257 | real,imag = real*mask_real-imag*mask_imag, real*mask_imag+imag*mask_real 258 | elif masking_mode == 'R': 259 | real, imag = real*mask_real, imag*mask_imag 260 | 261 | #print(out.shape) 262 | 263 | real = torch.cat((torch.zeros((real.size()[0], real.size()[1], 1)).to(device='cuda'), real), -1) 264 | imag = torch.cat((torch.zeros((imag.size()[0], imag.size()[1], 1)).to(device='cuda'), imag), -1) 265 | out = torch.cat([real,imag],-1).permute(0,2,1) 266 | 267 | out = self.istft(out) 268 | 269 | return out 270 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from metric import get_stoi, get_pesq 4 | from scipy.io import wavfile 5 | import numpy as np 6 | from checkpoints import Checkpoint 7 | from torch.utils.data import DataLoader 8 | from helper_funcs import snr, numParams 9 | from eval_composite import eval_composite 10 | from AudioData import EvalDataset, EvalCollate 11 | from new_model import Net 12 | import h5py 13 | import os 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' 16 | 17 | sr = 16000 18 | 19 | test_file_list_path = './test_file_list' 20 | audio_file_save = './audio_file/best_ept' 21 | if not os.path.isdir(audio_file_save): 22 | os.makedirs(audio_file_save) 23 | 24 | with open(test_file_list_path, 'r') as test_file_list: 25 | file_list = [line.strip() for line in test_file_list.readlines()] 26 | #audio_name = os.path.basename(file_list[0]) 27 | 28 | print(file_list) 29 | 30 | 31 | test_data = EvalDataset(test_file_list_path, frame_size=512, frame_shift=256) 32 | test_loader = DataLoader(test_data, 33 | batch_size=1, 34 | shuffle=False, 35 | num_workers=4, 36 | collate_fn=EvalCollate()) 37 | 38 | ckpt_path = '' 39 | 40 | model = Net() 41 | model = nn.DataParallel(model, device_ids=[0, 1]) 42 | checkpoint = Checkpoint() 43 | checkpoint.load(ckpt_path) 44 | model.load_state_dict(checkpoint.state_dict) 45 | model.cuda() 46 | print(checkpoint.start_epoch) 47 | print(checkpoint.best_val_loss) 48 | print(numParams(model)) 49 | 50 | 51 | # test function 52 | def evaluate(net, eval_loader): 53 | net.eval() 54 | 55 | print('********Starting metrics evaluation on test dataset**********') 56 | total_stoi = 0.0 57 | total_ssnr = 0.0 58 | total_pesq = 0.0 59 | total_csig = 0.0 60 | total_cbak = 0.0 61 | total_covl = 0.0 62 | 63 | with torch.no_grad(): 64 | count, total_eval_loss = 0, 0.0 65 | for k, (features, labels) in enumerate(eval_loader): 66 | features = features.cuda() # [1, 1, num_frames,frame_size] 67 | labels = labels.cuda() # [signal_len, ] 68 | 69 | output = net(features) # [1, 1, sig_len_recover] 70 | output = output.squeeze() # [sig_len_recover, ] 71 | 72 | # keep length same (output label) 73 | labels = labels[...,:output.shape[-1]].squeeze() 74 | 75 | eval_loss = torch.mean((output - labels) ** 2) 76 | total_eval_loss += eval_loss.data.item() 77 | 78 | est_sp = output.cpu().numpy() 79 | cln_raw = labels.cpu().numpy() 80 | 81 | eval_metric = eval_composite(cln_raw, est_sp, sr) 82 | 83 | total_pesq += eval_metric['pesq'] 84 | total_ssnr += eval_metric['ssnr'] 85 | total_stoi += eval_metric['stoi'] 86 | total_cbak += eval_metric['cbak'] 87 | total_csig += eval_metric['csig'] 88 | total_covl += eval_metric['covl'] 89 | 90 | wavfile.write(os.path.join(audio_file_save, os.path.basename(file_list[k])), sr, est_sp.astype(np.float32)) 91 | 92 | count += 1 93 | avg_eval_loss = total_eval_loss / count 94 | 95 | return avg_eval_loss, total_stoi / count, total_pesq / count, total_ssnr / count, total_csig / count, total_cbak / count, total_covl / count 96 | 97 | 98 | 99 | 100 | avg_eval, avg_stoi, avg_pesq, avg_ssnr, avg_csig, avg_cbak, avg_covl = evaluate(model, test_loader) 101 | 102 | #print('Avg_loss: {:.4f}'.format(avg_eval)) 103 | print('STOI: {:.4f}'.format(avg_stoi)) 104 | print('SSNR: {:.4f}'.format(avg_ssnr)) 105 | print('PESQ: {:.4f}'.format(avg_pesq)) 106 | print('CSIG: {:.4f}'.format(avg_csig)) 107 | print('CBAK: {:.4f}'.format(avg_cbak)) 108 | print('COVL: {:.4f}'.format(avg_covl)) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from AudioData import TrainingDataset, EvalDataset 3 | from torch.utils.data import DataLoader 4 | from model import Net 5 | from metric import get_pesq, get_stoi 6 | from criteria import stftm_loss, mag_loss 7 | from utils import Checkpoint 8 | from tqdm import tqdm 9 | import os 10 | import warnings 11 | 12 | # hyperparameter 13 | frame_size = 512 14 | overlap = 0.5 15 | frame_shift = int(512 * (1 - overlap)) 16 | max_epochs = 200 17 | batch_size = 8 18 | lr_init = 64 ** (-0.5) 19 | eval_steps = 50000 20 | weight_delay = 1e-7 21 | batches_per_epoch = 40000 22 | 23 | # lr scheduling 24 | step_num = 0 25 | warm_ups = 4000 26 | 27 | sr = 16000 28 | 29 | resume_model = None#'./checkpoints/temp/latest.model-3.model' 30 | model_save_path = './checkpoints/temp/' 31 | 32 | if not os.path.isdir(model_save_path): 33 | os.makedirs(model_save_path) 34 | 35 | early_stop = False 36 | 37 | # file path 38 | train_file_list_path = './train_file_list' 39 | validation_file_list_path = './validation_file_list' 40 | 41 | # data and data_loader 42 | train_data = TrainingDataset(train_file_list_path, frame_size=512, frame_shift=256) 43 | train_loader = DataLoader(train_data, 44 | batch_size=batch_size, 45 | shuffle=True, 46 | num_workers=4,) 47 | 48 | validation_data = EvalDataset(validation_file_list_path, frame_size=512, frame_shift=256) 49 | validation_loader = DataLoader(validation_data, 50 | batch_size=1, 51 | shuffle=False, 52 | num_workers=4,) 53 | 54 | # define model 55 | model = Net() 56 | model = torch.nn.DataParallel(model) 57 | model = model.cuda() 58 | print('Number of learnable parameters: %d' % numParams(model)) 59 | 60 | optimizer = torch.optim.Adam(model.parameters(), lr=lr_init, weight_decay=weight_delay) 61 | 62 | mag_loss = mag_loss() 63 | freq_loss = stftm_loss() 64 | 65 | def validate(net, eval_loader, test_metric=False): 66 | net.eval() 67 | if test_metric: 68 | print('********Starting metrics evaluation on val dataset**********') 69 | total_stoi = 0.0 70 | total_snr = 0.0 71 | total_pesq = 0.0 72 | 73 | with torch.no_grad(): 74 | count, total_eval_loss = 0, 0.0 75 | for k, (features, labels) in enumerate(eval_loader): 76 | features = features.cuda() # [1, 1, num_frames, frame_size] 77 | labels = labels.cuda() # [signal_len, ] 78 | 79 | output = net(features) # [1, 1, sig_len_recover] 80 | output = output.squeeze() # [sig_len_recover,] 81 | 82 | #print(output.shape) 83 | labels = labels[:,:output.shape[-1]].squeeze() # keep length same (output label) 84 | #print(labels.shape) 85 | #eval_loss = torch.mean((output - labels) ** 2) * 1000 86 | 87 | eval_loss = net.module.loss(output, labels, loss_mode='SI-SNR')#torch.mean((output - labels) ** 2) * 1000 88 | total_eval_loss += eval_loss.item() 89 | ##print(k) 90 | 91 | est_sp = output.cpu().numpy() 92 | cln_raw = labels.cpu().numpy() 93 | if test_metric: 94 | st = get_stoi(cln_raw, est_sp, sr) 95 | pe = get_pesq(cln_raw, est_sp, sr) 96 | sn = eval_loss 97 | total_pesq += pe 98 | 99 | total_snr += sn 100 | total_stoi += st 101 | count += 1 102 | # print(count) 103 | avg_eval_loss = total_eval_loss / count 104 | net.train() 105 | if test_metric: 106 | return avg_eval_loss, total_stoi / count, total_pesq / count, total_snr / count 107 | else: 108 | return avg_eval_loss 109 | 110 | # train model 111 | if resume_model: 112 | print('Resume model from "%s"' % resume_model) 113 | checkpoint = Checkpoint() 114 | checkpoint.load(resume_model) 115 | 116 | warm_ups = 0 117 | start_epoch = checkpoint.start_epoch + 1 118 | best_val_loss = checkpoint.best_val_loss 119 | prev_val_loss = checkpoint.prev_val_loss 120 | num_no_improv = checkpoint.num_no_improv 121 | half_lr = checkpoint.half_lr 122 | model.load_state_dict(checkpoint.state_dict) 123 | optimizer.load_state_dict(checkpoint.optimizer) 124 | 125 | else: 126 | print('Training from scratch.') 127 | start_epoch = 0 128 | best_val_loss = float("inf") 129 | prev_val_loss = float("inf") 130 | num_no_improv = 0 131 | half_lr = False 132 | 133 | for epoch in range(start_epoch, max_epochs): 134 | model.train() 135 | total_train_loss, count, ave_train_loss = 0.0, 0, 0.0 136 | 137 | pbar = tqdm(enumerate(train_loader)) 138 | for index, (features, labels,sig_len ) in enumerate(train_loader): 139 | step_num += 1 140 | 141 | if step_num <= warm_ups: 142 | 143 | lr = 0.2 * lr_init * min(step_num ** (-0.5), 144 | step_num * (warm_ups ** (-1.5))) 145 | else: 146 | lr = 0.0004 * (0.98 ** ((epoch - 1) // 2)) 147 | 148 | for param_group in optimizer.param_groups: 149 | param_group['lr'] = lr 150 | 151 | 152 | # feature -- [batch_size, 2, nframes, frame_size] 153 | features = features.cuda() 154 | # label -- [batch_size, 1, signal_length] 155 | labels = labels.cuda() 156 | 157 | 158 | optimizer.zero_grad() 159 | 160 | output = model(features) # output -- [batch_size, 1, sig_len_recover] 161 | 162 | 163 | output = output[:, :, :labels.shape[-1]] # [batch_size, 1, sig_len] 164 | labels = labels[...,:output.shape[-1]] 165 | 166 | 167 | #loss_time = model.module.loss(output, labels, loss_mode='SI-SNR') 168 | loss_mag = mag_loss(output, labels) 169 | loss_freq = freq_loss(output, labels, loss_mask) 170 | loss = loss_mag +loss_freq #+ 0.05 * loss_time 171 | 172 | 173 | ref_loss = model.module.loss(output, labels, loss_mode='SI-SNR') 174 | loss.backward() 175 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5) 176 | 177 | 178 | optimizer.step() 179 | 180 | train_loss = loss.item() 181 | total_train_loss += train_loss 182 | 183 | count += 1 184 | 185 | #del loss, output, features, labels 186 | 187 | del loss, loss_mag, loss_freq, output, features, labels 188 | 189 | pbar.set_description('iter = {}/{}, epoch = {}/{}, loss = {:.5f}, ref_loss = {:.5f}, lr = {:.6f}'.format(index + 1, len(train_loader), epoch + 1, max_epochs, train_loss, ref_loss,lr)) 190 | 191 | if (index + 1) % eval_steps == 0: 192 | ave_train_loss = total_train_loss / count 193 | 194 | # validation 195 | avg_eval_loss = validate(model, validation_loader) 196 | model.train() 197 | 198 | print('Epoch [%d/%d], Iter [%d/%d], ( TrainLoss: %.4f | EvalLoss: %.4f )' % ( 199 | epoch + 1, max_epochs, index + 1, len(train_loader), ave_train_loss, avg_eval_loss)) 200 | 201 | count = 0 202 | total_train_loss = 0.0 203 | 204 | 205 | if (index + 1) % batches_per_epoch == 0: 206 | break 207 | 208 | # validate metric 209 | avg_eval, avg_stoi, avg_pesq, avg_snr = validate(model, validation_loader, test_metric=True) 210 | model.train() 211 | print('#' * 50) 212 | print('') 213 | print('After {} epoch the performance on validation score is a s follows:'.format(epoch + 1)) 214 | print('') 215 | print('Avg_loss: {:.4f}'.format(avg_eval)) 216 | print('STOI: {:.4f}'.format(avg_stoi)) 217 | print('SNR: {:.4f}'.format(avg_snr)) 218 | print('PESQ: {:.4f}'.format(avg_pesq)) 219 | 220 | 221 | store_to_file = 'After {} epoch the performance on validation score is a s follows:'.format(epoch + 1) +\ 222 | 'Avg_loss: {:.4f}'.format(avg_eval) +\ 223 | 'STOI: {:.4f}'.format(avg_stoi) +\ 224 | 'SNR: {:.4f}'.format(avg_snr) +\ 225 | 'PESQ: {:.4f}'.format(avg_pesq) 226 | 227 | doc = store_to_file + '\n' 228 | file_name = './log/train_log' 229 | with open(file_name, "a") as myfile: 230 | myfile.write(doc) 231 | 232 | 233 | # adjust learning rate and early stop 234 | if avg_eval >= prev_val_loss: 235 | num_no_improv += 1 236 | if num_no_improv == 2: 237 | half_lr = True 238 | if num_no_improv >= 10 and early_stop is True: 239 | print("No improvement and apply early stop") 240 | break 241 | else: 242 | num_no_improv = 0 243 | 244 | prev_val_loss = avg_eval 245 | 246 | if avg_eval < best_val_loss: 247 | best_val_loss = avg_eval 248 | is_best_model = True 249 | else: 250 | is_best_model = False 251 | 252 | # save model 253 | latest_model = 'latest.model' 254 | best_model = 'best.model' 255 | 256 | checkpoint = Checkpoint(start_epoch=epoch, 257 | best_val_loss=best_val_loss, 258 | prev_val_loss=prev_val_loss, 259 | state_dict=model.state_dict(), 260 | optimizer=optimizer.state_dict(), 261 | num_no_improv=num_no_improv, 262 | half_lr=half_lr) 263 | checkpoint.save(is_best=is_best_model, 264 | filename=os.path.join(model_save_path, latest_model + '-{}.model'.format(epoch + 1)), 265 | best_model=os.path.join(model_save_path, best_model)) 266 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import shutil 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | 10 | 11 | class Checkpoint(object): 12 | def __init__(self, start_epoch=None, start_iter=None, train_loss=None, eval_loss=None, best_val_loss=float("inf"), 13 | prev_val_loss=float("inf"), state_dict=None, optimizer=None, num_no_improv=0, half_lr=False): 14 | self.start_epoch = start_epoch 15 | self.start_iter = start_iter 16 | self.train_loss = train_loss 17 | self.eval_loss = eval_loss 18 | 19 | self.best_val_loss = best_val_loss 20 | self.prev_val_loss = prev_val_loss 21 | 22 | self.state_dict = state_dict 23 | self.optimizer = optimizer 24 | 25 | self.num_no_improv = num_no_improv 26 | self.half_lr = half_lr 27 | 28 | 29 | def save(self, is_best, filename, best_model): 30 | print('Saving checkpoint at "%s"' % filename) 31 | torch.save(self, filename) 32 | if is_best: 33 | print('Saving the best model at "%s"' % best_model) 34 | shutil.copyfile(filename, best_model) 35 | print('\n') 36 | 37 | 38 | def load(self, filename): 39 | # filename : model path 40 | if os.path.isfile(filename): 41 | print('Loading checkpoint from "%s"\n' % filename) 42 | checkpoint = torch.load(filename, map_location='cpu') 43 | 44 | self.start_epoch = checkpoint.start_epoch 45 | self.start_iter = checkpoint.start_iter 46 | self.train_loss = checkpoint.train_loss 47 | self.eval_loss = checkpoint.eval_loss 48 | 49 | self.best_val_loss = checkpoint.best_val_loss 50 | self.prev_val_loss = checkpoint.prev_val_loss 51 | self.state_dict = checkpoint.state_dict 52 | self.optimizer = checkpoint.optimizer 53 | self.num_no_improv = checkpoint.num_no_improv 54 | self.half_lr = checkpoint.half_lr 55 | else: 56 | raise ValueError('No checkpoint found at "%s"' % filename) 57 | 58 | def show_params(nnet): 59 | print("=" * 40, "Model Parameters", "=" * 40) 60 | num_params = 0 61 | for module_name, m in nnet.named_modules(): 62 | if module_name == '': 63 | for name, params in m.named_parameters(): 64 | print(name, params.size()) 65 | i = 1 66 | for j in params.size(): 67 | i = i * j 68 | num_params += i 69 | print('[*] Parameter Size: {}'.format(num_params)) 70 | print("=" * 98) 71 | 72 | 73 | def show_model(nnet): 74 | print("=" * 40, "Model Structures", "=" * 40) 75 | for module_name, m in nnet.named_modules(): 76 | if module_name == '': 77 | print(m) 78 | print("=" * 98) 79 | ~ --------------------------------------------------------------------------------