├── README.md └── metric.py /README.md: -------------------------------------------------------------------------------- 1 | # Speech-Enhancement-Metrics-SNR-SDRi-SISDRi 2 | 3 | SNR 4 | SDRi 5 | SISDRi 6 | SNRseg 7 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from mir_eval.separation import bss_eval_sources 4 | 5 | class AudioMetircs(): 6 | def __init__(self, reference, estimation, mix, sr): 7 | super(AudioMetircs, self).__init__() 8 | 9 | self.SISDR = sisdr(reference, estimation) 10 | self.SNR = snr(reference, estimation) 11 | self.SDRi = cal_SDRi(reference, estimation, mix) 12 | self.SISDRi = cal_SISNRi(reference, estimation, mix) 13 | self.SNRseg = SNRseg(reference, estimation, sr) 14 | 15 | def sisdr(reference, estimation, sr=16000): 16 | """ 17 | Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) 18 | Args: 19 | reference: numpy.ndarray, [..., T] 20 | estimation: numpy.ndarray, [..., T] 21 | Returns: 22 | SI-SDR 23 | """ 24 | estimation, reference = np.broadcast_arrays(estimation, reference) 25 | reference_energy = np.sum(reference ** 2, axis=-1, keepdims=True) 26 | 27 | optimal_scaling = np.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy 28 | 29 | projection = optimal_scaling * reference 30 | 31 | noise = estimation - projection 32 | 33 | ratio = np.sum(projection ** 2, axis=-1) / np.sum(noise ** 2, axis=-1) 34 | return np.mean(10 * np.log10(ratio)) 35 | 36 | def snr(reference, estimation): 37 | numerator = np.sum(reference ** 2, axis=-1, keepdims=True) 38 | denominator = np.sum((estimation-reference) ** 2, axis=-1, keepdims=True) 39 | 40 | return np.mean(10 * np.log10(numerator / denominator)) 41 | 42 | def cal_SDR(reference, estimation, eps=1e-8): 43 | """ 44 | Calculate Source-to-Distortion Ratio 45 | Args: 46 | reference:numpy.ndarray, [B, T] 47 | estimation:numpy.ndarray, [B, T] 48 | """ 49 | origin_power = np.sum(reference**2, 1, keepdims=True) + eps # [B, 1] 50 | scale = np.sum(reference*estimation, 1, keepdims=True) / origin_power # [B, 1] 51 | 52 | est_true = scale * reference # [B, T] 53 | est_res = estimation - est_true 54 | 55 | true_power = np.sum(est_true**2, 1) 56 | res_power = np.sum(est_res**2, 1) 57 | 58 | return 10*np.log10(true_power) - 10*np.log10(res_power) 59 | 60 | def cal_SDRi(src_ref, src_est, mix): 61 | """Calculate Source-to-Distortion Ratio improvement (SDRi). 62 | NOTE: bss_eval_sources is very very slow. 63 | Args: 64 | src_ref: numpy.ndarray, [N, T] 65 | src_est: numpy.ndarray, [N, T] 66 | mix: numpy.ndarray, [N, T] 67 | N 要求是不同的声源,这里将不同的声源换成不同的Batch_size 68 | Returns: 69 | average_SDRi 70 | """ 71 | counter = mix.shape[0] 72 | # src_anchor = np.stack([mix, mix], axis=0) 73 | sdr = cal_SDR(src_ref, src_est) 74 | sdr0 = cal_SDR(src_ref, mix) 75 | avg_SDRi = np.sum(sdr - sdr0) / counter 76 | # print("SDRi1: {0:.2f}, SDRi2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[1])) 77 | return avg_SDRi 78 | 79 | 80 | def cal_SISNRi(src_ref, src_est, mix): 81 | """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) 82 | Args: 83 | src_ref: numpy.ndarray, [N, T] 84 | src_est: numpy.ndarray, [N, T] 85 | mix: numpy.ndarray, [N, T] 86 | Returns: 87 | average_SISNRi 88 | """ 89 | sisnr1 = sisdr(src_ref, src_est) 90 | sisnr1b = sisdr(src_ref, mix) 91 | # print("SISNR base1 {0:.2f} SISNR base2 {1:.2f}, avg {2:.2f}".format( 92 | # sisnr1b, sisnr2b, (sisnr1b+sisnr2b)/2)) 93 | # print("SISNRi1: {0:.2f}, SISNRi2: {1:.2f}".format(sisnr1, sisnr2)) 94 | avg_SISNRi = sisnr1 - sisnr1b 95 | return avg_SISNRi 96 | 97 | # Reference : https://github.com/schmiph2/pysepm 98 | 99 | def SNRseg(clean_speech, processed_speech, fs, frameLen=0.03, overlap=0.75): 100 | eps = np.finfo(np.float64).eps 101 | 102 | winlength = round(frameLen * fs) # window length in samples 103 | skiprate = int(np.floor((1 - overlap) * frameLen * fs)) # window skip in samples 104 | MIN_SNR = -15 # minimum SNR in dB 105 | MAX_SNR = 35 # maximum SNR in dB 106 | 107 | hannWin = 0.5 * (1 - np.cos(2 * np.pi * np.arange(1, winlength + 1) / (winlength + 1))) 108 | clean_speech_framed = extract_overlapped_windows(clean_speech, winlength, winlength - skiprate, hannWin) 109 | processed_speech_framed = extract_overlapped_windows(processed_speech, winlength, winlength - skiprate, hannWin) 110 | 111 | signal_energy = np.power(clean_speech_framed, 2).sum(-1) 112 | noise_energy = np.power(clean_speech_framed - processed_speech_framed, 2).sum(-1) 113 | 114 | segmental_snr = 10 * np.log10(signal_energy / (noise_energy + eps) + eps) 115 | segmental_snr[segmental_snr < MIN_SNR] = MIN_SNR 116 | segmental_snr[segmental_snr > MAX_SNR] = MAX_SNR 117 | segmental_snr = segmental_snr[:-1] # remove last frame -> not valid 118 | return np.mean(segmental_snr) 119 | 120 | def extract_overlapped_windows(x,nperseg,noverlap,window=None): 121 | step = nperseg - noverlap 122 | shape = x.shape[:-1]+((x.shape[-1]-noverlap)//step, nperseg) 123 | strides = x.strides[:-1]+(step*x.strides[-1], x.strides[-1]) 124 | result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) 125 | if window is not None: 126 | result = window * result 127 | return result 128 | 129 | if __name__ == '__main__': 130 | a = np.random.randn(3, 4) 131 | b = np.random.randn(3, 4) 132 | mix = np.random.randn(3, 4) 133 | print(cal_SDRi(a, b, mix)) 134 | --------------------------------------------------------------------------------