├── README.md ├── cal_metrics ├── cal_metrics.py ├── compute_metrics.py └── metric.py ├── data └── make_dataset_json.py ├── dataloaders └── dataloader_vctk.py ├── datasets └── dataset.py ├── flops.py ├── inference.py ├── models ├── codec_module.py ├── discriminator.py ├── generator.py ├── loss.py ├── lsigmoid.py ├── mamba_block.py ├── pcs400.py └── stfts.py ├── recipes ├── Mamba-SEUNet-PCS │ └── Mamba-SEUNet-PCS.yaml └── Mamba-SEUNet │ └── Mamba-SEUNet.yaml ├── requirements.txt ├── train.py └── utils └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Mamba-SEUNet: Mamba UNet for Monaural Speech Enhancement (Accepted at ICASSP 2025) 2 | 3 | **Abstract:** 4 | In recent speech enhancement (SE) research, transformer and its variants have emerged as the predominant methodologies. However, the quadratic complexity of the self-attention mechanism imposes certain limitations on practical deployment. Mamba, as a novel state-space model (SSM), has gained widespread application in natural language processing and computer vision due to its strong capabilities in modeling long sequences and relatively low computational complexity. In this work, we introduce Mamba-SEUNet, an innovative architecture that integrates Mamba with U-Net for SE tasks. By leveraging bidirectional Mamba to model forward and backward dependencies of speech signals at different resolutions, and incorporating skip connections to capture multi-scale information, our approach achieves state-of-the-art (SOTA) performance. Experimental results on the VCTK+DEMAND dataset indicate that Mamba-SEUNet attains a PESQ score of 3.59, while maintaining low computational complexity. When combined with the Perceptual Contrast Stretching technique, Mamba-SEUNet further improves the PESQ score to 3.73. 5 | 6 | ## Pre-requisites 7 | 1. Python >= 3.8. 8 | 2. Clone this repository. 9 | 3. Install python requirements. Please refer requirements.txt. 10 | 4. Download and extract the [VoiceBank+DEMAND dataset](https://datashare.ed.ac.uk/handle/10283/1942). 11 | 12 | ## Training 13 | For single GPU (Recommend), Mamba-SEUNet needs at least 12GB GPU memery. 14 | ``` 15 | python train.py 16 | ``` 17 | 18 | ## Training with your own data 19 | Generate six dataset json files using data/make_dataset_json.py 20 | ``` 21 | python make_dataset_json.py 22 | ``` 23 | 24 | ## Inference 25 | ``` 26 | python inference.py --checkpoint_file /PATH/TO/YOUR/CHECK_POINT/g_xxxxxxx 27 | ``` 28 | 29 | ## Acknowledgements 30 | We referred to [MP-SENet](https://github.com/yxlu-0102/MP-SENet), [MUSE](https://github.com/huaidanquede/MUSE-Speech-Enhancement), [SEMamba](https://github.com/RoyChao19477/SEMamba) 31 | -------------------------------------------------------------------------------- /cal_metrics/cal_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import librosa 4 | import numpy as np 5 | from compute_metrics import compute_metrics 6 | from rich.progress import track 7 | 8 | def get_dataset_filelist(h): 9 | with open(h.input_test_file, 'r', encoding='utf-8') as fi: 10 | indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] 11 | 12 | return indexes 13 | 14 | def main(h): 15 | indexes = get_dataset_filelist(h) 16 | num = len(indexes) 17 | print(num) 18 | metrics_total = np.zeros(6) 19 | for index in track(indexes): 20 | clean_wav = os.path.join(h.clean_wav_dir, index + '.wav') 21 | noisy_wav = os.path.join(h.noisy_wav_dir, index + '.wav') 22 | clean, sr = librosa.load(clean_wav, h.sampling_rate) 23 | noisy, sr = librosa.load(noisy_wav, h.sampling_rate) 24 | 25 | metrics = compute_metrics(clean, noisy, sr, 0) 26 | metrics = np.array(metrics) 27 | metrics_total += metrics 28 | 29 | metrics_avg = metrics_total / num 30 | print('pesq: ', metrics_avg[0], 'csig: ', metrics_avg[1], 'cbak: ', metrics_avg[2], 31 | 'covl: ', metrics_avg[3], 'ssnr: ', metrics_avg[4], 'stoi: ', metrics_avg[5]) 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--sampling_rate', default=16000) 36 | parser.add_argument('--input_test_file', default='dataset_se/test.txt') 37 | parser.add_argument('--clean_wav_dir', default='dataset_se/testset_clean') 38 | parser.add_argument('--noisy_wav_dir', default='generated_files/MP-SENet') 39 | 40 | h = parser.parse_args() 41 | 42 | main(h) -------------------------------------------------------------------------------- /cal_metrics/compute_metrics.py: -------------------------------------------------------------------------------- 1 | # Copy from "https://github.com/ruizhecao96/CMGAN/blob/main/src/tools/compute_metrics.py" 2 | 3 | import numpy as np 4 | from scipy.io import wavfile 5 | from scipy.linalg import toeplitz, norm 6 | from scipy.fftpack import fft 7 | import math 8 | from scipy import signal 9 | from pesq import pesq 10 | 11 | ''' 12 | This is a python script which can be regarded as implementation of matlab script "compute_metrics.m". 13 | 14 | Usage: 15 | pesq, csig, cbak, covl, ssnr, stoi = compute_metrics(cleanFile, enhancedFile, Fs, path) 16 | cleanFile: clean audio as array or path if path is equal to 1 17 | enhancedFile: enhanced audio as array or path if path is equal to 1 18 | Fs: sampling rate, usually equals to 8000 or 16000 Hz 19 | path: whether the "cleanFile" and "enhancedFile" arguments are in .wav format or in numpy array format, 20 | 1 indicates "in .wav format" 21 | 22 | Example call: 23 | pesq_output, csig_output, cbak_output, covl_output, ssnr_output, stoi_output = \ 24 | compute_metrics(target_audio, output_audio, 16000, 0) 25 | ''' 26 | 27 | 28 | def compute_metrics(cleanFile, enhancedFile, Fs, path): 29 | alpha = 0.95 30 | 31 | if path == 1: 32 | sampling_rate1, data1 = wavfile.read(cleanFile) 33 | sampling_rate2, data2 = wavfile.read(enhancedFile) 34 | if sampling_rate1 != sampling_rate2: 35 | raise ValueError('The two files do not match!\n') 36 | else: 37 | data1 = cleanFile 38 | data2 = enhancedFile 39 | sampling_rate1 = Fs 40 | sampling_rate2 = Fs 41 | 42 | if len(data1) != len(data2): 43 | length = min(len(data1), len(data2)) 44 | data1 = data1[0: length] + np.spacing(1) 45 | data2 = data2[0: length] + np.spacing(1) 46 | 47 | # compute the WSS measure 48 | wss_dist_vec = wss(data1, data2, sampling_rate1) 49 | wss_dist_vec = np.sort(wss_dist_vec) 50 | wss_dist = np.mean(wss_dist_vec[0: round(np.size(wss_dist_vec) * alpha)]) 51 | 52 | # compute the LLR measure 53 | LLR_dist = llr(data1, data2, sampling_rate1) 54 | LLRs = np.sort(LLR_dist) 55 | LLR_len = round(np.size(LLR_dist) * alpha) 56 | llr_mean = np.mean(LLRs[0: LLR_len]) 57 | 58 | # compute the SNRseg 59 | snr_dist, segsnr_dist = snr(data1, data2, sampling_rate1) 60 | snr_mean = snr_dist 61 | segSNR = np.mean(segsnr_dist) 62 | 63 | # compute the pesq 64 | pesq_mos = pesq(sampling_rate1, data1, data2, 'wb') 65 | 66 | # now compute the composite measures 67 | CSIG = 3.093 - 1.029 * llr_mean + 0.603 * pesq_mos - 0.009 * wss_dist 68 | CSIG = max(1, CSIG) 69 | CSIG = min(5, CSIG) # limit values to [1, 5] 70 | CBAK = 1.634 + 0.478 * pesq_mos - 0.007 * wss_dist + 0.063 * segSNR 71 | CBAK = max(1, CBAK) 72 | CBAK = min(5, CBAK) # limit values to [1, 5] 73 | COVL = 1.594 + 0.805 * pesq_mos - 0.512 * llr_mean - 0.007 * wss_dist 74 | COVL = max(1, COVL) 75 | COVL = min(5, COVL) # limit values to [1, 5] 76 | 77 | STOI = stoi(data1, data2, sampling_rate1) 78 | 79 | return pesq_mos, CSIG, CBAK, COVL, segSNR, STOI 80 | 81 | 82 | def wss(clean_speech, processed_speech, sample_rate): 83 | # Check the length of the clean and processed speech, which must be the same. 84 | clean_length = np.size(clean_speech) 85 | processed_length = np.size(processed_speech) 86 | if clean_length != processed_length: 87 | raise ValueError('Files must have same length.') 88 | 89 | # Global variables 90 | winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples 91 | skiprate = (np.floor(np.divide(winlength, 4))).astype(int) # window skip in samples 92 | max_freq = (np.divide(sample_rate, 2)).astype(int) # maximum bandwidth 93 | num_crit = 25 # number of critical bands 94 | 95 | USE_FFT_SPECTRUM = 1 # defaults to 10th order LP spectrum 96 | n_fft = (np.power(2, np.ceil(np.log2(2 * winlength)))).astype(int) 97 | n_fftby2 = (np.multiply(0.5, n_fft)).astype(int) # FFT size/2 98 | Kmax = 20.0 # value suggested by Klatt, pg 1280 99 | Klocmax = 1.0 # value suggested by Klatt, pg 1280 100 | 101 | # Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz) 102 | cent_freq = np.array([50.0000, 120.000, 190.000, 260.000, 330.000, 400.000, 470.000, 103 | 540.000, 617.372, 703.378, 798.717, 904.128, 1020.38, 1148.30, 104 | 1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 2211.08, 2446.71, 105 | 2701.97, 2978.04, 3276.17, 3597.63]) 106 | bandwidth = np.array([70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 70.0000, 107 | 77.3724, 86.0056, 95.3398, 105.411, 116.256, 127.914, 140.423, 108 | 153.823, 168.154, 183.457, 199.776, 217.153, 235.631, 255.255, 109 | 276.072, 298.126, 321.465, 346.136]) 110 | 111 | bw_min = bandwidth[0] # minimum critical bandwidth 112 | 113 | # Set up the critical band filters. 114 | # Note here that Gaussianly shaped filters are used. 115 | # Also, the sum of the filter weights are equivalent for each critical band filter. 116 | # Filter less than -30 dB and set to zero. 117 | min_factor = math.exp(-30.0 / (2.0 * 2.303)) # -30 dB point of filter 118 | crit_filter = np.empty((num_crit, n_fftby2)) 119 | for i in range(num_crit): 120 | f0 = (cent_freq[i] / max_freq) * n_fftby2 121 | bw = (bandwidth[i] / max_freq) * n_fftby2 122 | norm_factor = np.log(bw_min) - np.log(bandwidth[i]) 123 | j = np.arange(n_fftby2) 124 | crit_filter[i, :] = np.exp(-11 * np.square(np.divide(j - np.floor(f0), bw)) + norm_factor) 125 | cond = np.greater(crit_filter[i, :], min_factor) 126 | crit_filter[i, :] = np.where(cond, crit_filter[i, :], 0) 127 | # For each frame of input speech, calculate the Weighted Spectral Slope Measure 128 | num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames 129 | start = 0 # starting sample 130 | window = 0.5 * (1 - np.cos(2 * math.pi * np.arange(1, winlength + 1) / (winlength + 1))) 131 | 132 | distortion = np.empty(num_frames) 133 | for frame_count in range(num_frames): 134 | # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window. 135 | clean_frame = clean_speech[start: start + winlength] / 32768 136 | processed_frame = processed_speech[start: start + winlength] / 32768 137 | clean_frame = np.multiply(clean_frame, window) 138 | processed_frame = np.multiply(processed_frame, window) 139 | # (2) Compute the Power Spectrum of Clean and Processed 140 | # if USE_FFT_SPECTRUM: 141 | clean_spec = np.square(np.abs(fft(clean_frame, n_fft))) 142 | processed_spec = np.square(np.abs(fft(processed_frame, n_fft))) 143 | 144 | # (3) Compute Filterbank Output Energies (in dB scale) 145 | clean_energy = np.matmul(crit_filter, clean_spec[0:n_fftby2]) 146 | processed_energy = np.matmul(crit_filter, processed_spec[0:n_fftby2]) 147 | 148 | clean_energy = 10 * np.log10(np.maximum(clean_energy, 1E-10)) 149 | processed_energy = 10 * np.log10(np.maximum(processed_energy, 1E-10)) 150 | 151 | # (4) Compute Spectral Slope (dB[i+1]-dB[i]) 152 | clean_slope = clean_energy[1:num_crit] - clean_energy[0: num_crit - 1] 153 | processed_slope = processed_energy[1:num_crit] - processed_energy[0: num_crit - 1] 154 | 155 | # (5) Find the nearest peak locations in the spectra to each critical band. 156 | # If the slope is negative, we search to the left. If positive, we search to the right. 157 | clean_loc_peak = np.empty(num_crit - 1) 158 | processed_loc_peak = np.empty(num_crit - 1) 159 | 160 | for i in range(num_crit - 1): 161 | # find the peaks in the clean speech signal 162 | if clean_slope[i] > 0: # search to the right 163 | n = i 164 | while (n < num_crit - 1) and (clean_slope[n] > 0): 165 | n = n + 1 166 | clean_loc_peak[i] = clean_energy[n - 1] 167 | else: # search to the left 168 | n = i 169 | while (n >= 0) and (clean_slope[n] <= 0): 170 | n = n - 1 171 | clean_loc_peak[i] = clean_energy[n + 1] 172 | 173 | # find the peaks in the processed speech signal 174 | if processed_slope[i] > 0: # search to the right 175 | n = i 176 | while (n < num_crit - 1) and (processed_slope[n] > 0): 177 | n = n + 1 178 | processed_loc_peak[i] = processed_energy[n - 1] 179 | else: # search to the left 180 | n = i 181 | while (n >= 0) and (processed_slope[n] <= 0): 182 | n = n - 1 183 | processed_loc_peak[i] = processed_energy[n + 1] 184 | 185 | # (6) Compute the WSS Measure for this frame. This includes determination of the weighting function. 186 | dBMax_clean = np.max(clean_energy) 187 | dBMax_processed = np.max(processed_energy) 188 | ''' 189 | The weights are calculated by averaging individual weighting factors from the clean and processed frame. 190 | These weights W_clean and W_processed should range from 0 to 1 and place more emphasis on spectral peaks 191 | and less emphasis on slope differences in spectral valleys. 192 | This procedure is described on page 1280 of Klatt's 1982 ICASSP paper. 193 | ''' 194 | Wmax_clean = np.divide(Kmax, Kmax + dBMax_clean - clean_energy[0: num_crit - 1]) 195 | Wlocmax_clean = np.divide(Klocmax, Klocmax + clean_loc_peak - clean_energy[0: num_crit - 1]) 196 | W_clean = np.multiply(Wmax_clean, Wlocmax_clean) 197 | 198 | Wmax_processed = np.divide(Kmax, Kmax + dBMax_processed - processed_energy[0: num_crit - 1]) 199 | Wlocmax_processed = np.divide(Klocmax, Klocmax + processed_loc_peak - processed_energy[0: num_crit - 1]) 200 | W_processed = np.multiply(Wmax_processed, Wlocmax_processed) 201 | 202 | W = np.divide(np.add(W_clean, W_processed), 2.0) 203 | slope_diff = np.subtract(clean_slope, processed_slope)[0: num_crit - 1] 204 | distortion[frame_count] = np.dot(W, np.square(slope_diff)) / np.sum(W) 205 | # this normalization is not part of Klatt's paper, but helps to normalize the measure. 206 | # Here we scale the measure by the sum of the weights. 207 | start = start + skiprate 208 | return distortion 209 | 210 | 211 | def llr(clean_speech, processed_speech,sample_rate): 212 | # Check the length of the clean and processed speech. Must be the same. 213 | clean_length = np.size(clean_speech) 214 | processed_length = np.size(processed_speech) 215 | if clean_length != processed_length: 216 | raise ValueError('Both Speech Files must be same length.') 217 | 218 | # Global Variables 219 | winlength = (np.round(30 * sample_rate / 1000)).astype(int) # window length in samples 220 | skiprate = (np.floor(winlength / 4)).astype(int) # window skip in samples 221 | if sample_rate < 10000: 222 | P = 10 # LPC Analysis Order 223 | else: 224 | P = 16 # this could vary depending on sampling frequency. 225 | 226 | # For each frame of input speech, calculate the Log Likelihood Ratio 227 | num_frames = int((clean_length - winlength) / skiprate) # number of frames 228 | start = 0 # starting sample 229 | window = 0.5 * (1 - np.cos(2 * math.pi * np.arange(1, winlength + 1) / (winlength + 1))) 230 | 231 | distortion = np.empty(num_frames) 232 | for frame_count in range(num_frames): 233 | # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window. 234 | clean_frame = clean_speech[start: start + winlength] 235 | processed_frame = processed_speech[start: start + winlength] 236 | clean_frame = np.multiply(clean_frame, window) 237 | processed_frame = np.multiply(processed_frame, window) 238 | 239 | # (2) Get the autocorrelation lags and LPC parameters used to compute the LLR measure. 240 | R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P) 241 | R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P) 242 | 243 | # (3) Compute the LLR measure 244 | numerator = np.dot(np.matmul(A_processed, toeplitz(R_clean)), A_processed) 245 | denominator = np.dot(np.matmul(A_clean, toeplitz(R_clean)), A_clean) 246 | distortion[frame_count] = math.log(numerator / denominator) 247 | start = start + skiprate 248 | return distortion 249 | 250 | 251 | def lpcoeff(speech_frame, model_order): 252 | # (1) Compute Autocorrelation Lags 253 | winlength = np.size(speech_frame) 254 | R = np.empty(model_order + 1) 255 | E = np.empty(model_order + 1) 256 | for k in range(model_order + 1): 257 | R[k] = np.dot(speech_frame[0:winlength - k], speech_frame[k: winlength]) 258 | 259 | # (2) Levinson-Durbin 260 | a = np.ones(model_order) 261 | a_past = np.empty(model_order) 262 | rcoeff = np.empty(model_order) 263 | E[0] = R[0] 264 | for i in range(model_order): 265 | a_past[0: i] = a[0: i] 266 | sum_term = np.dot(a_past[0: i], R[i:0:-1]) 267 | rcoeff[i] = (R[i + 1] - sum_term) / E[i] 268 | a[i] = rcoeff[i] 269 | if i == 0: 270 | a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1:-1:-1], rcoeff[i]) 271 | else: 272 | a[0: i] = a_past[0: i] - np.multiply(a_past[i - 1::-1], rcoeff[i]) 273 | E[i + 1] = (1 - rcoeff[i] * rcoeff[i]) * E[i] 274 | acorr = R 275 | refcoeff = rcoeff 276 | lpparams = np.concatenate((np.array([1]), -a)) 277 | return acorr, refcoeff, lpparams 278 | 279 | 280 | def snr(clean_speech, processed_speech, sample_rate): 281 | # Check the length of the clean and processed speech. Must be the same. 282 | clean_length = len(clean_speech) 283 | processed_length = len(processed_speech) 284 | if clean_length != processed_length: 285 | raise ValueError('Both Speech Files must be same length.') 286 | 287 | overall_snr = 10 * np.log10(np.sum(np.square(clean_speech)) / np.sum(np.square(clean_speech - processed_speech))) 288 | 289 | # Global Variables 290 | winlength = round(30 * sample_rate / 1000) # window length in samples 291 | skiprate = math.floor(winlength / 4) # window skip in samples 292 | MIN_SNR = -10 # minimum SNR in dB 293 | MAX_SNR = 35 # maximum SNR in dB 294 | 295 | # For each frame of input speech, calculate the Segmental SNR 296 | num_frames = int(clean_length / skiprate - (winlength / skiprate)) # number of frames 297 | start = 0 # starting sample 298 | window = 0.5 * (1 - np.cos(2 * math.pi * np.arange(1, winlength + 1) / (winlength + 1))) 299 | 300 | segmental_snr = np.empty(num_frames) 301 | EPS = np.spacing(1) 302 | for frame_count in range(num_frames): 303 | # (1) Get the Frames for the test and reference speech. Multiply by Hanning Window. 304 | clean_frame = clean_speech[start:start + winlength] 305 | processed_frame = processed_speech[start:start + winlength] 306 | clean_frame = np.multiply(clean_frame, window) 307 | processed_frame = np.multiply(processed_frame, window) 308 | 309 | # (2) Compute the Segmental SNR 310 | signal_energy = np.sum(np.square(clean_frame)) 311 | noise_energy = np.sum(np.square(clean_frame - processed_frame)) 312 | segmental_snr[frame_count] = 10 * math.log10(signal_energy / (noise_energy + EPS) + EPS) 313 | segmental_snr[frame_count] = max(segmental_snr[frame_count], MIN_SNR) 314 | segmental_snr[frame_count] = min(segmental_snr[frame_count], MAX_SNR) 315 | 316 | start = start + skiprate 317 | 318 | return overall_snr, segmental_snr 319 | 320 | 321 | def stoi(x, y, fs_signal): 322 | if np.size(x) != np.size(y): 323 | raise ValueError('x and y should have the same length') 324 | 325 | # initialization, pay attention to the range of x and y(divide by 32768?) 326 | fs = 10000 # sample rate of proposed intelligibility measure 327 | N_frame = 256 # window support 328 | K = 512 # FFT size 329 | J = 15 # Number of 1/3 octave bands 330 | mn = 150 # Center frequency of first 1/3 octave band in Hz 331 | H, _ = thirdoct(fs, K, J, mn) # Get 1/3 octave band matrix 332 | N = 30 # Number of frames for intermediate intelligibility measure (Length analysis window) 333 | Beta = -15 # lower SDR-bound 334 | dyn_range = 40 # speech dynamic range 335 | 336 | # resample signals if other sample rate is used than fs 337 | if fs_signal != fs: 338 | x = signal.resample_poly(x, fs, fs_signal) 339 | y = signal.resample_poly(y, fs, fs_signal) 340 | 341 | # remove silent frames 342 | x, y = removeSilentFrames(x, y, dyn_range, N_frame, int(N_frame / 2)) 343 | 344 | # apply 1/3 octave band TF-decomposition 345 | x_hat = stdft(x, N_frame, N_frame / 2, K) # apply short-time DFT to clean speech 346 | y_hat = stdft(y, N_frame, N_frame / 2, K) # apply short-time DFT to processed speech 347 | 348 | x_hat = np.transpose(x_hat[:, 0:(int(K / 2) + 1)]) # take clean single-sided spectrum 349 | y_hat = np.transpose(y_hat[:, 0:(int(K / 2) + 1)]) # take processed single-sided spectrum 350 | 351 | X = np.sqrt(np.matmul(H, np.square(np.abs(x_hat)))) # apply 1/3 octave bands as described in Eq.(1) [1] 352 | Y = np.sqrt(np.matmul(H, np.square(np.abs(y_hat)))) 353 | 354 | # loop al segments of length N and obtain intermediate intelligibility measure for all TF-regions 355 | d_interm = np.zeros(np.size(np.arange(N - 1, x_hat.shape[1]))) 356 | # init memory for intermediate intelligibility measure 357 | c = 10 ** (-Beta / 20) 358 | # constant for clipping procedure 359 | 360 | for m in range(N - 1, x_hat.shape[1]): 361 | X_seg = X[:, (m - N + 1): (m + 1)] # region with length N of clean TF-units for all j 362 | Y_seg = Y[:, (m - N + 1): (m + 1)] # region with length N of processed TF-units for all j 363 | # obtain scale factor for normalizing processed TF-region for all j 364 | alpha = np.sqrt(np.divide(np.sum(np.square(X_seg), axis=1, keepdims=True), 365 | np.sum(np.square(Y_seg), axis=1, keepdims=True))) 366 | # obtain \alpha*Y_j(n) from Eq.(2) [1] 367 | aY_seg = np.multiply(Y_seg, alpha) 368 | # apply clipping from Eq.(3) 369 | Y_prime = np.minimum(aY_seg, X_seg + X_seg * c) 370 | # obtain correlation coeffecient from Eq.(4) [1] 371 | d_interm[m - N + 1] = taa_corr(X_seg, Y_prime) / J 372 | 373 | d = d_interm.mean() # combine all intermediate intelligibility measures as in Eq.(4) [1] 374 | return d 375 | 376 | 377 | def thirdoct(fs, N_fft, numBands, mn): 378 | """ 379 | [A CF] = THIRDOCT(FS, N_FFT, NUMBANDS, MN) returns 1/3 octave band matrix 380 | inputs: 381 | FS: samplerate 382 | N_FFT: FFT size 383 | NUMBANDS: number of bands 384 | MN: center frequency of first 1/3 octave band 385 | outputs: 386 | A: octave band matrix 387 | CF: center frequencies 388 | """ 389 | f = np.linspace(0, fs, N_fft + 1) 390 | f = f[0:int(N_fft / 2 + 1)] 391 | k = np.arange(numBands) 392 | cf = np.multiply(np.power(2, k / 3), mn) 393 | fl = np.sqrt(np.multiply(np.multiply(np.power(2, k / 3), mn), np.multiply(np.power(2, (k - 1) / 3), mn))) 394 | fr = np.sqrt(np.multiply(np.multiply(np.power(2, k / 3), mn), np.multiply(np.power(2, (k + 1) / 3), mn))) 395 | A = np.zeros((numBands, len(f))) 396 | 397 | for i in range(np.size(cf)): 398 | b = np.argmin((f - fl[i]) ** 2) 399 | fl[i] = f[b] 400 | fl_ii = b 401 | 402 | b = np.argmin((f - fr[i]) ** 2) 403 | fr[i] = f[b] 404 | fr_ii = b 405 | A[i, fl_ii: fr_ii] = 1 406 | 407 | rnk = np.sum(A, axis=1) 408 | end = np.size(rnk) 409 | rnk_back = rnk[1: end] 410 | rnk_before = rnk[0: (end-1)] 411 | for i in range(np.size(rnk_back)): 412 | if (rnk_back[i] >= rnk_before[i]) and (rnk_back[i] != 0): 413 | result = i 414 | numBands = result + 2 415 | A = A[0:numBands, :] 416 | cf = cf[0:numBands] 417 | return A, cf 418 | 419 | 420 | def stdft(x, N, K, N_fft): 421 | """ 422 | X_STDFT = X_STDFT(X, N, K, N_FFT) returns the short-time hanning-windowed dft of X with frame-size N, 423 | overlap K and DFT size N_FFT. The columns and rows of X_STDFT denote the frame-index and dft-bin index, 424 | respectively. 425 | """ 426 | frames_size = int((np.size(x) - N) / K) 427 | w = signal.windows.hann(N+2) 428 | w = w[1: N+1] 429 | 430 | x_stdft = signal.stft(x, window=w, nperseg=N, noverlap=K, nfft=N_fft, return_onesided=False, boundary=None)[2] 431 | x_stdft = np.transpose(x_stdft)[0:frames_size, :] 432 | 433 | return x_stdft 434 | 435 | 436 | def removeSilentFrames(x, y, dyrange, N, K): 437 | """ 438 | [X_SIL Y_SIL] = REMOVESILENTFRAMES(X, Y, RANGE, N, K) X and Y are segmented with frame-length N 439 | and overlap K, where the maximum energy of all frames of X is determined, say X_MAX. 440 | X_SIL and Y_SIL are the reconstructed signals, excluding the frames, where the energy of a frame 441 | of X is smaller than X_MAX-RANGE 442 | """ 443 | 444 | frames = np.arange(0, (np.size(x) - N), K) 445 | w = signal.windows.hann(N+2) 446 | w = w[1: N+1] 447 | 448 | jj_list = np.empty((np.size(frames), N), dtype=int) 449 | for j in range(np.size(frames)): 450 | jj_list[j, :] = np.arange(frames[j] - 1, frames[j] + N - 1) 451 | 452 | msk = 20 * np.log10(np.divide(norm(np.multiply(x[jj_list], w), axis=1), np.sqrt(N))) 453 | 454 | msk = (msk - np.max(msk) + dyrange) > 0 455 | count = 0 456 | 457 | x_sil = np.zeros(np.size(x)) 458 | y_sil = np.zeros(np.size(y)) 459 | 460 | for j in range(np.size(frames)): 461 | if msk[j]: 462 | jj_i = np.arange(frames[j], frames[j] + N) 463 | jj_o = np.arange(frames[count], frames[count] + N) 464 | x_sil[jj_o] = x_sil[jj_o] + np.multiply(x[jj_i], w) 465 | y_sil[jj_o] = y_sil[jj_o] + np.multiply(y[jj_i], w) 466 | count = count + 1 467 | 468 | x_sil = x_sil[0: jj_o[-1] + 1] 469 | y_sil = y_sil[0: jj_o[-1] + 1] 470 | return x_sil, y_sil 471 | 472 | 473 | def taa_corr(x, y): 474 | """ 475 | RHO = TAA_CORR(X, Y) Returns correlation coeffecient between column 476 | vectors x and y. Gives same results as 'corr' from statistics toolbox. 477 | """ 478 | xn = np.subtract(x, np.mean(x, axis=1, keepdims=True)) 479 | xn = np.divide(xn, norm(xn, axis=1, keepdims=True)) 480 | yn = np.subtract(y, np.mean(y, axis=1, keepdims=True)) 481 | yn = np.divide(yn, norm(yn, axis=1, keepdims=True)) 482 | rho = np.trace(np.matmul(xn, np.transpose(yn))) 483 | 484 | return rho 485 | -------------------------------------------------------------------------------- /cal_metrics/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from pesq import pesq 4 | from pystoi import stoi 5 | import sys 6 | import librosa 7 | from scipy.linalg import toeplitz 8 | from tqdm import tqdm 9 | # from .utils import * 10 | 11 | def get_scores(clean, estimate, args): 12 | estimate = estimate.numpy()[:, 0] 13 | clean = clean.numpy()[:, 0] 14 | 15 | pesq_i = get_pesq(clean, estimate, sr=args.sample_rate) 16 | stoi_i = get_stoi(clean, estimate, sr=args.sample_rate) 17 | 18 | return pesq_i, stoi_i 19 | 20 | def get_pesq(ref_sig, out_sig, sr): 21 | """Calculate PESQ. 22 | Args: 23 | ref_sig: numpy.ndarray, [B, T] 24 | out_sig: numpy.ndarray, [B, T] 25 | Returns: 26 | PESQ 27 | """ 28 | pesq_val = 0 29 | for i in range(len(ref_sig)): 30 | try: 31 | pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'wb') 32 | except: 33 | pesq_val = 3.1 34 | continue 35 | return pesq_val 36 | 37 | 38 | def get_stoi(ref_sig, out_sig, sr): 39 | """Calculate STOI. 40 | Args: 41 | ref_sig: numpy.ndarray, [B, T] 42 | out_sig: numpy.ndarray, [B, T] 43 | Returns: 44 | STOI 45 | """ 46 | stoi_val = 0 47 | for i in range(len(ref_sig)): 48 | stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False) 49 | return stoi_val 50 | 51 | 52 | # Original copyright: 53 | # The copy right is under the MIT license. 54 | # SEGAN (https://github.com/santi-pdp/segan_pytorch) / author: santi-pdp 55 | 56 | def eval_composite(ref_wav, deg_wav): 57 | ref_wav = ref_wav.reshape(-1) 58 | deg_wav = deg_wav.reshape(-1) 59 | 60 | alpha = 0.95 61 | len_ = min(ref_wav.shape[0], deg_wav.shape[0]) 62 | ref_wav = ref_wav[:len_] 63 | ref_len = ref_wav.shape[0] 64 | deg_wav = deg_wav[:len_] 65 | 66 | # Compute WSS measure 67 | wss_dist_vec = wss(ref_wav, deg_wav, 16000) 68 | wss_dist_vec = sorted(wss_dist_vec, reverse=False) 69 | wss_dist = np.mean(wss_dist_vec[:int(round(len(wss_dist_vec) * alpha))]) 70 | 71 | # Compute LLR measure 72 | LLR_dist = llr(ref_wav, deg_wav, 16000) 73 | LLR_dist = sorted(LLR_dist, reverse=False) 74 | LLRs = LLR_dist 75 | LLR_len = round(len(LLR_dist) * alpha) 76 | llr_mean = np.mean(LLRs[:LLR_len]) 77 | 78 | # Compute the SSNR 79 | snr_mean, segsnr_mean = SSNR(ref_wav, deg_wav, 16000) 80 | segSNR = np.mean(segsnr_mean) 81 | 82 | # Compute the PESQ 83 | pesq_raw = PESQ(ref_wav, deg_wav) 84 | 85 | Csig = 3.093 - 1.029 * llr_mean + 0.603 * pesq_raw - 0.009 * wss_dist 86 | Csig = trim_mos(Csig) 87 | Cbak = 1.634 + 0.478 * pesq_raw - 0.007 * wss_dist + 0.063 * segSNR 88 | Cbak = trim_mos(Cbak) 89 | Covl = 1.594 + 0.805 * pesq_raw - 0.512 * llr_mean - 0.007 * wss_dist 90 | Covl = trim_mos(Covl) 91 | 92 | return {'csig':Csig, 'cbak':Cbak, 'covl':Covl, 'pesq':pesq_raw} 93 | 94 | # ----------------------------- HELPERS ------------------------------------ # 95 | def trim_mos(val): 96 | return min(max(val, 1), 5) 97 | 98 | def lpcoeff(speech_frame, model_order): 99 | # (1) Compute Autocor lags 100 | winlength = speech_frame.shape[0] 101 | R = [] 102 | for k in range(model_order + 1): 103 | first = speech_frame[:(winlength - k)] 104 | second = speech_frame[k:winlength] 105 | R.append(np.sum(first * second)) 106 | 107 | # (2) Lev-Durbin 108 | a = np.ones((model_order,)) 109 | E = np.zeros((model_order + 1,)) 110 | rcoeff = np.zeros((model_order,)) 111 | E[0] = R[0] 112 | for i in range(model_order): 113 | if i == 0: 114 | sum_term = 0 115 | else: 116 | a_past = a[:i] 117 | sum_term = np.sum(a_past * np.array(R[i:0:-1])) 118 | rcoeff[i] = (R[i+1] - sum_term)/E[i] 119 | a[i] = rcoeff[i] 120 | if i > 0: 121 | a[:i] = a_past[:i] - rcoeff[i] * a_past[::-1] 122 | E[i+1] = (1-rcoeff[i]*rcoeff[i])*E[i] 123 | acorr = np.array(R, dtype=np.float32) 124 | refcoeff = np.array(rcoeff, dtype=np.float32) 125 | a = a * -1 126 | lpparams = np.array([1] + list(a), dtype=np.float32) 127 | acorr = np.array(acorr, dtype=np.float32) 128 | refcoeff = np.array(refcoeff, dtype=np.float32) 129 | lpparams = np.array(lpparams, dtype=np.float32) 130 | 131 | return acorr, refcoeff, lpparams 132 | # -------------------------------------------------------------------------- # 133 | 134 | # ---------------------- Speech Quality Metric ----------------------------- # 135 | def PESQ(ref_wav, deg_wav): 136 | rate = 16000 137 | return pesq(rate, ref_wav, deg_wav, 'wb') 138 | 139 | def SSNR(ref_wav, deg_wav, srate=16000, eps=1e-10): 140 | """ Segmental Signal-to-Noise Ratio Objective Speech Quality Measure 141 | This function implements the segmental signal-to-noise ratio 142 | as defined in [1, p. 45] (see Equation 2.12). 143 | """ 144 | clean_speech = ref_wav 145 | processed_speech = deg_wav 146 | clean_length = ref_wav.shape[0] 147 | processed_length = deg_wav.shape[0] 148 | 149 | # scale both to have same dynamic range. Remove DC too. 150 | clean_speech -= clean_speech.mean() 151 | processed_speech -= processed_speech.mean() 152 | processed_speech *= (np.max(np.abs(clean_speech)) / np.max(np.abs(processed_speech))) 153 | 154 | # Signal-to-Noise Ratio 155 | dif = ref_wav - deg_wav 156 | overall_snr = 10 * np.log10(np.sum(ref_wav ** 2) / (np.sum(dif ** 2) + 157 | 10e-20)) 158 | # global variables 159 | winlength = int(np.round(30 * srate / 1000)) # 30 msecs 160 | skiprate = winlength // 4 161 | MIN_SNR = -10 162 | MAX_SNR = 35 163 | 164 | # For each frame, calculate SSNR 165 | num_frames = int(clean_length / skiprate - (winlength/skiprate)) 166 | start = 0 167 | time = np.linspace(1, winlength, winlength) / (winlength + 1) 168 | window = 0.5 * (1 - np.cos(2 * np.pi * time)) 169 | segmental_snr = [] 170 | 171 | for frame_count in range(int(num_frames)): 172 | # (1) get the frames for the test and ref speech. 173 | # Apply Hanning Window 174 | clean_frame = clean_speech[start:start+winlength] 175 | processed_frame = processed_speech[start:start+winlength] 176 | clean_frame = clean_frame * window 177 | processed_frame = processed_frame * window 178 | 179 | # (2) Compute Segmental SNR 180 | signal_energy = np.sum(clean_frame ** 2) 181 | noise_energy = np.sum((clean_frame - processed_frame) ** 2) 182 | segmental_snr.append(10 * np.log10(signal_energy / (noise_energy + eps)+ eps)) 183 | segmental_snr[-1] = max(segmental_snr[-1], MIN_SNR) 184 | segmental_snr[-1] = min(segmental_snr[-1], MAX_SNR) 185 | start += int(skiprate) 186 | return overall_snr, segmental_snr 187 | 188 | def wss(ref_wav, deg_wav, srate): 189 | clean_speech = ref_wav 190 | processed_speech = deg_wav 191 | clean_length = ref_wav.shape[0] 192 | processed_length = deg_wav.shape[0] 193 | 194 | assert clean_length == processed_length, clean_length 195 | 196 | winlength = round(30 * srate / 1000.) # 240 wlen in samples 197 | skiprate = np.floor(winlength / 4) 198 | max_freq = srate / 2 199 | num_crit = 25 # num of critical bands 200 | 201 | USE_FFT_SPECTRUM = 1 202 | n_fft = int(2 ** np.ceil(np.log(2*winlength)/np.log(2))) 203 | n_fftby2 = int(n_fft / 2) 204 | Kmax = 20 205 | Klocmax = 1 206 | 207 | # Critical band filter definitions (Center frequency and BW in Hz) 208 | cent_freq = [50., 120, 190, 260, 330, 400, 470, 540, 617.372, 209 | 703.378, 798.717, 904.128, 1020.38, 1148.30, 210 | 1288.72, 1442.54, 1610.70, 1794.16, 1993.93, 211 | 2211.08, 2446.71, 2701.97, 2978.04, 3276.17, 212 | 3597.63] 213 | bandwidth = [70., 70, 70, 70, 70, 70, 70, 77.3724, 86.0056, 214 | 95.3398, 105.411, 116.256, 127.914, 140.423, 215 | 153.823, 168.154, 183.457, 199.776, 217.153, 216 | 235.631, 255.255, 276.072, 298.126, 321.465, 217 | 346.136] 218 | 219 | bw_min = bandwidth[0] # min critical bandwidth 220 | 221 | # set up critical band filters. Note here that Gaussianly shaped filters 222 | # are used. Also, the sum of the filter weights are equivalent for each 223 | # critical band filter. Filter less than -30 dB and set to zero. 224 | min_factor = np.exp(-30. / (2 * 2.303)) # -30 dB point of filter 225 | 226 | crit_filter = np.zeros((num_crit, n_fftby2)) 227 | all_f0 = [] 228 | for i in range(num_crit): 229 | f0 = (cent_freq[i] / max_freq) * (n_fftby2) 230 | all_f0.append(np.floor(f0)) 231 | bw = (bandwidth[i] / max_freq) * (n_fftby2) 232 | norm_factor = np.log(bw_min) - np.log(bandwidth[i]) 233 | j = list(range(n_fftby2)) 234 | crit_filter[i, :] = np.exp(-11 * (((j - np.floor(f0)) / bw) ** 2) + \ 235 | norm_factor) 236 | crit_filter[i, :] = crit_filter[i, :] * (crit_filter[i, :] > \ 237 | min_factor) 238 | 239 | # For each frame of input speech, compute Weighted Spectral Slope Measure 240 | num_frames = int(clean_length / skiprate - (winlength / skiprate)) 241 | start = 0 # starting sample 242 | time = np.linspace(1, winlength, winlength) / (winlength + 1) 243 | window = 0.5 * (1 - np.cos(2 * np.pi * time)) 244 | distortion = [] 245 | 246 | for frame_count in range(num_frames): 247 | # (1) Get the Frames for the test and reference speeech. 248 | # Multiply by Hanning window. 249 | clean_frame = clean_speech[start:start+winlength] 250 | processed_frame = processed_speech[start:start+winlength] 251 | clean_frame = clean_frame * window 252 | processed_frame = processed_frame * window 253 | 254 | # (2) Compuet Power Spectrum of clean and processed 255 | clean_spec = (np.abs(np.fft.fft(clean_frame, n_fft)) ** 2) 256 | processed_spec = (np.abs(np.fft.fft(processed_frame, n_fft)) ** 2) 257 | clean_energy = [None] * num_crit 258 | processed_energy = [None] * num_crit 259 | 260 | # (3) Compute Filterbank output energies (in dB) 261 | for i in range(num_crit): 262 | clean_energy[i] = np.sum(clean_spec[:n_fftby2] * \ 263 | crit_filter[i, :]) 264 | processed_energy[i] = np.sum(processed_spec[:n_fftby2] * \ 265 | crit_filter[i, :]) 266 | clean_energy = np.array(clean_energy).reshape(-1, 1) 267 | eps = np.ones((clean_energy.shape[0], 1)) * 1e-10 268 | clean_energy = np.concatenate((clean_energy, eps), axis=1) 269 | clean_energy = 10 * np.log10(np.max(clean_energy, axis=1)) 270 | processed_energy = np.array(processed_energy).reshape(-1, 1) 271 | processed_energy = np.concatenate((processed_energy, eps), axis=1) 272 | processed_energy = 10 * np.log10(np.max(processed_energy, axis=1)) 273 | 274 | # (4) Compute Spectral Shape (dB[i+1] - dB[i]) 275 | clean_slope = clean_energy[1:num_crit] - clean_energy[:num_crit-1] 276 | processed_slope = processed_energy[1:num_crit] - \ 277 | processed_energy[:num_crit-1] 278 | 279 | # (5) Find the nearest peak locations in the spectra to each 280 | # critical band. If the slope is negative, we search 281 | # to the left. If positive, we search to the right. 282 | clean_loc_peak = [] 283 | processed_loc_peak = [] 284 | for i in range(num_crit - 1): 285 | if clean_slope[i] > 0: 286 | # search to the right 287 | n = i 288 | while n < num_crit - 1 and clean_slope[n] > 0: 289 | n += 1 290 | clean_loc_peak.append(clean_energy[n - 1]) 291 | else: 292 | # search to the left 293 | n = i 294 | while n >= 0 and clean_slope[n] <= 0: 295 | n -= 1 296 | clean_loc_peak.append(clean_energy[n + 1]) 297 | # find the peaks in the processed speech signal 298 | if processed_slope[i] > 0: 299 | n = i 300 | while n < num_crit - 1 and processed_slope[n] > 0: 301 | n += 1 302 | processed_loc_peak.append(processed_energy[n - 1]) 303 | else: 304 | n = i 305 | while n >= 0 and processed_slope[n] <= 0: 306 | n -= 1 307 | processed_loc_peak.append(processed_energy[n + 1]) 308 | 309 | # (6) Compuet the WSS Measure for this frame. This includes 310 | # determination of the weighting functino 311 | dBMax_clean = max(clean_energy) 312 | dBMax_processed = max(processed_energy) 313 | 314 | # The weights are calculated by averaging individual 315 | # weighting factors from the clean and processed frame. 316 | # These weights W_clean and W_processed should range 317 | # from 0 to 1 and place more emphasis on spectral 318 | # peaks and less emphasis on slope differences in spectral 319 | # valleys. This procedure is described on page 1280 of 320 | # Klatt's 1982 ICASSP paper. 321 | clean_loc_peak = np.array(clean_loc_peak) 322 | processed_loc_peak = np.array(processed_loc_peak) 323 | Wmax_clean = Kmax / (Kmax + dBMax_clean - clean_energy[:num_crit-1]) 324 | Wlocmax_clean = Klocmax / (Klocmax + clean_loc_peak - \ 325 | clean_energy[:num_crit-1]) 326 | W_clean = Wmax_clean * Wlocmax_clean 327 | Wmax_processed = Kmax / (Kmax + dBMax_processed - \ 328 | processed_energy[:num_crit-1]) 329 | Wlocmax_processed = Klocmax / (Klocmax + processed_loc_peak - \ 330 | processed_energy[:num_crit-1]) 331 | W_processed = Wmax_processed * Wlocmax_processed 332 | W = (W_clean + W_processed) / 2 333 | distortion.append(np.sum(W * (clean_slope[:num_crit - 1] - \ 334 | processed_slope[:num_crit - 1]) ** 2)) 335 | 336 | # this normalization is not part of Klatt's paper, but helps 337 | # to normalize the meaasure. Here we scale the measure by the sum of the 338 | # weights 339 | distortion[frame_count] = distortion[frame_count] / np.sum(W) 340 | start += int(skiprate) 341 | return distortion 342 | 343 | 344 | def llr(ref_wav, deg_wav, srate): 345 | clean_speech = ref_wav 346 | processed_speech = deg_wav 347 | clean_length = ref_wav.shape[0] 348 | processed_length = deg_wav.shape[0] 349 | assert clean_length == processed_length, clean_length 350 | 351 | winlength = round(30 * srate / 1000.) # 240 wlen in samples 352 | skiprate = np.floor(winlength / 4) 353 | if srate < 10000: 354 | # LPC analysis order 355 | P = 10 356 | else: 357 | P = 16 358 | 359 | # For each frame of input speech, calculate the Log Likelihood Ratio 360 | num_frames = int(clean_length / skiprate - (winlength / skiprate)) 361 | start = 0 362 | time = np.linspace(1, winlength, winlength) / (winlength + 1) 363 | window = 0.5 * (1 - np.cos(2 * np.pi * time)) 364 | distortion = [] 365 | 366 | for frame_count in range(num_frames): 367 | # (1) Get the Frames for the test and reference speeech. 368 | # Multiply by Hanning window. 369 | clean_frame = clean_speech[start:start+winlength] 370 | processed_frame = processed_speech[start:start+winlength] 371 | clean_frame = clean_frame * window 372 | processed_frame = processed_frame * window 373 | 374 | # (2) Get the autocorrelation logs and LPC params used 375 | # to compute the LLR measure 376 | R_clean, Ref_clean, A_clean = lpcoeff(clean_frame, P) 377 | R_processed, Ref_processed, A_processed = lpcoeff(processed_frame, P) 378 | A_clean = A_clean[None, :] 379 | A_processed = A_processed[None, :] 380 | 381 | # (3) Compute the LLR measure 382 | numerator = A_processed.dot(toeplitz(R_clean)).dot(A_processed.T) 383 | denominator = A_clean.dot(toeplitz(R_clean)).dot(A_clean.T) 384 | 385 | if (numerator/denominator) <= 0: 386 | print(f'Numerator: {numerator}') 387 | print(f'Denominator: {denominator}') 388 | 389 | log_ = np.log(numerator / denominator) 390 | distortion.append(np.squeeze(log_)) 391 | start += int(skiprate) 392 | return np.nan_to_num(np.array(distortion)) 393 | # -------------------------------------------------------------------------- # 394 | 395 | -------------------------------------------------------------------------------- /data/make_dataset_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | def list_files_in_directory(directory_path): 6 | # List all files in the directory 7 | files = [] 8 | for root, dirs, filenames in os.walk(directory_path): 9 | for filename in filenames: 10 | if filename.endswith('.wav'): 11 | files.append(os.path.join(root, filename)) 12 | return files 13 | 14 | def save_files_to_json(files, output_file): 15 | with open(output_file, 'w') as json_file: 16 | json.dump(files, json_file, indent=4) 17 | 18 | def make_json(directory_path, output_file): 19 | # Get the list of files and save to JSON 20 | files = list_files_in_directory(directory_path) 21 | save_files_to_json(files, output_file) 22 | 23 | # create training set json 24 | 25 | def main(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--path', default="/Work21/2024/wangjunyu/pythonproject/Data_16k/") 28 | 29 | args = parser.parse_args() 30 | 31 | prepath = args.path if (args.path is not None) else "../" 32 | 33 | ## train_clean 34 | make_json( 35 | os.path.join(prepath, 'clean_train/'), 36 | 'train_clean.json' 37 | ) 38 | 39 | ## train_noisy 40 | make_json( 41 | os.path.join(prepath, 'noisy_train/'), 42 | 'train_noisy.json' 43 | ) 44 | 45 | ## valid_clean 46 | make_json( 47 | os.path.join(prepath, 'clean_valid/'), 48 | 'valid_clean.json' 49 | ) 50 | 51 | ## valid_noisy 52 | make_json( 53 | os.path.join(prepath, 'noisy_valid/'), 54 | 'valid_noisy.json' 55 | ) 56 | 57 | ## test_clean 58 | make_json( 59 | os.path.join(prepath, 'clean_test/'), 60 | 'test_clean.json' 61 | ) 62 | 63 | ## test_noisy 64 | make_json( 65 | os.path.join(prepath, 'noisy_test/'), 66 | 'test_noisy.json' 67 | ) 68 | # ----------------------------------------------------------# 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /dataloaders/dataloader_vctk.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import librosa 7 | from models.stfts import mag_phase_stft, mag_phase_istft 8 | from models.pcs400 import cal_pcs 9 | 10 | def list_files_in_directory(directory_path): 11 | files = [] 12 | for root, dirs, filenames in os.walk(directory_path): 13 | for filename in filenames: 14 | if filename.endswith('.wav'): # only add .wav files 15 | files.append(os.path.join(root, filename)) 16 | return files 17 | 18 | def load_json_file(file_path): 19 | with open(file_path, 'r') as json_file: 20 | data = json.load(json_file) 21 | return data 22 | 23 | def extract_identifier(file_path): 24 | return os.path.basename(file_path) 25 | 26 | def get_clean_path_for_noisy(noisy_file_path, clean_path_dict): 27 | identifier = extract_identifier(noisy_file_path) 28 | return clean_path_dict.get(identifier, None) 29 | 30 | class VCTKDemandDataset(torch.utils.data.Dataset): 31 | """ 32 | Dataset for loading clean and noisy audio files. 33 | 34 | Args: 35 | clean_wavs_json (str): Directory containing clean audio files. 36 | noisy_wavs_json (str): Directory containing noisy audio files. 37 | audio_index_file (str): File containing audio indexes. 38 | sampling_rate (int, optional): Sampling rate of the audio files. Defaults to 16000. 39 | segment_size (int, optional): Size of the audio segments. Defaults to 32000. 40 | n_fft (int, optional): FFT size. Defaults to 400. 41 | hop_size (int, optional): Hop size. Defaults to 100. 42 | win_size (int, optional): Window size. Defaults to 400. 43 | compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0. 44 | split (bool, optional): Whether to split the audio into segments. Defaults to True. 45 | n_cache_reuse (int, optional): Number of times to reuse cached audio. Defaults to 1. 46 | device (torch.device, optional): Target device. Defaults to None 47 | pcs (bool, optional): Use PCS in training period. Defaults to False 48 | """ 49 | def __init__( 50 | self, 51 | clean_json, 52 | noisy_json, 53 | sampling_rate=16000, 54 | segment_size=32000, 55 | n_fft=400, 56 | hop_size=100, 57 | win_size=400, 58 | compress_factor=1.0, 59 | split=True, 60 | n_cache_reuse=1, 61 | shuffle=True, 62 | device=None, 63 | pcs=False 64 | ): 65 | 66 | self.clean_wavs_path = load_json_file( clean_json ) 67 | self.noisy_wavs_path = load_json_file( noisy_json ) 68 | random.seed(1234) 69 | 70 | if shuffle: 71 | random.shuffle(self.noisy_wavs_path) 72 | self.clean_path_dict = {extract_identifier(clean_path): clean_path for clean_path in self.clean_wavs_path} 73 | 74 | self.sampling_rate = sampling_rate 75 | self.segment_size = segment_size 76 | self.n_fft = n_fft 77 | self.hop_size = hop_size 78 | self.win_size = win_size 79 | self.compress_factor = compress_factor 80 | self.split = split 81 | self.n_cache_reuse = n_cache_reuse 82 | 83 | self.cached_clean_wav = None 84 | self.cached_noisy_wav = None 85 | self._cache_ref_count = 0 86 | self.device = device 87 | self.pcs = pcs 88 | 89 | def __getitem__(self, index): 90 | """ 91 | Get an audio sample by index. 92 | 93 | Args: 94 | index (int): Index of the audio sample. 95 | 96 | Returns: 97 | tuple: clean audio, clean magnitude, clean phase, clean complex, noisy magnitude, noisy phase 98 | """ 99 | if self._cache_ref_count == 0: 100 | noisy_path = self.noisy_wavs_path[index] 101 | clean_path = get_clean_path_for_noisy(noisy_path, self.clean_path_dict) 102 | noisy_audio, _ = librosa.load( noisy_path, sr=self.sampling_rate) 103 | clean_audio, _ = librosa.load( clean_path, sr=self.sampling_rate) 104 | if self.pcs == True: 105 | clean_audio = cal_pcs(clean_audio) 106 | self.cached_noisy_wav = noisy_audio 107 | self.cached_clean_wav = clean_audio 108 | self._cache_ref_count = self.n_cache_reuse 109 | else: 110 | clean_audio = self.cached_clean_wav 111 | noisy_audio = self.cached_noisy_wav 112 | self._cache_ref_count -= 1 113 | 114 | clean_audio, noisy_audio = torch.FloatTensor(clean_audio), torch.FloatTensor(noisy_audio) 115 | norm_factor = torch.sqrt(len(noisy_audio) / torch.sum(noisy_audio ** 2.0)) 116 | clean_audio = (clean_audio * norm_factor).unsqueeze(0) 117 | noisy_audio = (noisy_audio * norm_factor).unsqueeze(0) 118 | 119 | assert clean_audio.size(1) == noisy_audio.size(1) 120 | 121 | if self.split: 122 | if clean_audio.size(1) >= self.segment_size: 123 | max_audio_start = clean_audio.size(1) - self.segment_size 124 | audio_start = random.randint(0, max_audio_start) 125 | clean_audio = clean_audio[:, audio_start:audio_start + self.segment_size] 126 | noisy_audio = noisy_audio[:, audio_start:audio_start + self.segment_size] 127 | else: 128 | clean_audio = torch.nn.functional.pad(clean_audio, (0, self.segment_size - clean_audio.size(1)), 'constant') 129 | noisy_audio = torch.nn.functional.pad(noisy_audio, (0, self.segment_size - noisy_audio.size(1)), 'constant') 130 | 131 | clean_mag, clean_pha, clean_com = mag_phase_stft(clean_audio, self.n_fft, self.hop_size, self.win_size, self.compress_factor) 132 | noisy_mag, noisy_pha, noisy_com = mag_phase_stft(noisy_audio, self.n_fft, self.hop_size, self.win_size, self.compress_factor) 133 | 134 | return (clean_audio.squeeze(), clean_mag.squeeze(), clean_pha.squeeze(), clean_com.squeeze(), noisy_mag.squeeze(), noisy_pha.squeeze()) 135 | 136 | def __len__(self): 137 | return len(self.noisy_wavs_path) 138 | 139 | class Val_Dataset(torch.utils.data.Dataset): 140 | """ 141 | Dataset for loading clean and noisy audio files. 142 | 143 | Args: 144 | clean_wavs_json (str): Directory containing clean audio files. 145 | noisy_wavs_json (str): Directory containing noisy audio files. 146 | audio_index_file (str): File containing audio indexes. 147 | sampling_rate (int, optional): Sampling rate of the audio files. Defaults to 16000. 148 | segment_size (int, optional): Size of the audio segments. Defaults to 32000. 149 | n_fft (int, optional): FFT size. Defaults to 400. 150 | hop_size (int, optional): Hop size. Defaults to 100. 151 | win_size (int, optional): Window size. Defaults to 400. 152 | compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0. 153 | split (bool, optional): Whether to split the audio into segments. Defaults to True. 154 | n_cache_reuse (int, optional): Number of times to reuse cached audio. Defaults to 1. 155 | device (torch.device, optional): Target device. Defaults to None 156 | pcs (bool, optional): Use PCS in training period. Defaults to False 157 | """ 158 | def __init__( 159 | self, 160 | clean_json, 161 | noisy_json, 162 | sampling_rate=16000, 163 | segment_size=32000, 164 | n_fft=400, 165 | hop_size=100, 166 | win_size=400, 167 | compress_factor=1.0, 168 | split=True, 169 | n_cache_reuse=1, 170 | shuffle=True, 171 | device=None, 172 | pcs=False 173 | ): 174 | 175 | self.clean_wavs_path = load_json_file( clean_json ) 176 | self.noisy_wavs_path = load_json_file( noisy_json ) 177 | random.seed(1234) 178 | 179 | if shuffle: 180 | random.shuffle(self.noisy_wavs_path) 181 | self.clean_path_dict = {extract_identifier(clean_path): clean_path for clean_path in self.clean_wavs_path} 182 | 183 | self.sampling_rate = sampling_rate 184 | self.segment_size = segment_size 185 | self.n_fft = n_fft 186 | self.hop_size = hop_size 187 | self.win_size = win_size 188 | self.compress_factor = compress_factor 189 | self.split = split 190 | self.n_cache_reuse = n_cache_reuse 191 | 192 | self.cached_clean_wav = None 193 | self.cached_noisy_wav = None 194 | self._cache_ref_count = 0 195 | self.device = device 196 | self.pcs = pcs 197 | 198 | def __getitem__(self, index): 199 | """ 200 | Get an audio sample by index. 201 | 202 | Args: 203 | index (int): Index of the audio sample. 204 | 205 | Returns: 206 | tuple: clean audio, clean magnitude, clean phase, clean complex, noisy magnitude, noisy phase 207 | """ 208 | if self._cache_ref_count == 0: 209 | noisy_path = self.noisy_wavs_path[index] 210 | clean_path = get_clean_path_for_noisy(noisy_path, self.clean_path_dict) 211 | noisy_audio, _ = librosa.load( noisy_path, sr=self.sampling_rate) 212 | clean_audio, _ = librosa.load( clean_path, sr=self.sampling_rate) 213 | if self.pcs == True: 214 | clean_audio = cal_pcs(clean_audio) 215 | self.cached_noisy_wav = noisy_audio 216 | self.cached_clean_wav = clean_audio 217 | self._cache_ref_count = self.n_cache_reuse 218 | else: 219 | clean_audio = self.cached_clean_wav 220 | noisy_audio = self.cached_noisy_wav 221 | self._cache_ref_count -= 1 222 | 223 | clean_audio, noisy_audio = torch.FloatTensor(clean_audio), torch.FloatTensor(noisy_audio) 224 | norm_factor = torch.sqrt(len(noisy_audio) / torch.sum(noisy_audio ** 2.0)) 225 | clean_audio = (clean_audio * norm_factor).unsqueeze(0) 226 | noisy_audio = (noisy_audio * norm_factor).unsqueeze(0) 227 | 228 | assert clean_audio.size(1) == noisy_audio.size(1) 229 | 230 | if self.split: 231 | if clean_audio.size(1) >= self.segment_size: 232 | max_audio_start = clean_audio.size(1) - self.segment_size 233 | audio_start = random.randint(0, max_audio_start) 234 | clean_audio = clean_audio[:, audio_start:audio_start + self.segment_size] 235 | noisy_audio = noisy_audio[:, audio_start:audio_start + self.segment_size] 236 | else: 237 | clean_audio = torch.nn.functional.pad(clean_audio, (0, self.segment_size - clean_audio.size(1)), 'constant') 238 | noisy_audio = torch.nn.functional.pad(noisy_audio, (0, self.segment_size - noisy_audio.size(1)), 'constant') 239 | 240 | clean_mag, clean_pha, clean_com = mag_phase_stft(clean_audio, self.n_fft, self.hop_size, self.win_size, self.compress_factor) 241 | 242 | return (clean_audio.squeeze(), clean_mag.squeeze(), clean_pha.squeeze(), clean_com.squeeze(), noisy_audio.squeeze()) 243 | 244 | def __len__(self): 245 | return len(self.noisy_wavs_path) -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torch.utils.data 5 | import librosa 6 | 7 | def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): 8 | 9 | hann_window = torch.hann_window(win_size).to(y.device) 10 | stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, 11 | center=center, pad_mode='reflect', normalized=False, return_complex=True) 12 | mag = torch.abs(stft_spec) 13 | pha = torch.angle(stft_spec) 14 | # Magnitude Compression 15 | mag = torch.pow(mag, compress_factor) 16 | com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) 17 | 18 | return mag, pha, com 19 | 20 | 21 | def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): 22 | # Magnitude Decompression 23 | mag = torch.pow(mag, (1.0/compress_factor)) 24 | com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) 25 | hann_window = torch.hann_window(win_size).to(com.device) 26 | wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) 27 | 28 | return wav 29 | 30 | 31 | def get_dataset_filelist(a): 32 | with open(a.input_training_file, 'r', encoding='utf-8') as fi: 33 | training_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] 34 | 35 | with open(a.input_validation_file, 'r', encoding='utf-8') as fi: 36 | validation_indexes = [x.split('|')[0] for x in fi.read().split('\n') if len(x) > 0] 37 | 38 | return training_indexes, validation_indexes 39 | 40 | 41 | class Val_Dataset(torch.utils.data.Dataset): 42 | def __init__(self, training_indexes, clean_wavs_dir, noisy_wavs_dir, segment_size, n_fft, hop_size, win_size, 43 | sampling_rate, compress_factor, split=True, shuffle=True, n_cache_reuse=1, device=None): 44 | self.audio_indexes = training_indexes 45 | random.seed(1234) 46 | if shuffle: 47 | random.shuffle(self.audio_indexes) 48 | self.clean_wavs_dir = clean_wavs_dir 49 | self.noisy_wavs_dir = noisy_wavs_dir 50 | self.segment_size = segment_size 51 | self.sampling_rate = sampling_rate 52 | self.split = split 53 | self.n_fft = n_fft 54 | self.hop_size = hop_size 55 | self.win_size = win_size 56 | self.compress_factor = compress_factor 57 | self.cached_clean_wav = None 58 | self.cached_noisy_wav = None 59 | self.n_cache_reuse = n_cache_reuse 60 | self._cache_ref_count = 0 61 | self.device = device 62 | 63 | def __getitem__(self, index): 64 | filename = self.audio_indexes[index] 65 | if self._cache_ref_count == 0: 66 | clean_audio, _ = librosa.load(os.path.join(self.clean_wavs_dir, filename + '.wav'), self.sampling_rate) 67 | noisy_audio, _ = librosa.load(os.path.join(self.noisy_wavs_dir, filename + '.wav'), self.sampling_rate) 68 | self.cached_clean_wav = clean_audio 69 | self.cached_noisy_wav = noisy_audio 70 | self._cache_ref_count = self.n_cache_reuse 71 | else: 72 | clean_audio = self.cached_clean_wav 73 | noisy_audio = self.cached_noisy_wav 74 | self._cache_ref_count -= 1 75 | 76 | clean_audio, noisy_audio = torch.FloatTensor(clean_audio), torch.FloatTensor(noisy_audio) 77 | norm_factor = torch.sqrt(len(noisy_audio) / torch.sum(noisy_audio ** 2.0)) 78 | clean_audio = (clean_audio * norm_factor).unsqueeze(0) 79 | noisy_audio = (noisy_audio * norm_factor).unsqueeze(0) 80 | 81 | assert clean_audio.size(1) == noisy_audio.size(1) 82 | 83 | return (clean_audio.squeeze(), noisy_audio.squeeze()) 84 | 85 | def __len__(self): 86 | return len(self.audio_indexes) 87 | 88 | 89 | 90 | class Dataset(torch.utils.data.Dataset): 91 | def __init__(self, training_indexes, clean_wavs_dir, noisy_wavs_dir, segment_size, n_fft, hop_size, win_size, 92 | sampling_rate, compress_factor, split=True, shuffle=True, n_cache_reuse=1, device=None): 93 | self.audio_indexes = training_indexes 94 | random.seed(1234) 95 | if shuffle: 96 | random.shuffle(self.audio_indexes) 97 | self.clean_wavs_dir = clean_wavs_dir 98 | self.noisy_wavs_dir = noisy_wavs_dir 99 | self.segment_size = segment_size 100 | self.sampling_rate = sampling_rate 101 | self.split = split 102 | self.n_fft = n_fft 103 | self.hop_size = hop_size 104 | self.win_size = win_size 105 | self.compress_factor = compress_factor 106 | self.cached_clean_wav = None 107 | self.cached_noisy_wav = None 108 | self.n_cache_reuse = n_cache_reuse 109 | self._cache_ref_count = 0 110 | self.device = device 111 | 112 | def __getitem__(self, index): 113 | filename = self.audio_indexes[index] 114 | if self._cache_ref_count == 0: 115 | clean_audio, _ = librosa.load(os.path.join(self.clean_wavs_dir, filename + '.wav'), self.sampling_rate) 116 | noisy_audio, _ = librosa.load(os.path.join(self.noisy_wavs_dir, filename + '.wav'), self.sampling_rate) 117 | self.cached_clean_wav = clean_audio 118 | self.cached_noisy_wav = noisy_audio 119 | self._cache_ref_count = self.n_cache_reuse 120 | else: 121 | clean_audio = self.cached_clean_wav 122 | noisy_audio = self.cached_noisy_wav 123 | self._cache_ref_count -= 1 124 | 125 | clean_audio, noisy_audio = torch.FloatTensor(clean_audio), torch.FloatTensor(noisy_audio) 126 | norm_factor = torch.sqrt(len(noisy_audio) / torch.sum(noisy_audio ** 2.0)) 127 | clean_audio = (clean_audio * norm_factor).unsqueeze(0) 128 | noisy_audio = (noisy_audio * norm_factor).unsqueeze(0) 129 | 130 | assert clean_audio.size(1) == noisy_audio.size(1) 131 | 132 | if self.split: 133 | if clean_audio.size(1) >= self.segment_size: 134 | max_audio_start = clean_audio.size(1) - self.segment_size 135 | rand_num = random.random() 136 | 137 | if rand_num < 0.02: 138 | audio_start = 0 139 | elif rand_num < 0.02: 140 | audio_start = max_audio_start 141 | else: 142 | audio_start = random.randint(0, max_audio_start) 143 | 144 | clean_audio = clean_audio[:, audio_start: audio_start + self.segment_size] 145 | noisy_audio = noisy_audio[:, audio_start: audio_start + self.segment_size] 146 | else: 147 | clean_audio = torch.nn.functional.pad(clean_audio, (0, self.segment_size - clean_audio.size(1)), 148 | 'constant') 149 | noisy_audio = torch.nn.functional.pad(noisy_audio, (0, self.segment_size - noisy_audio.size(1)), 150 | 'constant') 151 | 152 | clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, self.n_fft, self.hop_size, self.win_size, self.compress_factor) #[1, n_fft/2+1, frames] 153 | noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, self.n_fft, self.hop_size, self.win_size, self.compress_factor) #[1, n_fft/2+1, frames] 154 | 155 | return (clean_audio.squeeze(), clean_mag.squeeze(), clean_pha.squeeze(), clean_com.squeeze(), noisy_audio.squeeze(), noisy_mag.squeeze(), noisy_pha.squeeze()) 156 | 157 | def __len__(self): 158 | return len(self.audio_indexes) 159 | -------------------------------------------------------------------------------- /flops.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter(action='ignore', category=FutureWarning) 3 | import os 4 | import time 5 | import argparse 6 | import json 7 | import torch 8 | from models.generator import MambaSEUNet 9 | from thop import profile 10 | from utils.util import load_config 11 | 12 | torch.backends.cudnn.benchmark = True 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--exp_folder', default='exp') 17 | parser.add_argument('--exp_name', default='MambaSEUNet_emb_32') 18 | parser.add_argument('--config', default='recipes/Mamba-SEUNet/Mamba-SEUNet.yaml') 19 | args = parser.parse_args() 20 | 21 | cfg = load_config(args.config) 22 | 23 | device = torch.device('cuda:{:d}'.format(0)) 24 | 25 | model = MambaSEUNet(cfg).to(device) 26 | 27 | num_params = sum(p.numel() for p in model.parameters()) 28 | 29 | print(f"Manual calculation of parameters: {num_params}") 30 | 31 | with torch.no_grad(): 32 | dummy_input1 = torch.rand(1, 256, 256).to(device) 33 | dummy_input2 = torch.rand(1, 256, 256).to(device) 34 | flops, params = profile(model, inputs=(dummy_input1, dummy_input2)) 35 | print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') 36 | print('Params = ' + str(params / 1000 ** 2) + 'M') 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import argparse 4 | import json 5 | import torch 6 | import librosa 7 | from models.stfts import mag_phase_stft, mag_phase_istft 8 | from datasets.dataset import mag_pha_stft, mag_pha_istft 9 | from cal_metrics.compute_metrics import compute_metrics 10 | from models.generator import MambaSEUNet 11 | from models.pcs400 import cal_pcs 12 | import soundfile as sf 13 | 14 | import numpy as np 15 | 16 | from utils.util import load_config 17 | 18 | 19 | def load_checkpoint(filepath, device): 20 | print("Loading '{}'".format(filepath)) 21 | checkpoint_dict = torch.load(filepath, map_location=device) 22 | print("Complete.") 23 | return checkpoint_dict 24 | 25 | def scan_checkpoint(cp_dir, prefix): 26 | pattern = os.path.join(cp_dir, prefix + '*') 27 | cp_list = glob.glob(pattern) 28 | if len(cp_list) == 0: 29 | return '' 30 | return sorted(cp_list)[-1] 31 | 32 | # handle audio slicing 33 | def process_audio_segment(noisy_wav, model, device, n_fft, hop_size, win_size, compress_factor, sampling_rate, segment_size): 34 | segment_size = segment_size 35 | n_fft = n_fft 36 | hop_size = hop_size 37 | win_size = win_size 38 | compress_factor = compress_factor 39 | sampling_rate = sampling_rate 40 | 41 | norm_factor = torch.sqrt(len(noisy_wav) / torch.sum(noisy_wav ** 2.0)).to(device) 42 | noisy_wav = (noisy_wav * norm_factor).unsqueeze(0) 43 | orig_size = noisy_wav.size(1) 44 | 45 | # whether zeros need to be padded 46 | if noisy_wav.size(1) >= segment_size: 47 | num_segments = noisy_wav.size(1) // segment_size 48 | last_segment_size = noisy_wav.size(1) % segment_size 49 | if last_segment_size > 0: 50 | last_segment = noisy_wav[:, -segment_size:] 51 | noisy_wav = noisy_wav[:, :-last_segment_size] 52 | segments = torch.split(noisy_wav, segment_size, dim=1) 53 | segments = list(segments) 54 | segments.append(last_segment) 55 | reshapelast=1 56 | else: 57 | segments = torch.split(noisy_wav, segment_size, dim=1) 58 | reshapelast = 0 59 | 60 | else: 61 | # padding 62 | padded_zeros = torch.zeros(1, segment_size - noisy_wav.size(1)).to(device) 63 | noisy_wav = torch.cat((noisy_wav, padded_zeros), dim=1) 64 | segments = [noisy_wav] 65 | reshapelast = 0 66 | 67 | processed_segments = [] 68 | 69 | for i, segment in enumerate(segments): 70 | 71 | noisy_amp, noisy_pha, noisy_com = mag_pha_stft(segment, n_fft, hop_size, win_size, compress_factor) 72 | amp_g, pha_g, com_g = model(noisy_amp.to(device, non_blocking=True), noisy_pha.to(device, non_blocking=True)) 73 | audio_g = mag_pha_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_factor) 74 | audio_g = audio_g / norm_factor 75 | audio_g = audio_g.squeeze() 76 | if reshapelast == 1 and i == len(segments) - 2: 77 | audio_g = audio_g[ :-(segment_size-last_segment_size)] 78 | 79 | processed_segments.append(audio_g) 80 | 81 | processed_audio = torch.cat(processed_segments, dim=-1) 82 | print(processed_audio.size()) 83 | 84 | processed_audio = processed_audio[:orig_size] 85 | print(processed_audio.size()) 86 | print(orig_size) 87 | 88 | return processed_audio 89 | 90 | def inference(args, device): 91 | cfg = load_config(args.config) 92 | n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size'] 93 | compress_factor = cfg['model_cfg']['compress_factor'] 94 | sampling_rate = cfg['stft_cfg']['sampling_rate'] 95 | segment_size = cfg['training_cfg']['segment_size'] 96 | 97 | model = MambaSEUNet(cfg).to(device) 98 | state_dict = load_checkpoint(args.checkpoint_file, device) 99 | model.load_state_dict(state_dict['generator']) 100 | 101 | os.makedirs(args.output_folder, exist_ok=True) 102 | 103 | model.eval() 104 | 105 | metrics_total = np.zeros(6) 106 | count = 0 107 | 108 | with torch.no_grad(): 109 | for i, fname in enumerate(os.listdir(args.input_clean_wavs_dir)): 110 | print(fname, args.input_clean_wavs_dir) 111 | noisy_wav, _ = librosa.load(os.path.join(args.input_noisy_wavs_dir, fname), sr=sampling_rate) 112 | noisy_wav = torch.FloatTensor(noisy_wav).to(device) 113 | output_audio = process_audio_segment(noisy_wav, model, device, n_fft, hop_size, win_size, compress_factor, sampling_rate, segment_size) 114 | if args.post_processing_PCS == True: 115 | output_audio = cal_pcs(output_audio.squeeze().cpu().numpy()) 116 | output_file = os.path.join(args.output_folder, fname) 117 | sf.write(output_file, output_audio.cpu().numpy(), sampling_rate, 'PCM_16') 118 | 119 | clean_wav, sr = librosa.load(os.path.join(args.input_clean_wavs_dir, fname), sr=sampling_rate) 120 | out1 = output_audio.cpu() 121 | output_audio = out1.numpy() 122 | 123 | metrics = compute_metrics(clean_wav, output_audio, sr, 0) 124 | metrics = np.array(metrics) 125 | metrics_total += metrics 126 | count += 1 127 | 128 | metrics_avg = metrics_total / count 129 | print('pesq: ', metrics_avg[0], 'csig: ', metrics_avg[1], 'cbak: ', metrics_avg[2], 130 | 'covl: ', metrics_avg[3], 'ssnr: ', metrics_avg[4], 'stoi: ', metrics_avg[5]) 131 | 132 | 133 | def main(): 134 | print('Initializing Inference Process..') 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('--input_clean_wavs_dir', default='../Data_16k/clean_test') 137 | parser.add_argument('--input_noisy_wavs_dir', default='../Data_16k/noisy_test') 138 | parser.add_argument('--output_folder', default='results/g_best') 139 | parser.add_argument('--config', default='recipes/Mamba-SEUNet/Mamba-SEUNet.yaml') 140 | parser.add_argument('--checkpoint_file', default='ckpts/g_best.pth') 141 | parser.add_argument('--post_processing_PCS', default=False) 142 | args = parser.parse_args() 143 | 144 | global device 145 | if torch.cuda.is_available(): 146 | device = torch.device('cuda') 147 | else: 148 | #device = torch.device('cpu') 149 | raise RuntimeError("Currently, CPU mode is not supported.") 150 | 151 | inference(args, device) 152 | 153 | if __name__ == '__main__': 154 | main() 155 | 156 | -------------------------------------------------------------------------------- /models/codec_module.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | from .lsigmoid import LearnableSigmoid2D 6 | 7 | def get_padding(kernel_size, dilation=1): 8 | """ 9 | Calculate the padding size for a convolutional layer. 10 | 11 | Args: 12 | - kernel_size (int): Size of the convolutional kernel. 13 | - dilation (int, optional): Dilation rate of the convolution. Defaults to 1. 14 | 15 | Returns: 16 | - int: Calculated padding size. 17 | """ 18 | return int((kernel_size * dilation - dilation) / 2) 19 | 20 | def get_padding_2d(kernel_size, dilation=(1, 1)): 21 | """ 22 | Calculate the padding size for a 2D convolutional layer. 23 | 24 | Args: 25 | - kernel_size (tuple): Size of the convolutional kernel (height, width). 26 | - dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1). 27 | 28 | Returns: 29 | - tuple: Calculated padding size (height, width). 30 | """ 31 | return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2), 32 | int((kernel_size[1] * dilation[1] - dilation[1]) / 2)) 33 | 34 | class DenseBlock(nn.Module): 35 | """ 36 | DenseBlock module consisting of multiple convolutional layers with dilation. 37 | """ 38 | def __init__(self, cfg, kernel_size=(3, 3), depth=4): 39 | super(DenseBlock, self).__init__() 40 | self.cfg = cfg 41 | self.depth = depth 42 | self.dense_block = nn.ModuleList() 43 | self.hid_feature = cfg['model_cfg']['hid_feature'] 44 | 45 | for i in range(depth): 46 | dil = 2 ** i 47 | dense_conv = nn.Sequential( 48 | nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size, 49 | dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))), 50 | nn.InstanceNorm2d(self.hid_feature, affine=True), 51 | nn.PReLU(self.hid_feature) 52 | ) 53 | self.dense_block.append(dense_conv) 54 | 55 | def forward(self, x): 56 | """ 57 | Forward pass for the DenseBlock module. 58 | 59 | Args: 60 | - x (torch.Tensor): Input tensor. 61 | 62 | Returns: 63 | - torch.Tensor: Output tensor after processing through the dense block. 64 | """ 65 | skip = x 66 | for i in range(self.depth): 67 | x = self.dense_block[i](skip) 68 | skip = torch.cat([x, skip], dim=1) 69 | return x 70 | 71 | class DenseEncoder(nn.Module): 72 | """ 73 | DenseEncoder module consisting of initial convolution, dense block, and a final convolution. 74 | """ 75 | def __init__(self, cfg): 76 | super(DenseEncoder, self).__init__() 77 | self.cfg = cfg 78 | self.input_channel = cfg['model_cfg']['input_channel'] 79 | self.hid_feature = cfg['model_cfg']['hid_feature'] 80 | 81 | self.dense_conv_1 = nn.Sequential( 82 | nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)), 83 | nn.InstanceNorm2d(self.hid_feature, affine=True), 84 | nn.PReLU(self.hid_feature) 85 | ) 86 | 87 | self.dense_block = DenseBlock(cfg, depth=4) 88 | 89 | self.dense_conv_2 = nn.Sequential( 90 | nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2), padding=(0, 1)), 91 | nn.InstanceNorm2d(self.hid_feature, affine=True), 92 | nn.PReLU(self.hid_feature) 93 | ) 94 | 95 | def forward(self, x): 96 | """ 97 | Forward pass for the DenseEncoder module. 98 | 99 | Args: 100 | - x (torch.Tensor): Input tensor. 101 | 102 | Returns: 103 | - torch.Tensor: Encoded tensor. 104 | """ 105 | x = self.dense_conv_1(x) # [batch, hid_feature, time, freq] 106 | x = self.dense_block(x) # [batch, hid_feature, time, freq] 107 | x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2] 108 | return x 109 | 110 | class MagDecoder(nn.Module): 111 | """ 112 | MagDecoder module for decoding magnitude information. 113 | """ 114 | def __init__(self, cfg): 115 | super(MagDecoder, self).__init__() 116 | self.dense_block = DenseBlock(cfg, depth=4) 117 | self.hid_feature = cfg['model_cfg']['hid_feature'] 118 | self.output_channel = cfg['model_cfg']['output_channel'] 119 | self.n_fft = cfg['stft_cfg']['n_fft'] 120 | self.beta = cfg['model_cfg']['beta'] 121 | 122 | self.mask_conv = nn.Sequential( 123 | nn.Conv2d(self.hid_feature, self.hid_feature * 4, 1, 1, 0, bias=False), 124 | nn.PixelShuffle(2), 125 | nn.Conv2d(self.hid_feature, self.hid_feature, kernel_size=(1, 3), stride=(2, 1), padding=(0, 1), 126 | groups=self.hid_feature, bias=False), 127 | nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)), 128 | nn.InstanceNorm2d(self.output_channel, affine=True), 129 | nn.PReLU(self.output_channel), 130 | nn.Conv2d(self.output_channel, self.output_channel, (1, 1)) 131 | ) 132 | self.lsigmoid = LearnableSigmoid2D(self.n_fft // 2 + 1, beta=self.beta) 133 | 134 | def forward(self, x): 135 | """ 136 | Forward pass for the MagDecoder module. 137 | 138 | Args: 139 | - x (torch.Tensor): Input tensor. 140 | 141 | Returns: 142 | - torch.Tensor: Decoded tensor with magnitude information. 143 | """ 144 | x = self.dense_block(x) 145 | x = self.mask_conv(x) 146 | x = rearrange(x, 'b c t f -> b f t c').squeeze(-1) 147 | x = self.lsigmoid(x) 148 | x = rearrange(x, 'b f t -> b t f').unsqueeze(1) 149 | return x 150 | 151 | class PhaseDecoder(nn.Module): 152 | """ 153 | PhaseDecoder module for decoding phase information. 154 | """ 155 | def __init__(self, cfg): 156 | super(PhaseDecoder, self).__init__() 157 | self.dense_block = DenseBlock(cfg, depth=4) 158 | self.hid_feature = cfg['model_cfg']['hid_feature'] 159 | self.output_channel = cfg['model_cfg']['output_channel'] 160 | 161 | self.phase_conv = nn.Sequential( 162 | nn.Conv2d(self.hid_feature, self.hid_feature * 4, 1, 1, 0, bias=False), 163 | nn.PixelShuffle(2), 164 | nn.Conv2d(self.hid_feature, self.hid_feature, kernel_size=(1, 3), stride=(2, 1), padding=(0, 1), 165 | groups=self.hid_feature, bias=False), 166 | nn.InstanceNorm2d(self.hid_feature, affine=True), 167 | nn.PReLU(self.hid_feature) 168 | ) 169 | 170 | self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)) 171 | self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)) 172 | 173 | def forward(self, x): 174 | """ 175 | Forward pass for the PhaseDecoder module. 176 | 177 | Args: 178 | - x (torch.Tensor): Input tensor. 179 | 180 | Returns: 181 | - torch.Tensor: Decoded tensor with phase information. 182 | """ 183 | x = self.dense_block(x) 184 | x = self.phase_conv(x) 185 | x_r = self.phase_conv_r(x) 186 | x_i = self.phase_conv_i(x) 187 | x = torch.atan2(x_i, x_r) 188 | return x 189 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | # References: https://github.com/yxlu-0102/MP-SENet/blob/main/models/discriminator.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from pesq import pesq 7 | from joblib import Parallel, delayed 8 | from models.lsigmoid import LearnableSigmoid1D 9 | 10 | def pesq_loss(clean, noisy, sr=16000): 11 | try: 12 | pesq_score = pesq(sr, clean, noisy, 'wb') 13 | except: 14 | # error can happen due to silent period 15 | pesq_score = -1 16 | return pesq_score 17 | 18 | 19 | def batch_pesq(clean, noisy, cfg): 20 | num_worker = cfg['env_setting']['num_workers'] 21 | pesq_score = Parallel(n_jobs=num_worker)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy)) 22 | pesq_score = np.array(pesq_score) 23 | if -1 in pesq_score: 24 | return None 25 | pesq_score = (pesq_score - 1) / 3.5 26 | return torch.FloatTensor(pesq_score) 27 | 28 | 29 | class MetricDiscriminator(nn.Module): 30 | def __init__(self, dim=16, in_channel=2): 31 | super(MetricDiscriminator, self).__init__() 32 | self.layers = nn.Sequential( 33 | nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), 34 | nn.InstanceNorm2d(dim, affine=True), 35 | nn.PReLU(dim), 36 | nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), 37 | nn.InstanceNorm2d(dim*2, affine=True), 38 | nn.PReLU(dim*2), 39 | nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), 40 | nn.InstanceNorm2d(dim*4, affine=True), 41 | nn.PReLU(dim*4), 42 | nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), 43 | nn.InstanceNorm2d(dim*8, affine=True), 44 | nn.PReLU(dim*8), 45 | nn.AdaptiveMaxPool2d(1), 46 | nn.Flatten(), 47 | nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), 48 | nn.Dropout(0.3), 49 | nn.PReLU(dim*4), 50 | nn.utils.spectral_norm(nn.Linear(dim*4, 1)), 51 | LearnableSigmoid1D(1) 52 | ) 53 | 54 | def forward(self, x, y): 55 | xy = torch.stack((x, y), dim=1) 56 | return self.layers(xy) 57 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/huaidanquede/MUSE-Speech-Enhancement/tree/main/models/generator 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from torchvision.ops.deform_conv import DeformConv2d 7 | from einops import rearrange 8 | from .mamba_block import TFMambaBlock 9 | from .codec_module import DenseEncoder, MagDecoder, PhaseDecoder 10 | 11 | ##################################### 12 | class DWConv2d_BN(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | in_ch, 17 | out_ch, 18 | kernel_size=1, 19 | stride=1, 20 | norm_layer=nn.BatchNorm2d, 21 | act_layer=nn.Hardswish, 22 | bn_weight_init=1, 23 | offset_clamp=(-1, 1) 24 | ): 25 | super().__init__() 26 | 27 | self.offset_clamp = offset_clamp 28 | self.offset_generator = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=3, 29 | stride=1, padding=1, bias=False, groups=in_ch), 30 | nn.Conv2d(in_channels=in_ch, out_channels=18, 31 | kernel_size=1, 32 | stride=1, padding=0, bias=False) 33 | ) 34 | self.dcn = DeformConv2d( 35 | in_channels=in_ch, 36 | out_channels=in_ch, 37 | kernel_size=3, 38 | stride=1, 39 | padding=1, 40 | bias=False, 41 | groups=in_ch 42 | ) 43 | self.pwconv = nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False) 44 | self.act = act_layer() if act_layer is not None else nn.Identity() 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | 52 | def forward(self, x): 53 | offset = self.offset_generator(x) 54 | 55 | if self.offset_clamp: 56 | offset = torch.clamp(offset, min=self.offset_clamp[0], max=self.offset_clamp[1]) 57 | x = self.dcn(x, offset) 58 | 59 | x = self.pwconv(x) 60 | x = self.act(x) 61 | return x 62 | 63 | 64 | class MB_Deform_Embedding(nn.Module): 65 | 66 | def __init__(self, 67 | in_chans=3, 68 | embed_dim=768, 69 | patch_size=16, 70 | stride=1, 71 | act_layer=nn.Hardswish, 72 | offset_clamp=(-1, 1)): 73 | super().__init__() 74 | 75 | self.patch_conv = DWConv2d_BN( 76 | in_chans, 77 | embed_dim, 78 | kernel_size=patch_size, 79 | stride=stride, 80 | act_layer=act_layer, 81 | offset_clamp=offset_clamp 82 | ) 83 | 84 | def forward(self, x): 85 | """foward function""" 86 | x = self.patch_conv(x) 87 | 88 | return x 89 | 90 | 91 | class Patch_Embed_stage(nn.Module): 92 | """Depthwise Convolutional Patch Embedding stage comprised of 93 | `DWCPatchEmbed` layers.""" 94 | 95 | def __init__(self, in_chans, embed_dim, isPool=False, offset_clamp=(-1, 1)): 96 | super(Patch_Embed_stage, self).__init__() 97 | 98 | self.patch_embeds = MB_Deform_Embedding( 99 | in_chans=in_chans, 100 | embed_dim=embed_dim, 101 | patch_size=3, 102 | stride=1, 103 | offset_clamp=offset_clamp) 104 | 105 | def forward(self, x): 106 | """foward function""" 107 | 108 | att_inputs = self.patch_embeds(x) 109 | 110 | return att_inputs 111 | 112 | ##################################### 113 | class Downsample(nn.Module): 114 | def __init__(self, input_feat, out_feat): 115 | super(Downsample, self).__init__() 116 | 117 | self.body = nn.Sequential( 118 | # dw 119 | nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False), 120 | # pw-linear 121 | nn.Conv2d(input_feat, out_feat // 4, 1, 1, 0, bias=False), 122 | nn.PixelUnshuffle(2)) 123 | 124 | def forward(self, x): 125 | return self.body(x) 126 | 127 | 128 | class Upsample(nn.Module): 129 | def __init__(self, input_feat, out_feat): 130 | super(Upsample, self).__init__() 131 | 132 | self.body = nn.Sequential( 133 | # dw 134 | nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False), 135 | # pw-linear 136 | nn.Conv2d(input_feat, out_feat * 4, 1, 1, 0, bias=False), 137 | nn.PixelShuffle(2)) 138 | 139 | def forward(self, x): 140 | return self.body(x) 141 | 142 | class MambaSEUNet(nn.Module): 143 | """ 144 | SEMamba model for speech enhancement using Mamba blocks. 145 | 146 | This model uses a dense encoder, multiple Mamba blocks, and separate magnitude 147 | and phase decoders to process noisy magnitude and phase inputs. 148 | """ 149 | def __init__(self, cfg): 150 | """ 151 | Initialize the SEMamba model. 152 | 153 | Args: 154 | - cfg: Configuration object containing model parameters. 155 | """ 156 | super(MambaSEUNet, self).__init__() 157 | self.cfg = cfg 158 | self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4 # default tfmamba: 4 159 | 160 | self.dim = [cfg['model_cfg']['hid_feature'], cfg['model_cfg']['hid_feature'] * 2, cfg['model_cfg']['hid_feature'] * 3] 161 | dim = self.dim 162 | 163 | # Initialize dense encoder 164 | self.dense_encoder = DenseEncoder(cfg) 165 | 166 | # Initialize Mamba blocks 167 | self.patch_embed_encoder_level1 = Patch_Embed_stage(dim[0], dim[0]) 168 | 169 | self.TSMamba1_encoder = nn.ModuleList([TFMambaBlock(cfg, dim[0]) for _ in range(self.num_tscblocks)]) 170 | 171 | self.down1_2 = Downsample(dim[0], dim[1]) 172 | 173 | self.patch_embed_encoder_level2 = Patch_Embed_stage(dim[1], dim[1]) 174 | 175 | self.TSMamba2_encoder = nn.ModuleList([TFMambaBlock(cfg, dim[1]) for _ in range(self.num_tscblocks)]) 176 | 177 | self.down2_3 = Downsample(dim[1], dim[2]) 178 | 179 | self.patch_embed_middle = Patch_Embed_stage(dim[2], dim[2]) 180 | 181 | self.TSMamba_middle = nn.ModuleList([TFMambaBlock(cfg, dim[2]) for _ in range(self.num_tscblocks)]) 182 | 183 | ########### 184 | 185 | self.up3_2 = Upsample(int(dim[2]), dim[1]) 186 | 187 | self.concat_level2 = nn.Sequential( 188 | nn.Conv2d(dim[1] * 2, dim[1], 1, 1, 0, bias=False), 189 | ) 190 | 191 | self.patch_embed_decoder_level2 = Patch_Embed_stage(dim[1], dim[1]) 192 | 193 | self.TSMamba2_decoder = nn.ModuleList([TFMambaBlock(cfg, dim[1]) for _ in range(self.num_tscblocks)]) 194 | 195 | self.up2_1 = Upsample(int(dim[1]), dim[0]) 196 | 197 | self.concat_level1 = nn.Sequential( 198 | nn.Conv2d(dim[0] * 2, dim[0], 1, 1, 0, bias=False), 199 | ) 200 | 201 | self.patch_embed_decoder_level1 = Patch_Embed_stage(dim[0], dim[0]) 202 | 203 | self.TSMamba1_decoder = nn.ModuleList([TFMambaBlock(cfg, dim[0]) for _ in range(self.num_tscblocks)]) 204 | 205 | # 幅度 206 | self.mag_patch_embed_refinement = Patch_Embed_stage(dim[0], dim[0]) 207 | 208 | self.mag_refinement = nn.ModuleList([TFMambaBlock(cfg, dim[0]) for _ in range(self.num_tscblocks)]) 209 | 210 | self.mag_output = nn.Sequential( 211 | nn.Conv2d(dim[0], dim[0], kernel_size=3, stride=1, padding=1, bias=False), 212 | 213 | ) 214 | 215 | # 相位 216 | self.pha_patch_embed_refinement = Patch_Embed_stage(dim[0], dim[0]) 217 | 218 | self.pha_refinement = nn.ModuleList([TFMambaBlock(cfg, dim[0]) for _ in range(self.num_tscblocks)]) 219 | 220 | self.pha_output = nn.Sequential( 221 | nn.Conv2d(dim[0], dim[0], kernel_size=3, stride=1, padding=1, bias=False), 222 | 223 | ) 224 | 225 | # Initialize decoders 226 | self.mask_decoder = MagDecoder(cfg) 227 | self.phase_decoder = PhaseDecoder(cfg) 228 | 229 | def forward(self, noisy_mag, noisy_pha): 230 | """ 231 | Forward pass for the SEMamba model. 232 | 233 | Args: 234 | - noisy_mag (torch.Tensor): Noisy magnitude input tensor [B, F, T]. 235 | - noisy_pha (torch.Tensor): Noisy phase input tensor [B, F, T]. 236 | 237 | Returns: 238 | - denoised_mag (torch.Tensor): Denoised magnitude tensor [B, F, T]. 239 | - denoised_pha (torch.Tensor): Denoised phase tensor [B, F, T]. 240 | - denoised_com (torch.Tensor): Denoised complex tensor [B, F, T, 2]. 241 | """ 242 | # Reshape inputs 243 | noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] 244 | noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1) # [B, 1, T, F] 245 | 246 | # Concatenate magnitude and phase inputs 247 | x = torch.cat((noisy_mag, noisy_pha), dim=1) # [B, 2, T, F] 248 | 249 | # Encode input 250 | x1 = self.dense_encoder(x) 251 | 252 | # Apply U-Net Mamba blocks 253 | copy1 = x1 254 | x1 = self.patch_embed_encoder_level1(x1) 255 | for block in self.TSMamba1_encoder: 256 | x1 = block(x1) 257 | x1 = copy1 + x1 258 | 259 | x2 = self.down1_2(x1) 260 | 261 | copy2 = x2 262 | x2 = self.patch_embed_encoder_level2(x2) 263 | for block in self.TSMamba2_encoder: 264 | x2 = block(x2) 265 | x2 = copy2 + x2 266 | 267 | x3 = self.down2_3(x2) 268 | 269 | copy3 = x3 270 | x3 = self.patch_embed_middle(x3) 271 | for block in self.TSMamba_middle: 272 | x3 = block(x3) 273 | x3 = copy3 + x3 274 | 275 | y2 = self.up3_2(x3) 276 | y2 = torch.cat([y2, x2], 1) 277 | y2 = self.concat_level2(y2) 278 | 279 | copy_de2 = y2 280 | y2 = self.patch_embed_decoder_level2(y2) 281 | for block in self.TSMamba2_decoder: 282 | y2 = block(y2) 283 | y2 = copy_de2 + y2 284 | 285 | y1 = self.up2_1(y2) 286 | y1 = torch.cat([y1, x1], 1) 287 | y1 = self.concat_level1(y1) 288 | 289 | copy_de1 = y1 290 | y1 = self.patch_embed_decoder_level1(y1) 291 | for block in self.TSMamba1_decoder: 292 | y1 = block(y1) 293 | y1 = copy_de1 + y1 294 | 295 | mag_input = y1 296 | pha_input = y1 297 | 298 | # magnitude 299 | copy_mag = mag_input 300 | mag_input = self.mag_patch_embed_refinement(mag_input) 301 | for block in self.mag_refinement: 302 | mag_input = block(mag_input) 303 | mag = copy_mag + mag_input 304 | mag = self.mag_output(mag) + copy1 305 | 306 | # phase 307 | copy_pha = pha_input 308 | pha_input = self.pha_patch_embed_refinement(pha_input) 309 | for block in self.pha_refinement: 310 | pha_input = block(pha_input) 311 | pha = copy_pha + pha_input 312 | pha = self.pha_output(pha) + copy1 313 | 314 | # Decode magnitude and phase 315 | denoised_mag = rearrange(self.mask_decoder(mag) * noisy_mag, 'b c t f -> b f t c').squeeze(-1) 316 | denoised_pha = rearrange(self.phase_decoder(pha), 'b c t f -> b f t c').squeeze(-1) 317 | 318 | # Combine denoised magnitude and phase into a complex representation 319 | denoised_com = torch.stack( 320 | (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)), 321 | dim=-1 322 | ) 323 | 324 | return denoised_mag, denoised_pha, denoised_com 325 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from pesq import pesq 6 | from joblib import Parallel, delayed 7 | 8 | def phase_losses(phase_r, phase_g, cfg): 9 | """ 10 | Calculate phase losses including in-phase loss, gradient delay loss, 11 | and integrated absolute frequency loss between reference and generated phases. 12 | 13 | Args: 14 | phase_r (torch.Tensor): Reference phase tensor of shape (batch, freq, time). 15 | phase_g (torch.Tensor): Generated phase tensor of shape (batch, freq, time). 16 | h (object): Configuration object containing parameters like n_fft. 17 | 18 | Returns: 19 | tuple: Tuple containing in-phase loss, gradient delay loss, and integrated absolute frequency loss. 20 | """ 21 | dim_freq = cfg['stft_cfg']['n_fft'] // 2 + 1 # Calculate frequency dimension 22 | dim_time = phase_r.size(-1) # Calculate time dimension 23 | 24 | # Construct gradient delay matrix 25 | gd_matrix = (torch.triu(torch.ones(dim_freq, dim_freq), diagonal=1) - 26 | torch.triu(torch.ones(dim_freq, dim_freq), diagonal=2) - 27 | torch.eye(dim_freq)).to(phase_g.device) 28 | 29 | # Apply gradient delay matrix to reference and generated phases 30 | gd_r = torch.matmul(phase_r.permute(0, 2, 1), gd_matrix) 31 | gd_g = torch.matmul(phase_g.permute(0, 2, 1), gd_matrix) 32 | 33 | # Construct integrated absolute frequency matrix 34 | iaf_matrix = (torch.triu(torch.ones(dim_time, dim_time), diagonal=1) - 35 | torch.triu(torch.ones(dim_time, dim_time), diagonal=2) - 36 | torch.eye(dim_time)).to(phase_g.device) 37 | 38 | # Apply integrated absolute frequency matrix to reference and generated phases 39 | iaf_r = torch.matmul(phase_r, iaf_matrix) 40 | iaf_g = torch.matmul(phase_g, iaf_matrix) 41 | 42 | # Calculate losses 43 | ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) 44 | gd_loss = torch.mean(anti_wrapping_function(gd_r - gd_g)) 45 | iaf_loss = torch.mean(anti_wrapping_function(iaf_r - iaf_g)) 46 | 47 | return ip_loss, gd_loss, iaf_loss 48 | 49 | def anti_wrapping_function(x): 50 | """ 51 | Anti-wrapping function to adjust phase values within the range of -pi to pi. 52 | 53 | Args: 54 | x (torch.Tensor): Input tensor representing phase differences. 55 | 56 | Returns: 57 | torch.Tensor: Adjusted tensor with phase values wrapped within -pi to pi. 58 | """ 59 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) 60 | 61 | def compute_stft(y: torch.Tensor, n_fft: int, hop_size: int, win_size: int, center: bool, compress_factor: float = 1.0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 62 | """ 63 | Compute the Short-Time Fourier Transform (STFT) and return magnitude, phase, and complex components. 64 | 65 | Args: 66 | y (torch.Tensor): Input signal tensor. 67 | n_fft (int): Number of FFT points. 68 | hop_size (int): Hop size for STFT. 69 | win_size (int): Window size for STFT. 70 | center (bool): Whether to pad the input on both sides. 71 | compress_factor (float, optional): Compression factor for magnitude. Defaults to 1.0. 72 | 73 | Returns: 74 | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Magnitude, phase, and complex components. 75 | """ 76 | eps = torch.finfo(y.dtype).eps 77 | hann_window = torch.hann_window(win_size).to(y.device) 78 | 79 | stft_spec = torch.stft( 80 | y, 81 | n_fft=n_fft, 82 | hop_length=hop_size, 83 | win_length=win_size, 84 | window=hann_window, 85 | center=center, 86 | pad_mode='reflect', 87 | normalized=False, 88 | return_complex=True 89 | ) 90 | 91 | real_part = stft_spec.real 92 | imag_part = stft_spec.imag 93 | 94 | mag = torch.sqrt( real_part.pow(2) * imag_part.pow(2) + eps ) 95 | pha = torch.atan2( real_part + eps, imag_part + eps ) 96 | 97 | mag = torch.pow(mag, compress_factor) 98 | com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1) 99 | 100 | return mag, pha, com 101 | 102 | def pesq_score(utts_r, utts_g, cfg): 103 | """ 104 | Calculate PESQ (Perceptual Evaluation of Speech Quality) score for pairs of reference and generated utterances. 105 | 106 | Args: 107 | utts_r (list of torch.Tensor): List of reference utterances. 108 | utts_g (list of torch.Tensor): List of generated utterances. 109 | h (object): Configuration object containing parameters like sampling_rate. 110 | 111 | Returns: 112 | float: Mean PESQ score across all pairs of utterances. 113 | """ 114 | def eval_pesq(clean_utt, esti_utt, sr): 115 | """ 116 | Evaluate PESQ score for a single pair of clean and estimated utterances. 117 | 118 | Args: 119 | clean_utt (np.ndarray): Clean reference utterance. 120 | esti_utt (np.ndarray): Estimated generated utterance. 121 | sr (int): Sampling rate. 122 | 123 | Returns: 124 | float: PESQ score or -1 in case of an error. 125 | """ 126 | try: 127 | pesq_score = pesq(sr, clean_utt, esti_utt) 128 | except Exception as e: 129 | # Error can happen due to silent period or other issues 130 | print(f"Error computing PESQ score: {e}") 131 | pesq_score = -1 132 | return pesq_score 133 | 134 | # Parallel processing of PESQ score computation 135 | pesq_scores = Parallel(n_jobs=30)(delayed(eval_pesq)( 136 | utts_r[i].squeeze().cpu().numpy(), 137 | utts_g[i].squeeze().cpu().numpy(), 138 | cfg['stft_cfg']['sampling_rate'] 139 | ) for i in range(len(utts_r))) 140 | 141 | # Calculate mean PESQ score 142 | pesq_score = np.mean(pesq_scores) 143 | return pesq_score 144 | 145 | -------------------------------------------------------------------------------- /models/lsigmoid.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/utils.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | class LearnableSigmoid1D(nn.Module): 7 | """ 8 | Learnable Sigmoid Activation Function for 1D inputs. 9 | 10 | This module applies a learnable slope parameter to the sigmoid activation function. 11 | """ 12 | def __init__(self, in_features, beta=1): 13 | """ 14 | Initialize the LearnableSigmoid1D module. 15 | 16 | Args: 17 | - in_features (int): Number of input features. 18 | - beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1. 19 | """ 20 | super(LearnableSigmoid1D, self).__init__() 21 | self.beta = beta 22 | self.slope = nn.Parameter(torch.ones(in_features)) 23 | self.slope.requires_grad = True 24 | 25 | def forward(self, x): 26 | """ 27 | Forward pass for the LearnableSigmoid1D module. 28 | 29 | Args: 30 | - x (torch.Tensor): Input tensor. 31 | 32 | Returns: 33 | - torch.Tensor: Output tensor after applying the learnable sigmoid activation. 34 | """ 35 | return self.beta * torch.sigmoid(self.slope * x) 36 | 37 | class LearnableSigmoid2D(nn.Module): 38 | """ 39 | Learnable Sigmoid Activation Function for 2D inputs. 40 | 41 | This module applies a learnable slope parameter to the sigmoid activation function for 2D inputs. 42 | """ 43 | def __init__(self, in_features, beta=1): 44 | """ 45 | Initialize the LearnableSigmoid2D module. 46 | 47 | Args: 48 | - in_features (int): Number of input features. 49 | - beta (float, optional): Scaling factor for the sigmoid function. Defaults to 1. 50 | """ 51 | super(LearnableSigmoid2D, self).__init__() 52 | self.beta = beta 53 | self.slope = nn.Parameter(torch.ones(in_features, 1)) 54 | self.slope.requires_grad = True 55 | 56 | def forward(self, x): 57 | """ 58 | Forward pass for the LearnableSigmoid2D module. 59 | 60 | Args: 61 | - x (torch.Tensor): Input tensor. 62 | 63 | Returns: 64 | - torch.Tensor: Output tensor after applying the learnable sigmoid activation. 65 | """ 66 | return self.beta * torch.sigmoid(self.slope * x) 67 | -------------------------------------------------------------------------------- /models/mamba_block.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/RoyChao19477/SEMamba/models/mamba_block 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | from torch.nn.parameter import Parameter 8 | from functools import partial 9 | from einops import rearrange 10 | 11 | from mamba_ssm.modules.mamba_simple import Mamba, Block 12 | from mamba_ssm.models.mixer_seq_simple import _init_weights 13 | from mamba_ssm.ops.triton.layernorm import RMSNorm 14 | 15 | # github: https://github.com/state-spaces/mamba/blob/9127d1f47f367f5c9cc49c73ad73557089d02cb8/mamba_ssm/models/mixer_seq_simple.py 16 | def create_block( 17 | d_model, cfg, layer_idx=0, rms_norm=True, fused_add_norm=False, residual_in_fp32=False, 18 | ): 19 | d_state = cfg['model_cfg']['d_state'] # 16 20 | d_conv = cfg['model_cfg']['d_conv'] # 4 21 | expand = cfg['model_cfg']['expand'] # 4 22 | norm_epsilon = cfg['model_cfg']['norm_epsilon'] # 0.00001 23 | 24 | mixer_cls = partial(Mamba, layer_idx=layer_idx, d_state=d_state, d_conv=d_conv, expand=expand) 25 | norm_cls = partial( 26 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon 27 | ) 28 | block = Block( 29 | d_model, 30 | mixer_cls, 31 | norm_cls=norm_cls, 32 | fused_add_norm=fused_add_norm, 33 | residual_in_fp32=residual_in_fp32, 34 | ) 35 | block.layer_idx = layer_idx 36 | return block 37 | 38 | class MambaBlock(nn.Module): 39 | def __init__(self, in_channels, cfg): 40 | super(MambaBlock, self).__init__() 41 | n_layer = 1 42 | self.forward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) ) 43 | self.backward_blocks = nn.ModuleList( create_block(in_channels, cfg) for i in range(n_layer) ) 44 | 45 | self.apply( 46 | partial( 47 | _init_weights, 48 | n_layer=n_layer, 49 | ) 50 | ) 51 | 52 | def forward(self, x): 53 | x_forward, x_backward = x.clone(), torch.flip(x, [1]) 54 | resi_forward, resi_backward = None, None 55 | 56 | # Forward 57 | for layer in self.forward_blocks: 58 | x_forward, resi_forward = layer(x_forward, resi_forward) 59 | y_forward = (x_forward + resi_forward) if resi_forward is not None else x_forward 60 | 61 | # Backward 62 | for layer in self.backward_blocks: 63 | x_backward, resi_backward = layer(x_backward, resi_backward) 64 | y_backward = torch.flip((x_backward + resi_backward), [1]) if resi_backward is not None else torch.flip(x_backward, [1]) 65 | 66 | return torch.cat([y_forward, y_backward], -1) 67 | 68 | class TFMambaBlock(nn.Module): 69 | """ 70 | Temporal-Frequency Mamba block for sequence modeling. 71 | 72 | Attributes: 73 | cfg (Config): Configuration for the block. 74 | time_mamba (MambaBlock): Mamba block for temporal dimension. 75 | freq_mamba (MambaBlock): Mamba block for frequency dimension. 76 | tlinear (ConvTranspose1d): ConvTranspose1d layer for temporal dimension. 77 | flinear (ConvTranspose1d): ConvTranspose1d layer for frequency dimension. 78 | """ 79 | def __init__(self, cfg, inchannels): 80 | super(TFMambaBlock, self).__init__() 81 | self.cfg = cfg 82 | self.hid_feature = inchannels 83 | 84 | # Initialize Mamba blocks 85 | self.time_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg) 86 | self.freq_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg) 87 | 88 | # Initialize ConvTranspose1d layers 89 | self.tlinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1) 90 | self.flinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1) 91 | 92 | def forward(self, x): 93 | """ 94 | Forward pass of the TFMamba block. 95 | 96 | Parameters: 97 | x (Tensor): Input tensor with shape (batch, channels, time, freq). 98 | 99 | Returns: 100 | Tensor: Output tensor after applying temporal and frequency Mamba blocks. 101 | """ 102 | b, c, t, f = x.size() 103 | 104 | x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c) 105 | x = self.tlinear( self.time_mamba(x).permute(0,2,1) ).permute(0,2,1) + x 106 | x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c) 107 | x = self.flinear( self.freq_mamba(x).permute(0,2,1) ).permute(0,2,1) + x 108 | x = x.view(b, t, f, c).permute(0, 3, 1, 2) 109 | return x 110 | 111 | -------------------------------------------------------------------------------- /models/pcs400.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/RoyChao19477/SEMamba/models/pcs400 2 | 3 | import os 4 | import torch 5 | import torchaudio 6 | import numpy as np 7 | import argparse 8 | import librosa 9 | import scipy 10 | 11 | # PCS400 parameters 12 | PCS400 = np.ones(201) 13 | PCS400[0:3] = 1 14 | PCS400[3:5] = 1.070175439 15 | PCS400[5:8] = 1.182456140 16 | PCS400[8:10] = 1.287719298 17 | PCS400[10:110] = 1.4 # Pre Set 18 | PCS400[110:130] = 1.322807018 19 | PCS400[130:160] = 1.238596491 20 | PCS400[160:190] = 1.161403509 21 | PCS400[190:202] = 1.077192982 22 | 23 | maxv = np.iinfo(np.int16).max 24 | 25 | def Sp_and_phase(signal): 26 | signal_length = signal.shape[0] 27 | n_fft = 400 28 | hop_length = 100 29 | y_pad = librosa.util.fix_length(signal, size=signal_length + n_fft // 2) 30 | 31 | F = librosa.stft(y_pad, n_fft=400, hop_length=100, win_length=400, window=scipy.signal.windows.hamming(400)) 32 | Lp = PCS400 * np.transpose(np.log1p(np.abs(F)), (1, 0)) 33 | phase = np.angle(F) 34 | 35 | NLp = np.transpose(Lp, (1, 0)) 36 | 37 | return NLp, phase, signal_length 38 | 39 | 40 | def SP_to_wav(mag, phase, signal_length): 41 | mag = np.expm1(mag) 42 | Rec = np.multiply(mag, np.exp(1j*phase)) 43 | result = librosa.istft(Rec, 44 | hop_length=100, 45 | win_length=400, 46 | window=scipy.signal.windows.hamming(400), 47 | length=signal_length) 48 | return result 49 | 50 | def cal_pcs(signal_wav): 51 | noisy_LP, Nphase, signal_length = Sp_and_phase(signal_wav.squeeze()) 52 | enhanced_wav = SP_to_wav(noisy_LP, Nphase, signal_length) 53 | enhanced_wav = enhanced_wav/np.max(abs(enhanced_wav)) 54 | 55 | return enhanced_wav 56 | -------------------------------------------------------------------------------- /models/stfts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def mag_phase_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True, addeps=False): 5 | """ 6 | Compute magnitude and phase using STFT. 7 | 8 | Args: 9 | y (torch.Tensor): Input audio signal. 10 | n_fft (int): FFT size. 11 | hop_size (int): Hop size. 12 | win_size (int): Window size. 13 | compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0. 14 | center (bool, optional): Whether to center the signal before padding. Defaults to True. 15 | eps (bool, optional): Whether adding epsilon to magnitude and phase or not. Defaults to False. 16 | 17 | Returns: 18 | tuple: Magnitude, phase, and complex representation of the STFT. 19 | """ 20 | #eps = torch.finfo(y.dtype).eps 21 | eps = 1e-10 22 | hann_window = torch.hann_window(win_size).to(y.device) 23 | stft_spec = torch.stft( 24 | y, n_fft, 25 | hop_length=hop_size, 26 | win_length=win_size, 27 | window=hann_window, 28 | center=center, 29 | pad_mode='reflect', 30 | normalized=False, 31 | return_complex=True) 32 | 33 | if addeps==False: 34 | mag = torch.abs(stft_spec) 35 | pha = torch.angle(stft_spec) 36 | else: 37 | real_part = stft_spec.real 38 | imag_part = stft_spec.imag 39 | mag = torch.sqrt(real_part.pow(2) + imag_part.pow(2) + eps) 40 | pha = torch.atan2(imag_part + eps, real_part + eps) 41 | # Compress the magnitude 42 | mag = torch.pow(mag, compress_factor) 43 | com = torch.stack((mag * torch.cos(pha), mag * torch.sin(pha)), dim=-1) 44 | return mag, pha, com 45 | 46 | 47 | def mag_phase_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): 48 | """ 49 | Inverse STFT to reconstruct the audio signal from magnitude and phase. 50 | 51 | Args: 52 | mag (torch.Tensor): Magnitude of the STFT. 53 | pha (torch.Tensor): Phase of the STFT. 54 | n_fft (int): FFT size. 55 | hop_size (int): Hop size. 56 | win_size (int): Window size. 57 | compress_factor (float, optional): Magnitude compression factor. Defaults to 1.0. 58 | center (bool, optional): Whether to center the signal before padding. Defaults to True. 59 | 60 | Returns: 61 | torch.Tensor: Reconstructed audio signal. 62 | """ 63 | mag = torch.pow(mag, 1.0 / compress_factor) 64 | com = torch.complex(mag * torch.cos(pha), mag * torch.sin(pha)) 65 | hann_window = torch.hann_window(win_size).to(com.device) 66 | wav = torch.istft( 67 | com, 68 | n_fft, 69 | hop_length=hop_size, 70 | win_length=win_size, 71 | window=hann_window, 72 | center=center) 73 | return wav 74 | -------------------------------------------------------------------------------- /recipes/Mamba-SEUNet-PCS/Mamba-SEUNet-PCS.yaml: -------------------------------------------------------------------------------- 1 | # Environment Settings 2 | # These settings specify the hardware and distributed setup for the model training. 3 | # Adjust `num_gpus` and `dist_config` according to your distributed training environment. 4 | env_setting: 5 | num_gpus: 1 # Number of GPUs. Now we don't support CPU mode. 6 | num_workers: 12 # Number of worker threads for data loading. 7 | seed: 1234 # Seed for random number generators to ensure reproducibility. 8 | stdout_interval: 200 9 | checkpoint_interval: 2000 # save model to ckpt every N steps 10 | validation_interval: 2000 11 | summary_interval: 200 12 | dist_cfg: 13 | dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs. 14 | dist_url: tcp://localhost:13407 # URL for initializing distributed training. 15 | world_size: 1 # Total number of processes in the distributed training. 16 | 17 | # Datapath Configuratoin 18 | data_cfg: 19 | train_clean_json: data/train_clean.json 20 | train_noisy_json: data/train_noisy.json 21 | valid_clean_json: data/valid_clean.json 22 | valid_noisy_json: data/valid_noisy.json 23 | test_clean_json: data/test_clean.json 24 | test_noisy_json: data/test_noisy.json 25 | 26 | # Training Configuration 27 | # This section details parameters that directly influence the training process, 28 | # including batch sizes, learning rates, and optimizer specifics. 29 | training_cfg: 30 | training_epochs: 200 # Training epoch. 31 | batch_size: 2 # Training batch size. 32 | learning_rate: 0.0005 # Initial learning rate. 33 | adam_b1: 0.8 # Beta1 hyperparameter for the AdamW optimizer. 34 | adam_b2: 0.99 # Beta2 hyperparameter for the AdamW optimizer. 35 | lr_decay: 0.99 # Learning rate decay per epoch. 36 | segment_size: 30600 # Audio segment size used during training, dependent on sampling rate. 37 | loss: 38 | metric: 0.05 39 | magnitude: 0.9 40 | phase: 0.3 41 | complex: 0.1 42 | time: 0.2 43 | consistancy: 0.1 44 | use_PCS400: True # Use PCS or not 45 | 46 | # STFT Configuration 47 | # Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models. 48 | stft_cfg: 49 | sampling_rate: 16000 # Audio sampling rate in Hz. 50 | n_fft: 510 # FFT components for transforming audio signals. 51 | hop_size: 120 # Samples between successive frames. 52 | win_size: 510 # Window size used in FFT. 53 | 54 | # Model Configuration 55 | # Defines the architecture specifics of the model, including layer configurations and feature compression. 56 | model_cfg: 57 | hid_feature: 32 # Channels in dense layers. 58 | compress_factor: 0.3 # Compression factor applied to extracted features. 59 | num_tfmamba: 4 # Number of Time-Frequency Mamba (TFMamba) blocks in the model. 60 | d_state: 16 # Dimensionality of the state vector in Mamba blocks. 61 | d_conv: 4 # Convolutional layer dimensionality within Mamba blocks. 62 | expand: 4 # Expansion factor for the layers within the Mamba blocks. 63 | norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks. 64 | beta: 2.0 # Hyperparameter for the Learnable Sigmoid function. 65 | input_channel: 2 # Magnitude and Phase 66 | output_channel: 1 # Single Channel Speech Enhancement 67 | -------------------------------------------------------------------------------- /recipes/Mamba-SEUNet/Mamba-SEUNet.yaml: -------------------------------------------------------------------------------- 1 | # Environment Settings 2 | # These settings specify the hardware and distributed setup for the model training. 3 | # Adjust `num_gpus` and `dist_config` according to your distributed training environment. 4 | env_setting: 5 | num_gpus: 1 # Number of GPUs. Now we don't support CPU mode. 6 | num_workers: 12 # Number of worker threads for data loading. 7 | seed: 1234 # Seed for random number generators to ensure reproducibility. 8 | stdout_interval: 200 9 | checkpoint_interval: 2000 # save model to ckpt every N steps 10 | validation_interval: 2000 11 | summary_interval: 200 12 | dist_cfg: 13 | dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs. 14 | dist_url: tcp://localhost:13407 # URL for initializing distributed training. 15 | world_size: 1 # Total number of processes in the distributed training. 16 | 17 | # Datapath Configuratoin 18 | data_cfg: 19 | train_clean_json: data/train_clean.json 20 | train_noisy_json: data/train_noisy.json 21 | valid_clean_json: data/valid_clean.json 22 | valid_noisy_json: data/valid_noisy.json 23 | test_clean_json: data/test_clean.json 24 | test_noisy_json: data/test_noisy.json 25 | 26 | # Training Configuration 27 | # This section details parameters that directly influence the training process, 28 | # including batch sizes, learning rates, and optimizer specifics. 29 | training_cfg: 30 | training_epochs: 200 # Training epoch. 31 | batch_size: 2 # Training batch size. 32 | learning_rate: 0.0005 # Initial learning rate. 33 | adam_b1: 0.8 # Beta1 hyperparameter for the AdamW optimizer. 34 | adam_b2: 0.99 # Beta2 hyperparameter for the AdamW optimizer. 35 | lr_decay: 0.99 # Learning rate decay per epoch. 36 | segment_size: 30600 # Audio segment size used during training, dependent on sampling rate. 37 | loss: 38 | metric: 0.05 39 | magnitude: 0.9 40 | phase: 0.3 41 | complex: 0.1 42 | time: 0.2 43 | consistancy: 0.1 44 | use_PCS400: False # Use PCS or not 45 | 46 | # STFT Configuration 47 | # Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models. 48 | stft_cfg: 49 | sampling_rate: 16000 # Audio sampling rate in Hz. 50 | n_fft: 510 # FFT components for transforming audio signals. 51 | hop_size: 120 # Samples between successive frames. 52 | win_size: 510 # Window size used in FFT. 53 | 54 | # Model Configuration 55 | # Defines the architecture specifics of the model, including layer configurations and feature compression. 56 | model_cfg: 57 | hid_feature: 32 # Channels in dense layers. 58 | compress_factor: 0.3 # Compression factor applied to extracted features. 59 | num_tfmamba: 4 # Number of Time-Frequency Mamba (TFMamba) blocks in the model. 60 | d_state: 16 # Dimensionality of the state vector in Mamba blocks. 61 | d_conv: 4 # Convolutional layer dimensionality within Mamba blocks. 62 | expand: 4 # Expansion factor for the layers within the Mamba blocks. 63 | norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks. 64 | beta: 2.0 # Hyperparameter for the Learnable Sigmoid function. 65 | input_channel: 2 # Magnitude and Phase 66 | output_channel: 1 # Single Channel Speech Enhancement 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.2+cu118 2 | torchaudio==2.1.2+cu118 3 | numpy==1.24.3 4 | librosa==0.10.2.post1 5 | numba==0.60.0 6 | scipy==1.14.0 7 | tensorboard==2.17.0 8 | SoundFile==0.12.1 9 | einops==0.8.0 10 | joblib==1.4.2 11 | pesq==0.0.4 12 | mamba-ssm==1.0.1 13 | pyyaml==6.0.1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/RoyChao19477/SEMamba/train.py 2 | # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/train.py 3 | 4 | import warnings 5 | warnings.simplefilter(action='ignore', category=FutureWarning) 6 | import os 7 | import time 8 | import argparse 9 | import json 10 | import yaml 11 | import torch 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch.utils.tensorboard import SummaryWriter 15 | from torch.utils.data import DistributedSampler, DataLoader 16 | import torch.multiprocessing as mp 17 | from torch.nn.parallel import DistributedDataParallel 18 | 19 | from dataloaders.dataloader_vctk import VCTKDemandDataset, Val_Dataset 20 | from models.stfts import mag_phase_stft, mag_phase_istft 21 | from models.generator import MambaSEUNet 22 | from models.loss import pesq_score, phase_losses 23 | from models.discriminator import MetricDiscriminator, batch_pesq 24 | from utils.util import ( 25 | load_ckpts, load_optimizer_states, save_checkpoint, 26 | build_env, load_config, initialize_seed, 27 | print_gpu_info, log_model_info, initialize_process_group, 28 | ) 29 | 30 | torch.backends.cudnn.benchmark = True 31 | 32 | def setup_optimizers(models, cfg): 33 | """Set up optimizers for the models.""" 34 | generator, discriminator = models 35 | learning_rate = cfg['training_cfg']['learning_rate'] 36 | betas = (cfg['training_cfg']['adam_b1'], cfg['training_cfg']['adam_b2']) 37 | 38 | optim_g = optim.AdamW(generator.parameters(), lr=learning_rate, betas=betas) 39 | optim_d = optim.AdamW(discriminator.parameters(), lr=learning_rate, betas=betas) 40 | 41 | return optim_g, optim_d 42 | 43 | def setup_schedulers(optimizers, cfg, last_epoch): 44 | """Set up learning rate schedulers.""" 45 | optim_g, optim_d = optimizers 46 | lr_decay = cfg['training_cfg']['lr_decay'] 47 | 48 | scheduler_g = optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay, last_epoch=last_epoch) 49 | scheduler_d = optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay, last_epoch=last_epoch) 50 | 51 | return scheduler_g, scheduler_d 52 | 53 | 54 | def create_val_dataset(cfg, train=True, split=True, device='cuda:0'): 55 | """Create dataset based on cfguration.""" 56 | clean_json = cfg['data_cfg']['train_clean_json'] if train else cfg['data_cfg']['valid_clean_json'] 57 | noisy_json = cfg['data_cfg']['train_noisy_json'] if train else cfg['data_cfg']['valid_noisy_json'] 58 | shuffle = (cfg['env_setting']['num_gpus'] <= 1) if train else False 59 | pcs = cfg['training_cfg']['use_PCS400'] if train else False 60 | 61 | return Val_Dataset( 62 | clean_json=clean_json, 63 | noisy_json=noisy_json, 64 | sampling_rate=cfg['stft_cfg']['sampling_rate'], 65 | segment_size=cfg['training_cfg']['segment_size'], 66 | n_fft=cfg['stft_cfg']['n_fft'], 67 | hop_size=cfg['stft_cfg']['hop_size'], 68 | win_size=cfg['stft_cfg']['win_size'], 69 | compress_factor=cfg['model_cfg']['compress_factor'], 70 | split=split, 71 | n_cache_reuse=0, 72 | shuffle=shuffle, 73 | device=device, 74 | pcs=pcs 75 | ) 76 | 77 | 78 | def create_dataset(cfg, train=True, split=True, device='cuda:0'): 79 | """Create dataset based on cfguration.""" 80 | clean_json = cfg['data_cfg']['train_clean_json'] if train else cfg['data_cfg']['valid_clean_json'] 81 | noisy_json = cfg['data_cfg']['train_noisy_json'] if train else cfg['data_cfg']['valid_noisy_json'] 82 | shuffle = (cfg['env_setting']['num_gpus'] <= 1) if train else False 83 | pcs = cfg['training_cfg']['use_PCS400'] if train else False 84 | 85 | return VCTKDemandDataset( 86 | clean_json=clean_json, 87 | noisy_json=noisy_json, 88 | sampling_rate=cfg['stft_cfg']['sampling_rate'], 89 | segment_size=cfg['training_cfg']['segment_size'], 90 | n_fft=cfg['stft_cfg']['n_fft'], 91 | hop_size=cfg['stft_cfg']['hop_size'], 92 | win_size=cfg['stft_cfg']['win_size'], 93 | compress_factor=cfg['model_cfg']['compress_factor'], 94 | split=split, 95 | n_cache_reuse=0, 96 | shuffle=shuffle, 97 | device=device, 98 | pcs=pcs 99 | ) 100 | 101 | def create_dataloader(dataset, cfg, train=True): 102 | """Create dataloader based on dataset and configuration.""" 103 | if cfg['env_setting']['num_gpus'] > 1: 104 | sampler = DistributedSampler(dataset) 105 | sampler.set_epoch(cfg['training_cfg']['training_epochs']) 106 | batch_size = (cfg['training_cfg']['batch_size'] // cfg['env_setting']['num_gpus']) if train else 1 107 | else: 108 | sampler = None 109 | batch_size = cfg['training_cfg']['batch_size'] if train else 1 110 | num_workers = cfg['env_setting']['num_workers'] 111 | 112 | return DataLoader( 113 | dataset, 114 | num_workers=num_workers, 115 | shuffle=(sampler is None) and train, 116 | sampler=sampler, 117 | batch_size=batch_size, 118 | pin_memory=True, 119 | drop_last=True 120 | ) 121 | 122 | 123 | def train(rank, args, cfg): 124 | num_gpus = cfg['env_setting']['num_gpus'] 125 | n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size'] 126 | compress_factor = cfg['model_cfg']['compress_factor'] 127 | batch_size = cfg['training_cfg']['batch_size'] // cfg['env_setting']['num_gpus'] 128 | if num_gpus >= 1: 129 | initialize_process_group(cfg, rank) 130 | device = torch.device('cuda:{:d}'.format(rank)) 131 | else: 132 | raise RuntimeError("Mamba needs GPU acceleration") 133 | 134 | generator = MambaSEUNet(cfg).to(device) 135 | discriminator = MetricDiscriminator().to(device) 136 | 137 | if rank == 0: 138 | log_model_info(rank, generator, args.exp_path) 139 | 140 | state_dict_g, state_dict_do, steps, last_epoch = load_ckpts(args, device) 141 | if state_dict_g is not None: 142 | generator.load_state_dict(state_dict_g['generator'], strict=False) 143 | discriminator.load_state_dict(state_dict_do['discriminator'], strict=False) 144 | 145 | if num_gpus > 1 and torch.cuda.is_available(): 146 | generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) 147 | discriminator = DistributedDataParallel(discriminator, device_ids=[rank]).to(device) 148 | 149 | # Create optimizer and schedulers 150 | optimizers = setup_optimizers((generator, discriminator), cfg) 151 | load_optimizer_states(optimizers, state_dict_do) 152 | optim_g, optim_d = optimizers 153 | scheduler_g, scheduler_d = setup_schedulers(optimizers, cfg, last_epoch) 154 | 155 | # Create trainset and train_loader 156 | trainset = create_dataset(cfg, train=True, split=True, device=device) 157 | train_loader = create_dataloader(trainset, cfg, train=True) 158 | 159 | # Create validset and validation_loader if rank is 0 160 | if rank == 0: 161 | validset = create_val_dataset(cfg, train=False, split=False, device=device) 162 | validation_loader = create_dataloader(validset, cfg, train=False) 163 | sw = SummaryWriter(os.path.join(args.exp_path, 'logs')) 164 | 165 | generator.train() 166 | discriminator.train() 167 | 168 | best_pesq, best_pesq_step = 0.0, 0 169 | for epoch in range(max(0, last_epoch), cfg['training_cfg']['training_epochs']): 170 | if rank == 0: 171 | start = time.time() 172 | print("Epoch: {}".format(epoch+1)) 173 | 174 | for i, batch in enumerate(train_loader): 175 | if rank == 0: 176 | start_b = time.time() 177 | clean_audio, clean_mag, clean_pha, clean_com, noisy_mag, noisy_pha = batch # [B, 1, F, T], F = nfft // 2+ 1, T = nframes 178 | clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True)) 179 | clean_mag = torch.autograd.Variable(clean_mag.to(device, non_blocking=True)) 180 | clean_pha = torch.autograd.Variable(clean_pha.to(device, non_blocking=True)) 181 | clean_com = torch.autograd.Variable(clean_com.to(device, non_blocking=True)) 182 | noisy_mag = torch.autograd.Variable(noisy_mag.to(device, non_blocking=True)) 183 | noisy_pha = torch.autograd.Variable(noisy_pha.to(device, non_blocking=True)) 184 | one_labels = torch.ones(batch_size).to(device, non_blocking=True) 185 | 186 | mag_g, pha_g, com_g = generator(noisy_mag, noisy_pha) 187 | 188 | audio_g = mag_phase_istft(mag_g, pha_g, n_fft, hop_size, win_size, compress_factor) 189 | audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy()) 190 | batch_pesq_score = batch_pesq(audio_list_r, audio_list_g, cfg) 191 | 192 | # Discriminator 193 | # ------------------------------------------------------- # 194 | optim_d.zero_grad() 195 | metric_r = discriminator(clean_mag, clean_mag) 196 | metric_g = discriminator(clean_mag, mag_g.detach()) 197 | loss_disc_r = F.mse_loss(one_labels, metric_r.flatten()) 198 | 199 | if batch_pesq_score is not None: 200 | loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten()) 201 | else: 202 | loss_disc_g = 0 203 | 204 | loss_disc_all = loss_disc_r + loss_disc_g 205 | 206 | loss_disc_all.backward() 207 | optim_d.step() 208 | # ------------------------------------------------------- # 209 | 210 | # Generator 211 | # ------------------------------------------------------- # 212 | optim_g.zero_grad() 213 | 214 | # Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/train.py 215 | # L2 Magnitude Loss 216 | loss_mag = F.mse_loss(clean_mag, mag_g) 217 | # Anti-wrapping Phase Loss 218 | loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g, cfg) 219 | loss_pha = loss_ip + loss_gd + loss_iaf 220 | # L2 Complex Loss 221 | loss_com = F.mse_loss(clean_com, com_g) * 2 222 | # Time Loss 223 | loss_time = F.l1_loss(clean_audio, audio_g) 224 | # Metric Loss 225 | metric_g = discriminator(clean_mag, mag_g) 226 | loss_metric = F.mse_loss(metric_g.flatten(), one_labels) 227 | # Consistancy Loss 228 | _, _, rec_com = mag_phase_stft(audio_g, n_fft, hop_size, win_size, compress_factor, addeps=True) 229 | loss_con = F.mse_loss(com_g, rec_com) * 2 230 | 231 | loss_gen_all = ( 232 | loss_metric * cfg['training_cfg']['loss']['metric'] + 233 | loss_mag * cfg['training_cfg']['loss']['magnitude'] + 234 | loss_pha * cfg['training_cfg']['loss']['phase'] + 235 | loss_com * cfg['training_cfg']['loss']['complex'] + 236 | loss_time * cfg['training_cfg']['loss']['time'] + 237 | loss_con * cfg['training_cfg']['loss']['consistancy'] 238 | ) 239 | 240 | loss_gen_all.backward() 241 | optim_g.step() 242 | # ------------------------------------------------------- # 243 | 244 | if rank == 0: 245 | # STDOUT logging 246 | if steps % cfg['env_setting']['stdout_interval'] == 0: 247 | with torch.no_grad(): 248 | metric_error = F.mse_loss(metric_g.flatten(), one_labels).item() 249 | mag_error = F.mse_loss(clean_mag, mag_g).item() 250 | ip_error, gd_error, iaf_error = phase_losses(clean_pha, pha_g, cfg) 251 | pha_error = (loss_ip + loss_gd + loss_iaf).item() 252 | com_error = F.mse_loss(clean_com, com_g).item() 253 | time_error = F.l1_loss(clean_audio, audio_g).item() 254 | con_error = F.mse_loss( com_g, rec_com ).item() 255 | 256 | print( 257 | 'Steps : {:d}, Gen Loss: {:4.3f}, Disc Loss: {:4.3f}, Metric Loss: {:4.3f}, ' 258 | 'Mag Loss: {:4.3f}, Pha Loss: {:4.3f}, Com Loss: {:4.3f}, Time Loss: {:4.3f}, Cons Loss: {:4.3f}, s/b : {:4.3f}'.format( 259 | steps, loss_gen_all, loss_disc_all, metric_error, mag_error, pha_error, com_error, time_error, con_error, time.time() - start_b 260 | ) 261 | ) 262 | 263 | # Checkpointing 264 | if steps % cfg['env_setting']['checkpoint_interval'] == 0 and steps != 0: 265 | exp_name = f"{args.exp_path}/g_{steps:08d}.pth" 266 | save_checkpoint( 267 | exp_name, 268 | { 269 | 'generator': (generator.module if num_gpus > 1 else generator).state_dict() 270 | } 271 | ) 272 | exp_name = f"{args.exp_path}/do_{steps:08d}.pth" 273 | save_checkpoint( 274 | exp_name, 275 | { 276 | 'discriminator': (discriminator.module if num_gpus > 1 else discriminator).state_dict(), 277 | 'optim_g': optim_g.state_dict(), 278 | 'optim_d': optim_d.state_dict(), 279 | 'steps': steps, 280 | 'epoch': epoch 281 | } 282 | ) 283 | 284 | # Tensorboard summary logging 285 | if steps % cfg['env_setting']['summary_interval'] == 0: 286 | sw.add_scalar("Training/Generator Loss", loss_gen_all, steps) 287 | sw.add_scalar("Training/Discriminator Loss", loss_disc_all, steps) 288 | sw.add_scalar("Training/Metric Loss", metric_error, steps) 289 | sw.add_scalar("Training/Magnitude Loss", mag_error, steps) 290 | sw.add_scalar("Training/Phase Loss", pha_error, steps) 291 | sw.add_scalar("Training/Complex Loss", com_error, steps) 292 | sw.add_scalar("Training/Time Loss", time_error, steps) 293 | sw.add_scalar("Training/Consistancy Loss", con_error, steps) 294 | 295 | # If NaN happend in training period, RaiseError 296 | if torch.isnan(loss_gen_all).any(): 297 | raise ValueError("NaN values found in loss_gen_all") 298 | 299 | # Validation 300 | if steps % cfg['env_setting']['validation_interval'] == 0 and steps != 0: 301 | generator.eval() 302 | torch.cuda.empty_cache() 303 | audios_r, audios_g = [], [] 304 | val_mag_err_tot = 0 305 | val_pha_err_tot = 0 306 | val_com_err_tot = 0 307 | with torch.no_grad(): 308 | for j, batch in enumerate(validation_loader): 309 | clean_audio, clean_mag, clean_pha, clean_com, noisy_audio = batch # [B, 1, F, T], F = nfft // 2+ 1, T = nframes 310 | clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True)) 311 | clean_mag = torch.autograd.Variable(clean_mag.to(device, non_blocking=True)) 312 | clean_pha = torch.autograd.Variable(clean_pha.to(device, non_blocking=True)) 313 | clean_com = torch.autograd.Variable(clean_com.to(device, non_blocking=True)) 314 | noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True)) 315 | 316 | orig_size = noisy_audio.size(1) 317 | 318 | # 判断是否需要补零 319 | if noisy_audio.size(1) >= cfg['training_cfg']['segment_size']: 320 | num_segments = noisy_audio.size(1) // cfg['training_cfg']['segment_size'] 321 | last_segment_size = noisy_audio.size(1) % cfg['training_cfg']['segment_size'] 322 | if last_segment_size > 0: 323 | last_segment = noisy_audio[:, -cfg['training_cfg']['segment_size']:] 324 | noisy_audio = noisy_audio[:, :-last_segment_size] 325 | segments = torch.split(noisy_audio, cfg['training_cfg']['segment_size'], dim=1) 326 | segments = list(segments) 327 | segments.append(last_segment) 328 | reshapelast = 1 329 | else: 330 | segments = torch.split(noisy_audio, cfg['training_cfg']['segment_size'], dim=1) 331 | reshapelast = 0 332 | 333 | else: 334 | # 如果语音长度小于一个segment_size,则直接补零 335 | padded_zeros = torch.zeros(1, cfg['training_cfg']['segment_size'] - noisy_audio.size(1)).to(device) 336 | noisy_audio = torch.cat((noisy_audio, padded_zeros), dim=1) 337 | segments = [noisy_audio] 338 | reshapelast = 0 339 | 340 | # 处理每个语音切片并连接结果 341 | processed_segments = [] 342 | audio_g = [] 343 | 344 | for i, segment in enumerate(segments): 345 | 346 | noisy_amp, noisy_pha, noisy_com = mag_phase_stft(segment, n_fft, hop_size, win_size, 347 | compress_factor) 348 | amp_g, pha_g, com_g = generator(noisy_amp.to(device, non_blocking=True), 349 | noisy_pha.to(device, non_blocking=True)) 350 | audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_factor) 351 | 352 | audio_g = audio_g.squeeze() 353 | if reshapelast == 1 and i == len(segments) - 2: 354 | audio_g = audio_g[:-(cfg['training_cfg']['segment_size'] - last_segment_size)] 355 | # print(orig_size) 356 | 357 | processed_segments.append(audio_g) 358 | 359 | # 将所有处理后的片段连接成一个完整的语音 360 | 361 | processed_audio = torch.cat(processed_segments, dim=-1) 362 | 363 | # 裁切末尾部分,保留noisy_wav长度的部分 364 | audio_g = processed_audio[:orig_size] 365 | 366 | mag_g, pha_g, com_g = mag_phase_stft(audio_g, n_fft, hop_size, win_size, 367 | compress_factor) 368 | 369 | mag_g = torch.autograd.Variable(mag_g.to(device, non_blocking=True)) 370 | pha_g = torch.autograd.Variable(pha_g.to(device, non_blocking=True)) 371 | 372 | com_g = torch.autograd.Variable(com_g.to(device, non_blocking=True)) 373 | 374 | mag_g = mag_g.squeeze() 375 | pha_g = torch.unsqueeze(pha_g, dim=0) 376 | 377 | # com_g = com_g.squeeze() 378 | clean_mag = clean_mag.squeeze() 379 | # clean_pha = clean_pha.squeeze() 380 | 381 | clean_com = clean_com.squeeze() 382 | audios_r += torch.split(clean_audio, 1, dim=0) # [1, T] * B 383 | # print(clean_audio.size()) 384 | # # print(len(audios_r)) 385 | audio_g = torch.unsqueeze(audio_g, dim=0) 386 | audios_g += torch.split(audio_g, 1, dim=0) 387 | 388 | 389 | val_mag_err_tot += F.mse_loss(clean_mag, mag_g).item() 390 | val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g, cfg) 391 | val_pha_err_tot += (val_ip_err + val_gd_err + val_iaf_err).item() 392 | val_com_err_tot += F.mse_loss(clean_com, com_g).item() 393 | 394 | val_mag_err = val_mag_err_tot / (j+1) 395 | val_pha_err = val_pha_err_tot / (j+1) 396 | val_com_err = val_com_err_tot / (j+1) 397 | val_pesq_score = pesq_score(audios_r, audios_g, cfg).item() 398 | print('Steps : {:d}, PESQ Score: {:4.3f}, s/b : {:4.3f}'. 399 | format(steps, val_pesq_score, time.time() - start_b)) 400 | sw.add_scalar("Validation/PESQ Score", val_pesq_score, steps) 401 | sw.add_scalar("Validation/Magnitude Loss", val_mag_err, steps) 402 | sw.add_scalar("Validation/Phase Loss", val_pha_err, steps) 403 | sw.add_scalar("Validation/Complex Loss", val_com_err, steps) 404 | 405 | generator.train() 406 | 407 | # Print best validation PESQ score in terminal 408 | if val_pesq_score >= best_pesq: 409 | best_pesq = val_pesq_score 410 | best_pesq_step = steps 411 | print(f"valid: PESQ {val_pesq_score}, Mag_loss {val_mag_err}, Phase_loss {val_pha_err}. Best_PESQ: {best_pesq} at step {best_pesq_step}") 412 | 413 | steps += 1 414 | 415 | scheduler_g.step() 416 | scheduler_d.step() 417 | 418 | if rank == 0: 419 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) 420 | 421 | def main(): 422 | parser = argparse.ArgumentParser() 423 | parser.add_argument('--exp_folder', default='exp') 424 | parser.add_argument('--exp_name', default='MambaSEUNet_emb_32') 425 | parser.add_argument('--config', default='recipes/Mamba-SEUNet/Mamba-SEUNet.yaml') 426 | args = parser.parse_args() 427 | 428 | cfg = load_config(args.config) 429 | seed = cfg['env_setting']['seed'] 430 | num_gpus = cfg['env_setting']['num_gpus'] 431 | available_gpus = torch.cuda.device_count() 432 | 433 | if num_gpus > available_gpus: 434 | warnings.warn( 435 | f"Warning: The actual number of available GPUs ({available_gpus}) is less than the .yaml config ({num_gpus}). Auto reset to num_gpu = {available_gpus}", 436 | UserWarning 437 | ) 438 | cfg['env_setting']['num_gpus'] = available_gpus 439 | num_gpus = available_gpus 440 | time.sleep(5) 441 | 442 | 443 | initialize_seed(seed) 444 | args.exp_path = os.path.join(args.exp_folder, args.exp_name) 445 | build_env(args.config, 'config.yaml', args.exp_path) 446 | 447 | if torch.cuda.is_available(): 448 | num_available_gpus = torch.cuda.device_count() 449 | print(f"Number of GPUs available: {num_available_gpus}") 450 | print_gpu_info(num_available_gpus, cfg) 451 | else: 452 | print("CUDA is not available.") 453 | 454 | if num_gpus > 1: 455 | mp.spawn(train, nprocs=num_gpus, args=(args, cfg)) 456 | else: 457 | train(0, args, cfg) 458 | 459 | if __name__ == '__main__': 460 | main() 461 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import os 4 | import shutil 5 | import glob 6 | from torch.distributed import init_process_group 7 | 8 | def load_config(config_path): 9 | """Load configuration from a YAML file.""" 10 | with open(config_path, 'r') as file: 11 | return yaml.safe_load(file) 12 | 13 | def initialize_seed(seed): 14 | """Initialize the random seed for both CPU and GPU.""" 15 | torch.manual_seed(seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed(seed) 18 | 19 | def print_gpu_info(num_gpus, cfg): 20 | """Print information about available GPUs and batch size per GPU.""" 21 | for i in range(num_gpus): 22 | gpu_name = torch.cuda.get_device_name(i) 23 | print(f"GPU {i}: {gpu_name}") 24 | print('Batch size per GPU:', int(cfg['training_cfg']['batch_size'] / num_gpus)) 25 | 26 | def initialize_process_group(cfg, rank): 27 | """Initialize the process group for distributed training.""" 28 | init_process_group( 29 | backend=cfg['env_setting']['dist_cfg']['dist_backend'], 30 | init_method=cfg['env_setting']['dist_cfg']['dist_url'], 31 | world_size=cfg['env_setting']['dist_cfg']['world_size'] * cfg['env_setting']['num_gpus'], 32 | rank=rank 33 | ) 34 | 35 | def log_model_info(rank, model, exp_path): 36 | """Log model information and create necessary directories.""" 37 | print(model) 38 | num_params = sum(p.numel() for p in model.parameters()) 39 | print("Generator Parameters :", num_params) 40 | os.makedirs(exp_path, exist_ok=True) 41 | os.makedirs(os.path.join(exp_path, 'logs'), exist_ok=True) 42 | print("checkpoints directory :", exp_path) 43 | 44 | def load_ckpts(args, device): 45 | """Load checkpoints if available.""" 46 | if os.path.isdir(args.exp_path): 47 | cp_g = scan_checkpoint(args.exp_path, 'g_') 48 | cp_do = scan_checkpoint(args.exp_path, 'do_') 49 | if cp_g is None or cp_do is None: 50 | return None, None, 0, -1 51 | state_dict_g = load_checkpoint(cp_g, device) 52 | state_dict_do = load_checkpoint(cp_do, device) 53 | return state_dict_g, state_dict_do, state_dict_do['steps'] + 1, state_dict_do['epoch'] 54 | return None, None, 0, -1 55 | 56 | def load_checkpoint(filepath, device): 57 | assert os.path.isfile(filepath) 58 | print("Loading '{}'".format(filepath)) 59 | checkpoint_dict = torch.load(filepath, map_location=device) 60 | print("Complete.") 61 | return checkpoint_dict 62 | 63 | 64 | def save_checkpoint(filepath, obj): 65 | print("Saving checkpoint to {}".format(filepath)) 66 | torch.save(obj, filepath) 67 | print("Complete.") 68 | 69 | 70 | def scan_checkpoint(cp_dir, prefix): 71 | pattern = os.path.join(cp_dir, prefix + '????????' + '.pth') 72 | cp_list = glob.glob(pattern) 73 | if len(cp_list) == 0: 74 | return None 75 | return sorted(cp_list)[-1] 76 | 77 | def build_env(config, config_name, exp_path): 78 | os.makedirs(exp_path, exist_ok=True) 79 | t_path = os.path.join(exp_path, config_name) 80 | if config != t_path: 81 | shutil.copyfile(config, t_path) 82 | 83 | def load_optimizer_states(optimizers, state_dict_do): 84 | """Load optimizer states from checkpoint.""" 85 | if state_dict_do is not None: 86 | optim_g, optim_d = optimizers 87 | optim_g.load_state_dict(state_dict_do['optim_g']) 88 | optim_d.load_state_dict(state_dict_do['optim_d']) 89 | --------------------------------------------------------------------------------