├── .gitignore ├── README.md ├── acoustic_scaler.py ├── apc_snr.py ├── constant.py ├── conv_stft.py ├── pmsqe ├── __init__.py ├── bark_matrix_16k.mat └── pmsqe.py └── results ├── different_loss_compare ├── README.md ├── analysis.py └── csvs │ ├── dcrn │ ├── APC-SNR.csv │ ├── MSE.csv │ ├── PMSQE.csv │ ├── PMSQE1+APC-SNR.csv │ ├── PMSQE1.csv │ ├── SI-SNR.csv │ └── STOI.csv │ └── model_in_paper │ ├── APC-SNR.csv │ ├── MSE.csv │ ├── PMSQE.csv │ ├── PMSQE1+APC-SNR.csv │ ├── PMSQE1.csv │ ├── SI-SNR.csv │ └── STOI.csv └── eps_and_theta ├── README.md ├── analysis.py ├── csvs ├── eps │ ├── eps_0.100_theta_0.010.csv │ ├── eps_0.500_theta_0.010.csv │ ├── eps_1.000_theta_0.010.csv │ ├── eps_1.500_theta_0.010.csv │ └── eps_2.000_theta_0.010.csv └── theta │ ├── eps_1.000_theta_0.0001.csv │ ├── eps_1.000_theta_0.0005.csv │ ├── eps_1.000_theta_0.001.csv │ ├── eps_1.000_theta_0.005.csv │ ├── eps_1.000_theta_0.010.csv │ ├── eps_1.000_theta_0.050.csv │ ├── eps_1.000_theta_0.100.csv │ └── eps_1.000_theta_0.500.csv └── pics ├── eps.png └── theta.png /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | *.out 7 | # C extensions 8 | *.so 9 | 10 | 11 | *.svg 12 | *.npy 13 | *.jpg 14 | *.wav 15 | *.mp3 16 | *.ckpt 17 | *.json 18 | logs/ 19 | lightning_logs/ 20 | # pickle 21 | *.log 22 | .csv 23 | *.data 24 | # torch 25 | *.pkl 26 | *.pth 27 | *.onnx 28 | version_*/ 29 | *.yaml 30 | *.0 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # celery beat schedule file 117 | celerybeat-schedule 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # MODIFIED SECITON 150 | #PyCharm project 151 | .idea/ 152 | .ipynb_checkpoints/ 153 | 154 | [user] 155 | name=ScorpioMiku 156 | email=1056992492@qq.com 157 | [credential] 158 | helper = store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The official repo of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" 2 | 3 | > This is a loss function for wide-band signals. Although the manuscript was rejected by ICASSP, the method is effective (The main problem appeared in the writing). To ensure the reproducibility of other papers (HGCN), we open-source it. 4 | 5 | ```text 6 | @misc{wang2021deep, 7 | title={A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement}, 8 | author={Tianrui Wang and Weibin Zhu}, 9 | year={2021}, 10 | eprint={2108.11877}, 11 | archivePrefix={arXiv}, 12 | primaryClass={eess.AS} 13 | } 14 | ``` 15 | 16 | -------------------------------------------------------------------------------- /acoustic_scaler.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author:WangTianRui 3 | # Date :2021/3/30 9:47 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import constant as constant 8 | 9 | 10 | class AcousticScaler(nn.Module): 11 | def __init__(self, theta=0.01, eps=1.0, mag_bins=256, no_grad=True): 12 | self.no_grad = no_grad 13 | self.eps = eps 14 | super(AcousticScaler, self).__init__() 15 | self.register_buffer( 16 | "zwicker_power", 17 | torch.tensor(constant.get_zwicker_power_to_nftt(mag_bins), dtype=torch.float).unsqueeze(0).unsqueeze(0) 18 | ) 19 | self.theta = theta 20 | self.mag_bins = mag_bins 21 | 22 | def forward(self, est, clean=None): 23 | """` 24 | :param est: B,2*F,T 25 | :param clean:B,2*F,T 26 | :return: B,2*F,T 27 | """ 28 | need_pad = 0 29 | if self.no_grad: 30 | with torch.no_grad(): 31 | est = est.permute(0, 2, 1) # B,T,2*F 32 | clean = clean.permute(0, 2, 1) 33 | est_real, est_imag = torch.chunk(est, 2, dim=-1) 34 | clean_real, clean_imag = torch.chunk(clean, 2, dim=-1) 35 | if est_real.size(-1) != self.mag_bins: 36 | est_real, est_imag = est_real[..., :self.mag_bins], est_imag[..., :self.mag_bins] 37 | clean_real, clean_imag = clean_real[..., :self.mag_bins], clean_imag[..., :self.mag_bins] 38 | need_pad = est_real.size(-1) - self.mag_bins 39 | est_power = (est_real ** 2 + est_imag ** 2) # B,T,F 40 | clean_power = (clean_real ** 2 + clean_imag ** 2) # B,T,F 41 | est_scales = ((est_power + self.eps) ** ((self.zwicker_power - 1) * 0.5)) 42 | clean_scales = ((clean_power + self.eps) ** ((self.zwicker_power - 1) * 0.5)) 43 | return self.double_size(est_scales, need_pad).clamp_(self.theta, 1), \ 44 | self.double_size(clean_scales, need_pad).clamp_(self.theta, 1) 45 | else: 46 | est = est.permute(0, 2, 1) # B,T,2*F 47 | clean = clean.permute(0, 2, 1) 48 | est_real, est_imag = torch.chunk(est, 2, dim=-1) 49 | clean_real, clean_imag = torch.chunk(clean, 2, dim=-1) 50 | if est_real.size(-1) != self.mag_bins: 51 | est_real, est_imag = est_real[..., :self.mag_bins], est_imag[..., :self.mag_bins] 52 | clean_real, clean_imag = clean_real[..., :self.mag_bins], clean_imag[..., :self.mag_bins] 53 | need_pad = est_real.size(-1) - self.mag_bins 54 | est_power = (est_real ** 2 + est_imag ** 2) # B,T,F 55 | clean_power = (clean_real ** 2 + clean_imag ** 2) # B,T,F 56 | est_scales = ((est_power + self.eps) ** ((self.zwicker_power - 1) * 0.5)) 57 | clean_scales = ((clean_power + self.eps) ** ((self.zwicker_power - 1) * 0.5)) 58 | return self.double_size(est_scales, need_pad).clamp_(self.theta, 1), \ 59 | self.double_size(clean_scales, need_pad).clamp_(self.theta, 1) 60 | 61 | 62 | def double_size(inp, need_pad): 63 | """ 64 | :param inp:B,T,F 65 | :return: B,2*F,T 66 | """ 67 | if need_pad != 0: 68 | return torch.nn.functional.pad( 69 | inp, [0, need_pad, 0, 0], value=1e-8 70 | ).repeat(1, 1, 2).permute(0, 2, 1) 71 | else: 72 | return inp.repeat(1, 1, 2).permute(0, 2, 1) 73 | -------------------------------------------------------------------------------- /apc_snr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author:WangTianRui 3 | # Date :2021/4/4 20:29 4 | import acoustic_scaler as acoustic_scaler 5 | import torch 6 | import pmsqe.pmsqe as pmsqe 7 | import conv_stft as conv_stft 8 | 9 | 10 | def power(x): 11 | """ 12 | :param x:B,2*F,T 13 | :return: 14 | """ 15 | return torch.stack(torch.chunk(x, 2, dim=-2), dim=-1).pow(2).sum(dim=-1) 16 | 17 | 18 | def stft_snr(est, clean, eps=1e-8): 19 | s1 = est.reshape(est.size(0), -1) 20 | s2 = clean.reshape(clean.size(0), -1) 21 | s1_s2_norm = torch.sum(s1 * s2, -1, keepdim=True) 22 | s2_s2_norm = torch.sum(s2 * s2, -1, keepdim=True) 23 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 24 | e_nosie = s1 - s_target 25 | target_norm = torch.sum(s_target * s_target, -1, keepdim=True) 26 | noise_norm = torch.sum(e_nosie * e_nosie, -1, keepdim=True) 27 | snr = 10 * torch.log10(target_norm / (noise_norm + eps) + eps) 28 | return -(torch.mean(snr)) 29 | 30 | 31 | class APC_SNR_multi_filter(torch.nn.Module): 32 | # APC_SNR_multi_filter is better than APC_SNR with more GPU memory. 33 | # criterion = APC_SNR_multi_filter(model_hop=128, model_winlen=512, mag_bins=256, theta=0.01, hops=[8, 16, 32, 64]) 34 | def __init__(self, model_hop, model_winlen, theta, mag_bins, hops=()): 35 | super(APC_SNR_multi_filter, self).__init__() 36 | self.multi_hop = hops 37 | self.scaler = acoustic_scaler.AcousticScaler(theta=theta, mag_bins=mag_bins, no_grad=True) 38 | self.eps = 1e-8 39 | self.pmsqe = pmsqe.SingleSrcPMSQE(sample_rate=16000) 40 | self.model_stft = conv_stft.ConvSTFT(model_winlen, model_hop, mag_bins * 2, "hanning", 'complex') 41 | self.model_istft = conv_stft.ConviSTFT(model_winlen, model_hop, mag_bins * 2, "hanning", 'complex') 42 | if model_hop in self.multi_hop: 43 | self.multi_hop.remove(model_hop) 44 | self.stfts = torch.nn.ModuleList() 45 | for hop in self.multi_hop: 46 | self.stfts.append(conv_stft.ConvSTFT(model_winlen, hop, mag_bins * 2, "hanning", 'complex')) 47 | 48 | def forward(self, est, clean): 49 | """ 50 | :param est: B,2*F,T 51 | :param clean: B,T 52 | :return: 53 | """ 54 | est_time_domain = self.model_istft(est).squeeze(1) # B,T 55 | clean_stft = self.model_stft(clean) # B,2*F,T 56 | 57 | pmsqe_score = self.pmsqe_score(est, clean_stft) 58 | 59 | est_scales, clean_scales = self.scaler(est, clean_stft) 60 | est = est * est_scales 61 | clean_stft = clean_stft * clean_scales 62 | stft_scaled_snr = stft_snr(est, clean_stft) 63 | 64 | for filter in self.stfts: 65 | clean_stft = filter(clean) 66 | est_stft = filter(est_time_domain) 67 | 68 | pmsqe_score += self.pmsqe_score(est_stft, clean_stft) 69 | est_scales, clean_scales = self.scaler(est_stft, clean_stft) 70 | stft_scaled_snr += stft_snr(est_stft * est_scales, clean_stft * clean_scales) 71 | 72 | return stft_scaled_snr / (len(self.stfts) + 1) + pmsqe_score / (len(self.stfts) + 1) 73 | 74 | def pmsqe_score(self, est_stft, clean_stft): 75 | mag_est = power(est_stft) 76 | mag_clean = power(clean_stft) 77 | pmsqe_score = self.pmsqe(mag_est, mag_clean) 78 | return pmsqe_score.mean() 79 | 80 | 81 | class APC_SNR(torch.nn.Module): 82 | def __init__(self, model_hop, model_winlen, theta, mag_bins): 83 | super(APC_SNR, self).__init__() 84 | self.scaler = acoustic_scaler.AcousticScaler(theta=theta, mag_bins=mag_bins, no_grad=True) 85 | self.eps = 1e-8 86 | self.pmsqe = pmsqe.SingleSrcPMSQE(sample_rate=16000) 87 | self.model_stft = conv_stft.ConvSTFT(model_winlen, model_hop, mag_bins * 2, "hanning", 'complex') 88 | 89 | def forward(self, est, clean): 90 | """ 91 | :param est: B,2*F,T 92 | :param clean: B,T 93 | :return: 94 | """ 95 | clean_stft = self.model_stft(clean) # B,2*F,T 96 | pmsqe_score = self.pmsqe_score(est, clean_stft) 97 | est_scales, clean_scales = self.scaler(est, clean_stft) 98 | est = est * est_scales 99 | clean_stft = clean_stft * clean_scales 100 | stft_scaled_snr = stft_snr(est, clean_stft) 101 | return stft_scaled_snr + pmsqe_score 102 | 103 | def pmsqe_score(self, est_stft, clean_stft): 104 | mag_est = power(est_stft) 105 | mag_clean = power(clean_stft) 106 | pmsqe_score = self.pmsqe(mag_est, mag_clean) 107 | return pmsqe_score.mean() 108 | -------------------------------------------------------------------------------- /constant.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author:WangTianRui 3 | # Date :2021/1/6 10:59 4 | import numpy as np 5 | 6 | np.set_printoptions(threshold=1000000) # 全部输出 7 | # bark域的频域跨度 8 | nr_of_hz_bands_per_bark_band_16k = [1, 1, 1, 1, 1, 9 | 1, 1, 1, 2, 1, 10 | 1, 1, 1, 1, 2, 11 | 1, 1, 2, 2, 2, 12 | 2, 2, 2, 2, 2, 13 | 3, 3, 3, 3, 4, 14 | 3, 4, 5, 4, 5, 15 | 6, 6, 7, 8, 9, 16 | 9, 12, 12, 15, 16, 17 | 18, 21, 25, 20] 18 | 19 | # bark域的中心 20 | centre_of_band_bark_16k = [0.078672, 0.316341, 0.636559, 0.961246, 1.290450, 21 | 1.624217, 1.962597, 2.305636, 2.653383, 3.005889, 22 | 3.363201, 3.725371, 4.092449, 4.464486, 4.841533, 23 | 5.223642, 5.610866, 6.003256, 6.400869, 6.803755, 24 | 7.211971, 7.625571, 8.044611, 8.469146, 8.899232, 25 | 9.334927, 9.776288, 10.223374, 10.676242, 11.134952, 26 | 11.599563, 12.070135, 12.546731, 13.029408, 13.518232, 27 | 14.013264, 14.514566, 15.022202, 15.536238, 16.056736, 28 | 16.583761, 17.117382, 17.657663, 18.204674, 18.758478, 29 | 19.319147, 19.886751, 20.461355, 21.043034] 30 | 31 | # # 在频域上bark域的中心 32 | centre_of_band_hz_16k = [7.867213, 31.634144, 63.655895, 96.124611, 129.044968, 33 | 162.421738, 196.259659, 230.563568, 265.338348, 300.588867, 34 | 336.320129, 372.537140, 409.244934, 446.448578, 484.568604, 35 | 526.600586, 570.303833, 619.423340, 672.121643, 728.525696, 36 | 785.675964, 846.835693, 909.691650, 977.063293, 1049.861694, 37 | 1129.635986, 1217.257568, 1312.109497, 1412.501465, 1517.999390, 38 | 1628.894165, 1746.194336, 1871.568848, 2008.776123, 2158.979248, 39 | 2326.743164, 2513.787109, 2722.488770, 2952.586670, 3205.835449, 40 | 3492.679932, 3820.219238, 4193.938477, 4619.846191, 5100.437012, 41 | 5636.199219, 6234.313477, 6946.734863, 7796.473633] 42 | 43 | # # 每个bark在频域上对应的宽度 44 | width_of_band_hz_16k = [15.734426, 31.799433, 32.244064, 32.693359, 33.147385, 45 | 33.606140, 34.069702, 34.538116, 35.011429, 35.489655, 46 | 35.972870, 36.461121, 36.954407, 37.452911, 40.269653, 47 | 42.311859, 45.992554, 51.348511, 55.040527, 56.775208, 48 | 58.699402, 62.445862, 64.820923, 69.195374, 76.745667, 49 | 84.016235, 90.825684, 97.931152, 103.348877, 107.801880, 50 | 113.552246, 121.490601, 130.420410, 143.431763, 158.486816, 51 | 176.872803, 198.314697, 219.549561, 240.600098, 268.702393, 52 | 306.060059, 349.937012, 398.686279, 454.713867, 506.841797, 53 | 564.863770, 637.261230, 794.717285, 931.068359] 54 | 55 | 56 | def get_nr_of_hz_bands_per_bark_bank_16k(n_fft): 57 | if n_fft == 256: 58 | return nr_of_hz_bands_per_bark_band_16k 59 | else: 60 | aranges = [] 61 | for index in range(len(centre_of_band_hz_16k)): 62 | width = width_of_band_hz_16k[index] / 2 63 | arange = [centre_of_band_hz_16k[index] - width, centre_of_band_hz_16k[index] + width] 64 | aranges.append(arange) 65 | result = np.zeros(len(aranges)) 66 | hz_per_fft_bin = 8000 / n_fft 67 | for fft_bin in range(0, n_fft): 68 | hz_now = fft_bin * hz_per_fft_bin 69 | # print(hz_now) 70 | for index in range(len(aranges)): 71 | if aranges[index][0] <= hz_now < aranges[index][1]: 72 | result[index] += 1 73 | break 74 | return result.astype(np.int) 75 | 76 | 77 | def get_zwicker_power_to_nftt(n_fft=256): 78 | """ 79 | :param n_fft: <=256 80 | :return: 81 | """ 82 | zp = zwicker_power(49) 83 | zp_nfft = np.zeros(n_fft) 84 | nr_of_hz_bands_per_bark_band_16k_ = get_nr_of_hz_bands_per_bark_bank_16k(n_fft) 85 | current_inx_of_f = 0 86 | for inx, item in enumerate(nr_of_hz_bands_per_bark_band_16k_): 87 | zp_nfft[current_inx_of_f:current_inx_of_f + item] += zp[inx] 88 | current_inx_of_f += item 89 | return zp_nfft 90 | 91 | 92 | def zwicker_power(nb): 93 | modified_zwicker_power = np.zeros(nb) 94 | for band in range(nb): 95 | if centre_of_band_bark_16k[band] < 4: 96 | h = 6 / (centre_of_band_bark_16k[band] + 2) 97 | else: 98 | h = 1 99 | if h > 2: 100 | h = 2 101 | h = h ** 0.15 102 | modified_zwicker_power[band] = 0.23 * h 103 | return modified_zwicker_power 104 | 105 | 106 | if __name__ == '__main__': 107 | print(get_zwicker_power_to_nftt(256)) 108 | -------------------------------------------------------------------------------- /conv_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as torchF 5 | from scipy.signal import get_window 6 | 7 | 8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 9 | if win_type == 'None' or win_type is None: 10 | window = np.ones(win_len) 11 | else: 12 | if win_type == "hanning sqrt": 13 | window = get_window("hanning", win_len, fftbins=True) # win_len 14 | window = np.sqrt(window) 15 | else: 16 | window = get_window(win_type, win_len, fftbins=True) # win_len 17 | 18 | N = fft_len 19 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 20 | # print(fourier_basis.shape) 21 | real_kernel = np.real(fourier_basis) 22 | imag_kernel = np.imag(fourier_basis) 23 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T # 514,400 24 | # print(kernel.shape) 25 | 26 | if invers: 27 | kernel = np.linalg.pinv(kernel).T 28 | # np.set_printoptions(threshold=1000000) # 全部输出 29 | # print(kernel.shape) 30 | # print(kernel[:5]) 31 | 32 | kernel = kernel * window 33 | kernel = kernel[:, None, :] 34 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32)) 35 | 36 | 37 | class ConvSTFT(nn.Module): 38 | 39 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'): 40 | super(ConvSTFT, self).__init__() 41 | 42 | if fft_len == None: 43 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 44 | else: 45 | self.fft_len = fft_len 46 | 47 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 48 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 49 | self.register_buffer('weight', kernel) 50 | self.feature_type = feature_type 51 | self.stride = win_inc 52 | self.win_len = win_len 53 | self.dim = self.fft_len 54 | 55 | def forward(self, inputs): 56 | if inputs.dim() == 2: 57 | inputs = torch.unsqueeze(inputs, 1) 58 | # inputs = F.pad(inputs, [(self.win_len - self.stride), (self.win_len - self.stride)]) 59 | inputs = torchF.pad(inputs, [(self.win_len - self.stride), (self.win_len - self.stride)]) 60 | outputs = torchF.conv1d(inputs, self.weight, stride=self.stride) 61 | 62 | if self.feature_type == 'complex': 63 | return outputs # B,F,T 64 | else: 65 | dim = self.dim // 2 + 1 66 | real = outputs[:, :dim, :] 67 | imag = outputs[:, dim:, :] 68 | mags = torch.sqrt(real ** 2 + imag ** 2) 69 | phase = torch.atan2(imag, real) 70 | return mags, phase 71 | 72 | 73 | class ConviSTFT(nn.Module): 74 | 75 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 76 | super(ConviSTFT, self).__init__() 77 | if fft_len == None: 78 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 79 | else: 80 | self.fft_len = fft_len 81 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 82 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 83 | self.register_buffer('weight', kernel) 84 | self.feature_type = feature_type 85 | self.win_type = win_type 86 | self.win_len = win_len 87 | self.stride = win_inc 88 | self.stride = win_inc 89 | self.dim = self.fft_len 90 | self.register_buffer('window', window) 91 | self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) 92 | # t = self.window.repeat(1, 1, 10) ** 2 93 | # coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 94 | # print("coff", coff.size()) 95 | # drawer.plot_mesh(t[0]) 96 | # drawer.plot(coff[0][0][self.win_len - self.stride:-(self.win_len - self.stride)], "coff") 97 | 98 | def forward(self, inputs, phase=None): 99 | """ 100 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 101 | phase: [B, N//2+1, T] (if not none) 102 | """ 103 | 104 | if phase is not None: 105 | real = inputs * torch.cos(phase) 106 | imag = inputs * torch.sin(phase) 107 | inputs = torch.cat([real, imag], 1) 108 | outputs = torchF.conv_transpose1d(inputs, self.weight, stride=self.stride) 109 | 110 | # this is from torch-stft: https://github.com/pseeth/torch-stft 111 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2 112 | coff = torchF.conv_transpose1d(t, self.enframe, stride=self.stride) 113 | outputs = outputs / (coff + 1e-8) 114 | # outputs = torch.where(coff == 0, outputs, outputs/coff) 115 | # print(outputs.size()) 116 | outputs_ = outputs[..., (self.win_len - self.stride):-(self.win_len - self.stride)] 117 | # outputs_ = outputs[..., (self.win_len - self.stride):-(self.win_len - self.stride)] 118 | # 119 | return outputs_ 120 | 121 | 122 | if __name__ == '__main__': 123 | # stft = ConvSTFT(400, 100, 512, "hanning", 'complex') 124 | stft = ConvSTFT(512, 128, 512, "hanning", 'complex') 125 | # istft = ConviSTFT(400, 100, 512, "hanning", 'complex') 126 | istft = ConviSTFT(512, 128, 512, "hanning", 'complex') 127 | test_inp = torch.randn(2, 512) 128 | stft_ = stft(test_inp) 129 | print(stft_.size()) 130 | out = istft(stft_) 131 | -------------------------------------------------------------------------------- /pmsqe/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author:WangTianRui 3 | # Date :2021-08-19 14:15 -------------------------------------------------------------------------------- /pmsqe/bark_matrix_16k.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangtianrui/APC-SNR/938ff22bc7e10ccc38021defc1bdb9eb96031eba/pmsqe/bark_matrix_16k.mat -------------------------------------------------------------------------------- /pmsqe/pmsqe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import tensor 4 | import torch.nn as nn 5 | import pathlib 6 | import os 7 | 8 | 9 | class SingleSrcPMSQE(nn.Module): 10 | """Computes the Perceptual Metric for Speech Quality Evaluation (PMSQE) 11 | as described in [1]. 12 | This version is only designed for 16 kHz (512 length DFT). 13 | Adaptation to 8 kHz could be done by changing the parameters of the 14 | class. 15 | The SLL, frequency and gain equalization are applied in each 16 | sequence independently. 17 | Parameters: 18 | window_name (str): Select the used window function for the correct 19 | factor to be applied. Defaults to sqrt hanning window. 20 | Among ['rect', 'hann', 'sqrt_hann', 'hamming', 'flatTop']. 21 | window_weight (float, optional): Correction to the window factor 22 | applied. 23 | bark_eq (bool, optional): Whether to apply bark equalization. 24 | gain_eq (bool, optional): Whether to apply gain equalization. 25 | sample_rate (int): Sample rate of the input audio. 26 | References 27 | [1] J.M.Martin, A.M.Gomez, J.A.Gonzalez, A.M.Peinado 'A Deep Learning 28 | Loss Function based on the Perceptual Evaluation of the 29 | Speech Quality', IEEE Signal Processing Letters, 2018. 30 | Implemented by Juan M. Martin. Contact: mdjuamart@ugr.es 31 | .. note:: Inspired on the Perceptual Evaluation of the Speech Quality (PESQ) 32 | algorithm, this function consists of two regularization factors : 33 | the symmetrical and asymmetrical distortion in the loudness domain. 34 | Examples 35 | >>> import torch 36 | >>> stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) 37 | >>> # Usage by itself 38 | >>> ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000) 39 | >>> ref_spec = transforms.mag(stft(ref)) 40 | >>> est_spec = transforms.mag(stft(est)) 41 | >>> loss_func = SingleSrcPMSQE() 42 | >>> loss_value = loss_func(est_spec, ref_spec) 43 | """ 44 | 45 | def __init__( 46 | self, 47 | window_name="sqrt_hann", 48 | window_weight=1.0, 49 | bark_eq=True, 50 | gain_eq=True, 51 | sample_rate=16000, 52 | ): 53 | super().__init__() 54 | self.window_name = window_name 55 | self.window_weight = window_weight 56 | self.bark_eq = bark_eq 57 | self.gain_eq = gain_eq 58 | 59 | if sample_rate not in [16000, 8000]: 60 | raise ValueError("Unsupported sample rate {}".format(sample_rate)) 61 | self.sample_rate = sample_rate 62 | if sample_rate == 16000: 63 | self.Sp = 6.910853e-006 64 | self.Sl = 1.866055e-001 65 | self.nbins = 512 66 | self.nbark = 49 67 | else: 68 | self.Sp = 2.764344e-5 69 | self.Sl = 1.866055e-1 70 | self.nbins = 256 71 | self.nbark = 42 72 | # As described in [1] and used in the TF implementation. 73 | self.alpha = 0.1 74 | self.beta = 0.309 * self.alpha 75 | 76 | pow_correc_factor = self.get_correction_factor(window_name) 77 | self.pow_correc_factor = pow_correc_factor * self.window_weight 78 | # Initialize to None and populate as a function of sample rate. 79 | self.abs_thresh_power = None 80 | self.modified_zwicker_power = None 81 | self.width_of_band_bark = None 82 | self.bark_matrix = None 83 | self.mask_sll = None 84 | self.populate_constants(self.sample_rate) 85 | self.sqrt_total_width = torch.sqrt(torch.sum(self.width_of_band_bark)) 86 | self.EPS = 1e-8 87 | 88 | def forward(self, est_targets, targets, pad_mask=None): 89 | """ 90 | Args 91 | est_targets (torch.Tensor): Dimensions (B, T, F). 92 | Padded degraded power spectrum in time-frequency domain. 93 | targets (torch.Tensor): Dimensions (B, T, F). 94 | Zero-Padded reference power spectrum in time-frequency domain. 95 | pad_mask (torch.Tensor, optional): Dimensions (B, T, 1). Mask 96 | to indicate the padding frames. Defaults to all ones. 97 | Dimensions 98 | B: Number of sequences in the batch. 99 | T: Number of time frames. 100 | F: Number of frequency bins. 101 | Returns 102 | torch.tensor of shape (B, ), wD + 0.309 * wDA 103 | ..note:: Dimensions (B, F, T) are also supported by SingleSrcPMSQE but are 104 | less efficient because input tensors are transposed (not inplace). 105 | """ 106 | assert est_targets.shape == targets.shape, str(est_targets.shape) + str(targets.shape) 107 | # Need transpose? Find it out 108 | try: 109 | freq_idx = est_targets.shape.index(self.nbins // 2 + 1) 110 | except ValueError: 111 | raise ValueError( 112 | "Could not find dimension with {} elements in " 113 | "input tensors, verify your inputs" 114 | "".format(self.nbins // 2 + 1) 115 | ) 116 | if freq_idx == 1: 117 | est_targets = est_targets.transpose(1, 2) 118 | targets = targets.transpose(1, 2) 119 | if pad_mask is not None: 120 | # Transpose the pad mask as well if needed. 121 | pad_mask = pad_mask.transpose(1, 2) if freq_idx == 1 else pad_mask 122 | else: 123 | # Suppose no padding if no pad_mask is provided. 124 | pad_mask = torch.ones( 125 | est_targets.shape[0], est_targets.shape[1], 1, device=est_targets.device 126 | ) 127 | # SLL equalization 128 | ref_spectra = self.magnitude_at_sll(targets, pad_mask) 129 | deg_spectra = self.magnitude_at_sll(est_targets, pad_mask) 130 | 131 | # Bark spectra computation 132 | ref_bark_spectra = self.bark_computation(ref_spectra) 133 | deg_bark_spectra = self.bark_computation(deg_spectra) 134 | 135 | # (Optional) frequency and gain equalization 136 | if self.bark_eq: 137 | deg_bark_spectra = self.bark_freq_equalization(ref_bark_spectra, deg_bark_spectra) 138 | 139 | if self.gain_eq: 140 | deg_bark_spectra = self.bark_gain_equalization(ref_bark_spectra, deg_bark_spectra) 141 | 142 | # Distortion matrix computation 143 | sym_d, asym_d = self.compute_distortion_tensors(ref_bark_spectra, deg_bark_spectra) 144 | 145 | # Per-frame distortion 146 | audible_power_ref = self.compute_audible_power(ref_bark_spectra, 1.0) 147 | wd_frame, wda_frame = self.per_frame_distortion(sym_d, asym_d, audible_power_ref) 148 | # Mean distortions over frames : keep batch dims 149 | dims = [-1, -2] 150 | # print(torch.mean(wd_frame),torch.mean(wda_frame)) 151 | pmsqe_frame = (self.alpha * wd_frame + self.beta * wda_frame) * pad_mask 152 | pmsqe = torch.sum(pmsqe_frame, dim=dims) / pad_mask.sum(dims) 153 | return pmsqe 154 | 155 | def magnitude_at_sll(self, spectra, pad_mask, correction=10000000.0): 156 | # Mapping to Eq. (10) in [1]. 157 | # Apply padding and SLL masking 158 | masked_spectra = spectra * pad_mask * self.mask_sll 159 | # Compute mean over frequency 160 | freq_mean_masked_spectra = torch.mean(masked_spectra, dim=-1, keepdim=True) 161 | # Compute mean over time (taking into account padding) 162 | sum_spectra = torch.sum(freq_mean_masked_spectra, dim=-2, keepdim=True) 163 | seq_len = torch.sum(pad_mask, dim=-2, keepdim=True) 164 | mean_pow = sum_spectra / seq_len 165 | # Compute final SLL spectra 166 | return correction * spectra / mean_pow 167 | 168 | def bark_computation(self, spectra): 169 | return self.Sp * torch.matmul(spectra, self.bark_matrix) 170 | 171 | def compute_audible_power(self, bark_spectra, factor=1.0): 172 | # Apply absolute hearing threshold to each band 173 | thr_bark = torch.where( 174 | bark_spectra > self.abs_thresh_power * factor, 175 | bark_spectra, 176 | torch.zeros_like(bark_spectra), 177 | ) 178 | # Sum band power over frequency 179 | return torch.sum(thr_bark, dim=-1, keepdim=True) 180 | 181 | def bark_gain_equalization(self, ref_bark_spectra, deg_bark_spectra): 182 | # Compute audible power 183 | audible_power_ref = self.compute_audible_power(ref_bark_spectra, 1.0) 184 | audible_power_deg = self.compute_audible_power(deg_bark_spectra, 1.0) 185 | # Compute gain factor 186 | gain = (audible_power_ref + 5.0e3) / (audible_power_deg + 5.0e3) 187 | # Limit the range of the gain factor 188 | limited_gain = torch.min(gain, 5.0 * torch.ones_like(gain)) 189 | limited_gain = torch.max(limited_gain, 3.0e-4 * torch.ones_like(limited_gain)) 190 | # Apply gain correction on degraded 191 | return limited_gain * deg_bark_spectra 192 | 193 | def bark_freq_equalization(self, ref_bark_spectra, deg_bark_spectra): 194 | """This version is applied in the degraded directly.""" 195 | # Identification of speech active frames 196 | audible_power_x100 = self.compute_audible_power(ref_bark_spectra, 100.0) 197 | not_silent = audible_power_x100 >= 1.0e7 198 | # Threshold for active bark bins 199 | cond_thr = ref_bark_spectra >= self.abs_thresh_power * 100.0 200 | ref_thresholded = torch.where( 201 | cond_thr, ref_bark_spectra, torch.zeros_like(ref_bark_spectra) 202 | ) 203 | deg_thresholded = torch.where( 204 | cond_thr, deg_bark_spectra, torch.zeros_like(deg_bark_spectra) 205 | ) 206 | # Total power per bark bin (ppb) 207 | avg_ppb_ref = torch.sum( 208 | torch.where(not_silent, ref_thresholded, torch.zeros_like(ref_thresholded)), 209 | dim=-2, 210 | keepdim=True, 211 | ) 212 | avg_ppb_deg = torch.sum( 213 | torch.where(not_silent, deg_thresholded, torch.zeros_like(deg_thresholded)), 214 | dim=-2, 215 | keepdim=True, 216 | ) 217 | # Compute equalizer 218 | equalizer = (avg_ppb_ref + 1000.0) / (avg_ppb_deg + 1000.0) 219 | equalizer = torch.min(equalizer, 100.0 * torch.ones_like(equalizer)) 220 | equalizer = torch.max(equalizer, 0.01 * torch.ones_like(equalizer)) 221 | # Apply frequency correction on degraded 222 | return equalizer * deg_bark_spectra 223 | 224 | def loudness_computation(self, bark_spectra): 225 | # Bark spectra transformed to a sone loudness scale using Zwicker's law 226 | aterm = torch.pow(self.abs_thresh_power / 0.5, self.modified_zwicker_power) 227 | bterm = ( 228 | torch.pow(0.5 + 0.5 * bark_spectra / self.abs_thresh_power, self.modified_zwicker_power) 229 | - 1.0 230 | ) 231 | loudness_dens = self.Sl * aterm * bterm 232 | cond = bark_spectra < self.abs_thresh_power 233 | return torch.where(cond, torch.zeros_like(loudness_dens), loudness_dens) 234 | 235 | def compute_distortion_tensors(self, ref_bark_spec, deg_bark_spec): 236 | # After bark spectra are compensated, transform to sone loudness 237 | original_loudness = self.loudness_computation(ref_bark_spec) 238 | distorted_loudness = self.loudness_computation(deg_bark_spec) 239 | # Loudness difference 240 | r = torch.abs(distorted_loudness - original_loudness) 241 | # print(torch.mean(r)) 242 | # drawer.plot_mesh(r[0].data.numpy(), "r:" + str(torch.mean(r).data)) 243 | # Masking effect computation 244 | m = 0.25 * torch.min(original_loudness, distorted_loudness) 245 | # Center clipping using masking effect 246 | sym_d = torch.max(r - m, torch.ones_like(r) * self.EPS) 247 | # print(torch.mean(sym_d)) 248 | # drawer.plot_mesh(sym_d[0].data.numpy(), "sym_d:" + str(torch.mean(sym_d).data)) 249 | # Asymmetry factor computation 250 | asym = torch.pow((deg_bark_spec + 50.0) / (ref_bark_spec + 50.0), 1.2) 251 | cond = asym < 3.0 * torch.ones_like(asym) 252 | asym_factor = torch.where( 253 | cond, torch.zeros_like(asym), torch.min(asym, 12.0 * torch.ones_like(asym)) 254 | ) 255 | # Asymmetric Disturbance matrix computation 256 | asym_d = asym_factor * sym_d 257 | return sym_d, asym_d 258 | 259 | def per_frame_distortion(self, sym_d, asym_d, total_power_ref): 260 | # Computation of the norms over bark bands for each frame 261 | # 2 and 1 for sym_d and asym_d, respectively 262 | d_frame = torch.sum( 263 | torch.pow(sym_d * self.width_of_band_bark, 2.0) + self.EPS, dim=-1, keepdim=True 264 | ) 265 | # a = torch.pow(sym_d * self.width_of_band_bark, 2.0) 266 | # b = sym_d 267 | # print(a.min(),a.max(),b.min(),b.max(), d_frame.min(), d_frame.max()) 268 | # print(self.width_of_band_bark.requires_grad) 269 | # print(d_frame.requires_grad) 270 | d_frame = torch.sqrt(d_frame) * self.sqrt_total_width 271 | da_frame = torch.sum(asym_d * self.width_of_band_bark, dim=-1, keepdim=True) 272 | # Weighting by the audible power raised to 0.04 273 | weights = torch.pow((total_power_ref + 1e5) / 1e7, 0.04) 274 | # Bounded computation of the per frame distortion metric 275 | wd_frame = torch.min(d_frame / weights, 45.0 * torch.ones_like(d_frame)) 276 | wda_frame = torch.min(da_frame / weights, 45.0 * torch.ones_like(da_frame)) 277 | return wd_frame, wda_frame 278 | 279 | @staticmethod 280 | def get_correction_factor(window_name): 281 | """ Returns the power correction factor depending on the window. """ 282 | if window_name == "rect": 283 | return 1.0 284 | elif window_name == "hann": 285 | return 2.666666666666754 286 | elif window_name == "sqrt_hann": 287 | return 2.0 288 | elif window_name == "hamming": 289 | return 2.51635879188799 290 | elif window_name == "flatTop": 291 | return 5.70713295690759 292 | else: 293 | raise ValueError("Unexpected window type {}".format(window_name)) 294 | 295 | def populate_constants(self, sample_rate): 296 | if sample_rate == 8000: 297 | self.register_8k_constants() 298 | elif sample_rate == 16000: 299 | self.register_16k_constants() 300 | # Mask SSL 301 | mask_sll = np.zeros(shape=[self.nbins // 2 + 1], dtype=np.float32) 302 | mask_sll[11] = 0.5 * 25.0 / 31.25 303 | mask_sll[12:104] = 1.0 304 | mask_sll[104] = 0.5 305 | correction = self.pow_correc_factor * (self.nbins + 2.0) / self.nbins ** 2 306 | mask_sll = mask_sll * correction 307 | self.mask_sll = nn.Parameter(tensor(mask_sll), requires_grad=False) 308 | 309 | def register_16k_constants(self): 310 | # Absolute threshold power 311 | abs_thresh_power = [ 312 | 51286152.00, 313 | 2454709.500, 314 | 70794.593750, 315 | 4897.788574, 316 | 1174.897705, 317 | 389.045166, 318 | 104.712860, 319 | 45.708820, 320 | 17.782795, 321 | 9.772372, 322 | 4.897789, 323 | 3.090296, 324 | 1.905461, 325 | 1.258925, 326 | 0.977237, 327 | 0.724436, 328 | 0.562341, 329 | 0.457088, 330 | 0.389045, 331 | 0.331131, 332 | 0.295121, 333 | 0.269153, 334 | 0.257040, 335 | 0.251189, 336 | 0.251189, 337 | 0.251189, 338 | 0.251189, 339 | 0.263027, 340 | 0.288403, 341 | 0.309030, 342 | 0.338844, 343 | 0.371535, 344 | 0.398107, 345 | 0.436516, 346 | 0.467735, 347 | 0.489779, 348 | 0.501187, 349 | 0.501187, 350 | 0.512861, 351 | 0.524807, 352 | 0.524807, 353 | 0.524807, 354 | 0.512861, 355 | 0.478630, 356 | 0.426580, 357 | 0.371535, 358 | 0.363078, 359 | 0.416869, 360 | 0.537032, 361 | ] 362 | self.abs_thresh_power = nn.Parameter(tensor(abs_thresh_power), requires_grad=False) 363 | # Modified zwicker power 364 | modif_zwicker_power = [ 365 | 0.25520097857560436, 366 | 0.25520097857560436, 367 | 0.25520097857560436, 368 | 0.25520097857560436, 369 | 0.25168783742879913, 370 | 0.24806665731869609, 371 | 0.244767379124259, 372 | 0.24173800119368227, 373 | 0.23893798876066405, 374 | 0.23633516221479894, 375 | 0.23390360348392067, 376 | 0.23162209128929445, 377 | 0.23, 378 | 0.23, 379 | 0.23, 380 | 0.23, 381 | 0.23, 382 | 0.23, 383 | 0.23, 384 | 0.23, 385 | 0.23, 386 | 0.23, 387 | 0.23, 388 | 0.23, 389 | 0.23, 390 | 0.23, 391 | 0.23, 392 | 0.23, 393 | 0.23, 394 | 0.23, 395 | 0.23, 396 | 0.23, 397 | 0.23, 398 | 0.23, 399 | 0.23, 400 | 0.23, 401 | 0.23, 402 | 0.23, 403 | 0.23, 404 | 0.23, 405 | 0.23, 406 | 0.23, 407 | 0.23, 408 | 0.23, 409 | 0.23, 410 | 0.23, 411 | 0.23, 412 | 0.23, 413 | 0.23, 414 | ] 415 | self.modified_zwicker_power = nn.Parameter(tensor(modif_zwicker_power), requires_grad=False) 416 | # Width of band bark 417 | width_of_band_bark = [ 418 | 0.157344, 419 | 0.317994, 420 | 0.322441, 421 | 0.326934, 422 | 0.331474, 423 | 0.336061, 424 | 0.340697, 425 | 0.345381, 426 | 0.350114, 427 | 0.354897, 428 | 0.359729, 429 | 0.364611, 430 | 0.369544, 431 | 0.374529, 432 | 0.379565, 433 | 0.384653, 434 | 0.389794, 435 | 0.394989, 436 | 0.400236, 437 | 0.405538, 438 | 0.410894, 439 | 0.416306, 440 | 0.421773, 441 | 0.427297, 442 | 0.432877, 443 | 0.438514, 444 | 0.444209, 445 | 0.449962, 446 | 0.455774, 447 | 0.461645, 448 | 0.467577, 449 | 0.473569, 450 | 0.479621, 451 | 0.485736, 452 | 0.491912, 453 | 0.498151, 454 | 0.504454, 455 | 0.510819, 456 | 0.517250, 457 | 0.523745, 458 | 0.530308, 459 | 0.536934, 460 | 0.543629, 461 | 0.550390, 462 | 0.557220, 463 | 0.564119, 464 | 0.571085, 465 | 0.578125, 466 | 0.585232, 467 | ] 468 | self.width_of_band_bark = nn.Parameter(tensor(width_of_band_bark), requires_grad=False) 469 | # Bark matrix 470 | local_path = pathlib.Path(__file__).parent.absolute() 471 | bark_path = os.path.join(local_path, "bark_matrix_16k.mat") 472 | bark_matrix = self.load_mat(bark_path)["Bark_matrix_16k"].astype("float32") 473 | self.bark_matrix = nn.Parameter(tensor(bark_matrix), requires_grad=False) 474 | 475 | def register_8k_constants(self): 476 | # Absolute threshold power 477 | abs_thresh_power = [ 478 | 51286152, 479 | 2454709.500, 480 | 70794.593750, 481 | 4897.788574, 482 | 1174.897705, 483 | 389.045166, 484 | 104.712860, 485 | 45.708820, 486 | 17.782795, 487 | 9.772372, 488 | 4.897789, 489 | 3.090296, 490 | 1.905461, 491 | 1.258925, 492 | 0.977237, 493 | 0.724436, 494 | 0.562341, 495 | 0.457088, 496 | 0.389045, 497 | 0.331131, 498 | 0.295121, 499 | 0.269153, 500 | 0.257040, 501 | 0.251189, 502 | 0.251189, 503 | 0.251189, 504 | 0.251189, 505 | 0.263027, 506 | 0.288403, 507 | 0.309030, 508 | 0.338844, 509 | 0.371535, 510 | 0.398107, 511 | 0.436516, 512 | 0.467735, 513 | 0.489779, 514 | 0.501187, 515 | 0.501187, 516 | 0.512861, 517 | 0.524807, 518 | 0.524807, 519 | 0.524807, 520 | ] 521 | self.abs_thresh_power = nn.Parameter(tensor(abs_thresh_power), requires_grad=False) 522 | # Modified zwicker power 523 | modif_zwicker_power = [ 524 | 0.25520097857560436, 525 | 0.25520097857560436, 526 | 0.25520097857560436, 527 | 0.25520097857560436, 528 | 0.25168783742879913, 529 | 0.24806665731869609, 530 | 0.244767379124259, 531 | 0.24173800119368227, 532 | 0.23893798876066405, 533 | 0.23633516221479894, 534 | 0.23390360348392067, 535 | 0.23162209128929445, 536 | 0.23, 537 | 0.23, 538 | 0.23, 539 | 0.23, 540 | 0.23, 541 | 0.23, 542 | 0.23, 543 | 0.23, 544 | 0.23, 545 | 0.23, 546 | 0.23, 547 | 0.23, 548 | 0.23, 549 | 0.23, 550 | 0.23, 551 | 0.23, 552 | 0.23, 553 | 0.23, 554 | 0.23, 555 | 0.23, 556 | 0.23, 557 | 0.23, 558 | 0.23, 559 | 0.23, 560 | 0.23, 561 | 0.23, 562 | 0.23, 563 | 0.23, 564 | 0.23, 565 | 0.23, 566 | ] 567 | self.modified_zwicker_power = nn.Parameter(tensor(modif_zwicker_power), requires_grad=False) 568 | # Width of band bark 569 | width_of_band_bark = [ 570 | 0.157344, 571 | 0.317994, 572 | 0.322441, 573 | 0.326934, 574 | 0.331474, 575 | 0.336061, 576 | 0.340697, 577 | 0.345381, 578 | 0.350114, 579 | 0.354897, 580 | 0.359729, 581 | 0.364611, 582 | 0.369544, 583 | 0.374529, 584 | 0.379565, 585 | 0.384653, 586 | 0.389794, 587 | 0.394989, 588 | 0.400236, 589 | 0.405538, 590 | 0.410894, 591 | 0.416306, 592 | 0.421773, 593 | 0.427297, 594 | 0.432877, 595 | 0.438514, 596 | 0.444209, 597 | 0.449962, 598 | 0.455774, 599 | 0.461645, 600 | 0.467577, 601 | 0.473569, 602 | 0.479621, 603 | 0.485736, 604 | 0.491912, 605 | 0.498151, 606 | 0.504454, 607 | 0.510819, 608 | 0.517250, 609 | 0.523745, 610 | 0.530308, 611 | 0.536934, 612 | ] 613 | self.width_of_band_bark = nn.Parameter(tensor(width_of_band_bark), requires_grad=False) 614 | # Bark matrix 615 | local_path = pathlib.Path(__file__).parent.absolute() 616 | bark_path = os.path.join(local_path, "bark_matrix_8k.mat") 617 | bark_matrix = self.load_mat(bark_path)["Bark_matrix_8k"].astype("float32") 618 | self.bark_matrix = nn.Parameter(tensor(bark_matrix), requires_grad=False) 619 | 620 | def load_mat(self, *args, **kwargs): 621 | from scipy.io import loadmat 622 | 623 | return loadmat(*args, **kwargs) 624 | 625 | 626 | if __name__ == '__main__': 627 | test_est = torch.tensor(np.load(r"F:\python programes\interphone_tensorflow\wav/test_est.npy"), dtype=torch.float32) 628 | test_label = torch.tensor(np.load(r"F:\python programes\interphone_tensorflow\wav/test_label.npy"), 629 | dtype=torch.float32) 630 | pmsqe = SingleSrcPMSQE() 631 | print(pmsqe(test_est ** 2, test_label ** 2)) 632 | print(torch.mean(pmsqe(test_est ** 2, test_label ** 2))) 633 | -------------------------------------------------------------------------------- /results/different_loss_compare/README.md: -------------------------------------------------------------------------------- 1 | The models trained by our loss function (APC-SNR or PMSQE1+APC-SNR) are the best on the comprehensive performance. 2 | 3 | All results are sorted from small to large according to the CI index in our paper. 4 | 5 | #### model in paper (NsNET) trained by different loss function. 6 | 7 | * (python36) F:\python programes\APC-SNR>python analysis.py model_in_paper 8 | ```txt 9 | ------------------------------ PMSQE1.csv ------------------------------ 10 | pesq_wb apc-snr time_sisnr_ stoi_score pmsqe_score mse \ 11 | std 0.885604 3.210944 3.100253 0.079843 0.462479 0.108243 12 | mean 2.819146 6.966474 6.793499 0.915208 0.609486 0.092088 13 | 14 | apc-mse stft_diff snr log_mse time CI 15 | std 0.009841 0.035318 11.380700 1672.448470 0.191178 -1.067367 16 | mean 0.010735 0.135234 9.627286 6668.395815 0.174532 -1.067367 17 | ------------------------------ STOI.csv ------------------------------ 18 | pesq_wb apc-snr time_sisnr_ stoi_score pmsqe_score mse \ 19 | std 0.916190 6.309723 7.650743 0.073037 0.727168 0.033223 20 | mean 2.421587 10.702118 14.932098 0.942021 1.128424 0.016135 21 | 22 | apc-mse stft_diff snr log_mse time CI 23 | std 0.006644 0.043790 11.380700 1788.278757 0.039769 -0.336705 24 | mean 0.004626 0.102257 9.627286 6154.422389 0.165325 -0.336705 25 | ------------------------------ PMSQE.csv ------------------------------ 26 | pesq_wb apc-snr time_sisnr_ stoi_score pmsqe_score mse \ 27 | std 0.913308 6.797561 8.256733 0.086187 0.638651 0.027477 28 | mean 2.608862 12.282501 15.535787 0.928366 0.913599 0.013817 29 | 30 | apc-mse stft_diff snr log_mse time CI 31 | std 0.004799 0.035834 11.380700 1329.387698 0.096284 -0.310001 32 | mean 0.003443 0.087506 9.627286 4717.880032 0.168999 -0.310001 33 | ------------------------------ MSE.csv ------------------------------ 34 | pesq_wb apc-snr time_sisnr_ stoi_score pmsqe_score mse \ 35 | std 0.889027 6.556404 8.005335 0.078473 0.592897 0.021589 36 | mean 2.593183 12.988769 17.098206 0.934021 0.903568 0.010040 37 | 38 | apc-mse stft_diff snr log_mse time CI 39 | std 0.004075 0.033662 11.380700 1438.310945 0.023291 0.010275 40 | mean 0.002936 0.089574 9.627286 5136.124110 0.153320 0.010275 41 | ------------------------------ SI-SNR.csv ------------------------------ 42 | pesq_wb apc-snr time_sisnr_ stoi_score pmsqe_score mse \ 43 | std 0.917275 6.697482 8.126247 0.077372 0.631717 0.023851 44 | mean 2.637803 13.079809 17.482367 0.937436 0.918886 0.010669 45 | 46 | apc-mse stft_diff snr log_mse time CI 47 | std 0.004219 0.032972 11.380700 1598.195230 0.034394 0.295149 48 | mean 0.002957 0.088835 9.627286 5421.023171 0.154234 0.295149 49 | ------------------------------ APC-SNR.csv ------------------------------ 50 | pesq_wb apc-snr time_sisnr_ stoi_score pmsqe_score mse \ 51 | std 0.925940 6.735696 8.227054 0.077457 0.576913 0.023204 52 | mean 2.717651 13.568513 17.531977 0.938980 0.813688 0.009959 53 | 54 | apc-mse stft_diff snr log_mse time CI 55 | std 0.003942 0.033922 11.380700 1463.192586 0.043477 0.570436 56 | mean 0.002709 0.083246 9.627286 5113.331380 0.160313 0.570436 57 | ------------------------------ PMSQE1+APC-SNR.csv ------------------------------ 58 | pesq_wb apc-snr time_sisnr_ stoi_score pmsqe_score mse \ 59 | std 0.939416 6.729692 8.198711 0.075890 0.521013 0.021263 60 | mean 2.794378 13.677221 17.637924 0.940408 0.710305 0.009601 61 | 62 | apc-mse stft_diff snr log_mse time CI 63 | std 0.003782 0.033273 11.380700 1476.335639 0.013591 0.838212 64 | mean 0.002642 0.083032 9.627286 5044.413008 0.153787 0.838212 65 | ``` 66 | 67 | 68 | 69 | #### DCRN trained by different loss function. 70 | 71 | > If the mask apply methods of model is different, the action mechanism of loss function will be affected. So it’s another research we will publish soon. We have not shown it in this paper. 72 | 73 | * (python36) F:\python programes\APC-SNR>python analysis.py dcrn 74 | 75 | ```txt 76 | ------------------------------ PMSQE.csv ------------------------------ 77 | pesq_wb pesq_nb apc-snr time_sisnr_ stoi_score pmsqe_score \ 78 | std 0.722245 0.743013 1.082325 1.134453 0.084694 0.618090 79 | mean 2.285199 2.805140 6.497373 6.507226 0.918010 1.123879 80 | 81 | time mse log_mse CI 82 | std 0.050024 0.852268 1277.142161 -1.335816 83 | mean 0.374750 0.739536 4684.820424 -1.335816 84 | ------------------------------ STOI.csv ------------------------------ 85 | pesq_wb pesq_nb apc-snr time_sisnr_ stoi_score pmsqe_score \ 86 | std 0.902221 0.865075 2.005889 2.267596 0.089532 0.805570 87 | mean 2.141888 2.606973 13.100688 13.693197 0.917886 1.414646 88 | 89 | time mse log_mse CI 90 | std 0.006431 0.802877 1665.579885 -1.107117 91 | mean 0.285633 0.754603 6540.993792 -1.107117 92 | ------------------------------ PMSQE1.csv ------------------------------ 93 | pesq_wb pesq_nb apc-snr time_sisnr_ stoi_score pmsqe_score \ 94 | std 0.871956 0.819457 0.975210 0.956333 0.075617 0.532536 95 | mean 2.670215 3.156578 4.040768 3.952440 0.933617 0.765608 96 | 97 | time mse log_mse CI 98 | std 0.051112 0.886549 1391.272701 -0.409464 99 | mean 0.396556 0.765121 5003.232417 -0.409464 100 | ------------------------------ MSE.csv ------------------------------ 101 | pesq_wb pesq_nb apc-snr time_sisnr_ stoi_score pmsqe_score \ 102 | std 0.869847 0.790393 6.508225 8.000791 0.077054 0.606798 103 | mean 2.622014 3.057390 13.462180 17.338863 0.937252 0.894019 104 | 105 | time mse log_mse CI 106 | std 0.026291 0.018310 1344.922938 0.462829 107 | mean 0.339722 0.008754 5091.116865 0.462829 108 | ------------------------------ SI-SNR.csv ------------------------------ 109 | pesq_wb pesq_nb apc-snr time_sisnr_ stoi_score pmsqe_score \ 110 | std 0.892635 0.797894 6.650916 8.096310 0.075642 0.607811 111 | mean 2.685761 3.116080 13.634782 17.532762 0.939228 0.859444 112 | 113 | time mse log_mse CI 114 | std 0.051400 1.287860 1341.251005 0.633437 115 | mean 0.403344 1.122308 4978.711163 0.633437 116 | ------------------------------ PMSQE1+APC-SNR.csv ------------------------------ 117 | pesq_wb pesq_nb apc-snr time_sisnr_ stoi_score pmsqe_score \ 118 | std 0.917321 0.818154 6.649614 8.109513 0.074625 0.562606 119 | mean 2.777020 3.198914 14.044983 17.706681 0.942018 0.736298 120 | 121 | time mse log_mse CI 122 | std 0.027447 1.283255 1347.922655 0.869978 123 | mean 0.340153 1.114424 4818.440187 0.869978 124 | ------------------------------ APC-SNR.csv ------------------------------ 125 | pesq_wb pesq_nb apc-snr time_sisnr_ stoi_score pmsqe_score \ 126 | std 0.909985 0.809029 6.668091 8.155011 0.074218 0.585604 127 | mean 2.770679 3.189733 14.167096 17.852413 0.942504 0.788317 128 | 129 | time mse log_mse CI 130 | std 0.050741 1.285023 1343.351034 0.886154 131 | mean 0.401119 1.116907 4861.222821 0.886154 132 | ``` 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /results/different_loss_compare/analysis.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author:WangTianRui 3 | # Date :2021/4/9 16:07 4 | import os, sys 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | pd.set_option('display.max_columns', None) # 显示完整的列 10 | pd.set_option('display.max_rows', None) # 显示完整的行 11 | 12 | pesqs = [] 13 | stois = [] 14 | sisnrs = [] 15 | 16 | 17 | def get_info(csv_path, name, all_csvs): 18 | df = pd.read_csv(csv_path, encoding="utf_8_sig") 19 | temp_df = df.describe().loc[["std", "mean"]].drop(['Unnamed: 0', "stft_snr"], axis=1) 20 | temp_df.time_sisnr_ = abs(temp_df.time_sisnr_) # to positive number 21 | temp_df["apc-snr"] = abs(temp_df["apc-snr"]) # to positive number 22 | all_csvs[name] = temp_df 23 | pesqs.append(all_csvs[name]["pesq_wb"].loc["mean"]) 24 | stois.append(all_csvs[name]["stoi_score"].loc["mean"]) 25 | sisnrs.append(all_csvs[name]["time_sisnr_"].loc["mean"]) 26 | 27 | 28 | def sort_csv(all_csvs, model_flag): 29 | scores = {} 30 | for key in all_csvs.keys(): 31 | mean_measure = ( 32 | (all_csvs[key]["pesq_wb"].loc["mean"] - np.mean(pesqs)) / np.std(pesqs) + 33 | (all_csvs[key]["time_sisnr_"].loc["mean"] - np.mean(sisnrs)) / np.std(sisnrs) + 34 | (all_csvs[key]["stoi_score"].loc["mean"] - np.mean(stois)) / np.std(stois) 35 | ) / 3 36 | all_csvs[key]["CI"] = mean_measure 37 | scores[key] = mean_measure 38 | scores_sort = sorted(scores.items(), key=lambda x: x[1], reverse=False) 39 | 40 | # indexes = None 41 | # scores_plot = {} 42 | methods = [] 43 | for key in scores_sort: 44 | print("--" * 15, key[0], "--" * 15) 45 | methods.append(key[0].split(".")[0]) 46 | print(all_csvs[key[0]]) 47 | # if indexes is None: 48 | # indexes = all_csvs[key[0]].columns.values 49 | # for index in indexes: 50 | # scores_plot[index] = [all_csvs[key[0]][index].loc["mean"]] 51 | # else: 52 | # for index in indexes: 53 | # scores_plot[index].append(all_csvs[key[0]][index].loc["mean"]) 54 | 55 | # plt.figure(figsize=(12, 4), dpi=160) 56 | # 57 | # for i, index in enumerate(scores_plot.keys()): 58 | # # print(500 + 10 + 10 * (i // 5) + 1 + i % 5) 59 | # # plt.subplot(int(500 + 10 + 10 * (i // 5) + 1 + i % 5)) 60 | # print(methods) 61 | # print(scores_plot[index]) 62 | # plt.bar(methods, scores_plot[index]) 63 | # plt.title(index) 64 | # plt.savefig(os.path.join("./pics", "%s" % model_flag, "%s.png" % index)) 65 | # # plt.show() 66 | 67 | 68 | if __name__ == '__main__': 69 | model_flag = str(sys.argv[1]) 70 | all_csvs = {} 71 | info_path = {} 72 | all_csv_name = [] 73 | root = os.path.join("./csvs/", model_flag) 74 | for i in os.walk(root): 75 | all_csv_name = i[2] 76 | break 77 | for name in all_csv_name: 78 | if name.endswith("csv"): 79 | info_path[name] = os.path.join(root, name) 80 | for key in info_path.keys(): 81 | get_info(info_path[key], key, all_csvs) 82 | sort_csv(all_csvs, model_flag) 83 | -------------------------------------------------------------------------------- /results/eps_and_theta/README.md: -------------------------------------------------------------------------------- 1 | Results for the eps and theta. 2 | 3 | #### eps 4 | * (python36) F:\python programes\APC-SNR\results\eps_and_theta> python analysis.py eps 5 | ```txt 6 | eps_0.100_theta_0.010.csv 7 | pesq_wb 2.668163 8 | time_sisnr_ 17.448977 9 | stoi_score 0.937472 10 | Name: mean, dtype: float64 11 | eps_0.500_theta_0.010.csv 12 | pesq_wb 2.673503 13 | time_sisnr_ 17.528962 14 | stoi_score 0.938223 15 | Name: mean, dtype: float64 16 | eps_1.000_theta_0.010.csv 17 | pesq_wb 2.717651 18 | time_sisnr_ 17.598689 19 | stoi_score 0.938980 20 | Name: mean, dtype: float64 21 | eps_1.500_theta_0.010.csv 22 | pesq_wb 2.673276 23 | time_sisnr_ 17.598803 24 | stoi_score 0.938291 25 | Name: mean, dtype: float64 26 | eps_2.000_theta_0.010.csv 27 | pesq_wb 2.669381 28 | time_sisnr_ 17.562867 29 | stoi_score 0.938318 30 | Name: mean, dtype: float64 31 | ``` 32 | ![](https://github.com/wangtianrui/APC-SNR/blob/master/results/eps_and_theta/pics/eps.png) 33 | 34 | #### theta 35 | * (python36) F:\python programes\APC-SNR\results\eps_and_theta> python analysis.py theta 36 | 37 | ```txt 38 | eps_1.000_theta_0.0001.csv 39 | pesq_wb 2.659480 40 | time_sisnr_ 17.458036 41 | stoi_score 0.936936 42 | Name: mean, dtype: float64 43 | eps_1.000_theta_0.0005.csv 44 | pesq_wb 2.663100 45 | time_sisnr_ 17.478115 46 | stoi_score 0.937638 47 | Name: mean, dtype: float64 48 | eps_1.000_theta_0.001.csv 49 | pesq_wb 2.688328 50 | time_sisnr_ 17.602062 51 | stoi_score 0.938702 52 | Name: mean, dtype: float64 53 | eps_1.000_theta_0.005.csv 54 | pesq_wb 2.675465 55 | time_sisnr_ 17.588495 56 | stoi_score 0.938699 57 | Name: mean, dtype: float64 58 | eps_1.000_theta_0.010.csv 59 | pesq_wb 2.717651 60 | time_sisnr_ 17.598689 61 | stoi_score 0.938980 62 | Name: mean, dtype: float64 63 | eps_1.000_theta_0.050.csv 64 | pesq_wb 2.674032 65 | time_sisnr_ 17.570935 66 | stoi_score 0.938317 67 | Name: mean, dtype: float64 68 | eps_1.000_theta_0.100.csv 69 | pesq_wb 2.679813 70 | time_sisnr_ 17.585749 71 | stoi_score 0.938391 72 | Name: mean, dtype: float64 73 | eps_1.000_theta_0.500.csv 74 | pesq_wb 2.652406 75 | time_sisnr_ 17.538733 76 | stoi_score 0.937562 77 | Name: mean, dtype: float64 78 | ``` 79 | ![](https://github.com/wangtianrui/APC-SNR/blob/master/results/eps_and_theta/pics/theta.png) 80 | 81 | 82 | -------------------------------------------------------------------------------- /results/eps_and_theta/analysis.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Author:WangTianRui 3 | # Date :2021-08-24 14:15 4 | import pandas as pd 5 | import numpy as np 6 | import os, sys 7 | import matplotlib.pyplot as plt 8 | 9 | if __name__ == '__main__': 10 | flag = str(sys.argv[1]) 11 | eps_root = r"./csvs/%s" % flag 12 | pesq = [] 13 | sisnr = [] 14 | stoi = [] 15 | xs = [] 16 | for i in os.walk(eps_root): 17 | for index, name in enumerate(i[2]): 18 | if flag == "eps": 19 | eps = str(name.split("_")[1]) 20 | xs.append(eps) 21 | else: 22 | theta = str(name.split("_")[3].split(".csv")[0]) 23 | 24 | xs.append(theta) 25 | df = pd.read_csv(os.path.join(eps_root, name)) 26 | print(name) 27 | df.time_sisnr_ = abs(df.time_sisnr_) 28 | print(df.describe().loc["mean"]) 29 | pesq.append(float(df.describe().loc["mean"]["pesq_wb"])) 30 | sisnr.append(abs(float(df.describe().loc["mean"]["time_sisnr_"]))) 31 | stoi.append(float(df.describe().loc["mean"]["stoi_score"])) 32 | df.to_csv(os.path.join(eps_root, name)) # change SI-SNR to positive number 33 | break 34 | print(pesq) 35 | print(sisnr) 36 | print(stoi) 37 | print(xs) 38 | plt.figure(figsize=(6, 8), dpi=160) 39 | plt.subplot(311) 40 | plt.plot(xs, pesq) 41 | # plt.title("PESQ") 42 | # plt.show() 43 | plt.subplot(312) 44 | plt.plot(xs, sisnr) 45 | # plt.title("SI-SNR") 46 | # plt.show() 47 | plt.subplot(313) 48 | plt.plot(xs, stoi) 49 | # plt.title("STOI") 50 | plt.savefig("./pics/%s.png" % flag) 51 | plt.show() 52 | -------------------------------------------------------------------------------- /results/eps_and_theta/pics/eps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangtianrui/APC-SNR/938ff22bc7e10ccc38021defc1bdb9eb96031eba/results/eps_and_theta/pics/eps.png -------------------------------------------------------------------------------- /results/eps_and_theta/pics/theta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangtianrui/APC-SNR/938ff22bc7e10ccc38021defc1bdb9eb96031eba/results/eps_and_theta/pics/theta.png --------------------------------------------------------------------------------