├── .gitattributes ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── pystoi ├── __init__.py ├── stoi.py └── utils.py ├── setup.cfg ├── setup.py └── tests ├── matlab ├── applyOBM.m ├── estoi.m ├── removeSilentFrames.m ├── stdft.m ├── stoi.m └── thirdoct.m ├── octave ├── applyOBM.m ├── estoi.m ├── ml_hanning.m ├── removeSilentFrames.m ├── stdft.m ├── stoi.m └── thirdoct.m ├── test_matlab_python.py ├── test_overlap_and_add.py ├── test_python_octave.py ├── test_stoi.py ├── test_stoi_octave.py ├── test_utils.py └── test_utils_octave.py /.gitattributes: -------------------------------------------------------------------------------- 1 | tests/* linguist-vendored=true 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | tests/__pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # Octave workspace file 105 | octave-workspace 106 | 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Pariente Manuel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python implementation of STOI 2 | 3 | Implementation of the classical and extended Short Term Objective Intelligibility measures 4 | 5 | Intelligibility measure which is highly correlated with the intelligibility of degraded speech signals, e.g., due to additive noise, single/multi-channel noise reduction, binary masking and vocoded speech as in CI simulations. The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech intelligibility. 6 | Description taken from [Cees Taal's website](http://www.ceestaal.nl/code/) 7 | 8 | 9 | ### Install 10 | 11 | `pip install pystoi` or 12 | `pip3 install pystoi` 13 | 14 | ### Usage 15 | ``` 16 | import soundfile as sf 17 | from pystoi import stoi 18 | 19 | clean, fs = sf.read('path/to/clean/audio') 20 | denoised, fs = sf.read('path/to/denoised/audio') 21 | 22 | # Clean and den should have the same length, and be 1D 23 | d = stoi(clean, denoised, fs, extended=False) 24 | ``` 25 | 26 | ### Running the Octave tests 27 | 28 | ```bash 29 | sudo apt update 30 | sudo apt install octave octave-signal 31 | pip install oct2py 32 | ``` 33 | 34 | ```bash 35 | python -m pytest tests/test_python_octave.py 36 | python -m pytest tests/test_stoi_octave.py 37 | ``` 38 | 39 | ### Matlab code & Testing 40 | 41 | All the Matlab code in this repo is taken from or adapted from the code available [here](http://www.ceestaal.nl/code/) (STOI – Short-Time Objective Intelligibility Measure – ) written by Cees Taal. 42 | 43 | Thanks to Cees Taal who open-sourced his Matlab implementation and enabled thorough testing of this python code. 44 | 45 | If you want to run the tests, you will need Matlab, `matlab.engine` (install instructions [here](https://fr.mathworks.com/help/matlab/matlab_external/install-the-matlab-engine-for-python.html)) and `matlab_wrapper` (install with `pip install matlab_wrapper`). 46 | The tests can only be ran under Python 2.7 as `matlab.engine` and `matlab_wrapper` are only compatible with Python2.7 47 | Tests are passing at relative and absolute tolerance of `1e-3`, which is enough for the considered application (all the variability is coming from the resampling method when signals are not natively sampled at 10kHz). 48 | 49 | Very big thanks to @gauss256 who translated all the matlab scripts to Octave, and wrote all the tests for it! 50 | 51 | ### Contribute 52 | 53 | Any contribution are welcome~, specially to improve the execution speed of the code~ (thank you Przemek Pobrotyn for a 4x speed-up!) : 54 | 55 | * ~Improve the resampling method to match Matlab's resampling in `tests/`.~ This can be considered a solved issue thanks to @gauss256 ! 56 | * Write tests for Python 3 (with [`transplant`](https://github.com/bastibe/transplant) for example) 57 | 58 | 59 | ### References 60 | * [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time 61 | Objective Intelligibility Measure for Time-Frequency Weighted Noisy Speech', 62 | ICASSP 2010, Texas, Dallas. 63 | * [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for 64 | Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', 65 | IEEE Transactions on Audio, Speech, and Language Processing, 2011. 66 | * [3] J. Jensen and C. H. Taal, 'An Algorithm for Predicting the 67 | Intelligibility of Speech Masked by Modulated Noise Maskers', 68 | IEEE Transactions on Audio, Speech and Language Processing, 2016. 69 | -------------------------------------------------------------------------------- /pystoi/__init__.py: -------------------------------------------------------------------------------- 1 | from .stoi import stoi 2 | 3 | __version__ = '0.4.1' 4 | -------------------------------------------------------------------------------- /pystoi/stoi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | from . import utils 4 | 5 | # Constant definition 6 | FS = 10000 # Sampling frequency 7 | N_FRAME = 256 # Window support 8 | NFFT = 512 # FFT Size 9 | NUMBAND = 15 # Number of 13 octave band 10 | MINFREQ = 150 # Center frequency of 1st octave band (Hz) 11 | OBM, CF = utils.thirdoct(FS, NFFT, NUMBAND, MINFREQ) # Get 1/3 octave band matrix 12 | N = 30 # N. frames for intermediate intelligibility 13 | BETA = -15. # Lower SDR bound 14 | DYN_RANGE = 40 # Speech dynamic range 15 | 16 | 17 | def stoi(x, y, fs_sig, extended=False): 18 | """ Short term objective intelligibility 19 | Computes the STOI (See [1][2]) of a denoised signal compared to a clean 20 | signal, The output is expected to have a monotonic relation with the 21 | subjective speech-intelligibility, where a higher score denotes better 22 | speech intelligibility. 23 | 24 | # Arguments 25 | x (np.ndarray): clean original speech 26 | y (np.ndarray): denoised speech 27 | fs_sig (int): sampling rate of x and y 28 | extended (bool): Boolean, whether to use the extended STOI described in [3] 29 | 30 | # Returns 31 | float: Short time objective intelligibility measure between clean and 32 | denoised speech 33 | 34 | # Raises 35 | AssertionError : if x and y have different lengths 36 | 37 | # Reference 38 | [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time 39 | Objective Intelligibility Measure for Time-Frequency Weighted Noisy 40 | Speech', ICASSP 2010, Texas, Dallas. 41 | [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for 42 | Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', 43 | IEEE Transactions on Audio, Speech, and Language Processing, 2011. 44 | [3] Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the 45 | Intelligibility of Speech Masked by Modulated Noise Maskers', 46 | IEEE Transactions on Audio, Speech and Language Processing, 2016. 47 | """ 48 | if x.shape != y.shape: 49 | raise Exception('x and y should have the same length,' + 50 | 'found {} and {}'.format(x.shape, y.shape)) 51 | 52 | # Resample is fs_sig is different than fs 53 | if fs_sig != FS: 54 | x = utils.resample_oct(x, FS, fs_sig) 55 | y = utils.resample_oct(y, FS, fs_sig) 56 | 57 | # Remove silent frames 58 | x, y = utils.remove_silent_frames(x, y, DYN_RANGE, N_FRAME, int(N_FRAME/2)) 59 | 60 | # Take STFT 61 | x_spec = utils.stft(x, N_FRAME, NFFT, overlap=2).transpose() 62 | y_spec = utils.stft(y, N_FRAME, NFFT, overlap=2).transpose() 63 | 64 | # Ensure at least 30 frames for intermediate intelligibility 65 | if x_spec.shape[-1] < N: 66 | warnings.warn('Not enough STFT frames to compute intermediate ' 67 | 'intelligibility measure after removing silent ' 68 | 'frames. Returning 1e-5. Please check you wav files', 69 | RuntimeWarning) 70 | return 1e-5 71 | 72 | # Apply OB matrix to the spectrograms as in Eq. (1) 73 | x_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(x_spec)))) 74 | y_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(y_spec)))) 75 | 76 | # Take segments of x_tob, y_tob 77 | x_segments = np.array( 78 | [x_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)]) 79 | y_segments = np.array( 80 | [y_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)]) 81 | 82 | if extended: 83 | x_n = utils.row_col_normalize(x_segments) 84 | y_n = utils.row_col_normalize(y_segments) 85 | return np.sum(x_n * y_n / N) / x_n.shape[0] 86 | 87 | else: 88 | # Find normalization constants and normalize 89 | normalization_consts = ( 90 | np.linalg.norm(x_segments, axis=2, keepdims=True) / 91 | (np.linalg.norm(y_segments, axis=2, keepdims=True) + utils.EPS)) 92 | y_segments_normalized = y_segments * normalization_consts 93 | 94 | # Clip as described in [1] 95 | clip_value = 10 ** (-BETA / 20) 96 | y_primes = np.minimum( 97 | y_segments_normalized, x_segments * (1 + clip_value)) 98 | 99 | # Subtract mean vectors 100 | y_primes = y_primes - np.mean(y_primes, axis=2, keepdims=True) 101 | x_segments = x_segments - np.mean(x_segments, axis=2, keepdims=True) 102 | 103 | # Divide by their norms 104 | y_primes /= (np.linalg.norm(y_primes, axis=2, keepdims=True) + utils.EPS) 105 | x_segments /= (np.linalg.norm(x_segments, axis=2, keepdims=True) + utils.EPS) 106 | # Find a matrix with entries summing to sum of correlations of vectors 107 | correlations_components = y_primes * x_segments 108 | 109 | # J, M as in [1], eq.6 110 | J = x_segments.shape[0] 111 | M = x_segments.shape[1] 112 | 113 | # Find the mean of all correlations 114 | d = np.sum(correlations_components) / (J * M) 115 | return d 116 | -------------------------------------------------------------------------------- /pystoi/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | from scipy.signal import resample_poly 5 | 6 | EPS = np.finfo("float").eps 7 | 8 | 9 | def _resample_window_oct(p, q): 10 | """Port of Octave code to Python""" 11 | 12 | gcd = np.gcd(p, q) 13 | if gcd > 1: 14 | p /= gcd 15 | q /= gcd 16 | 17 | # Properties of the antialiasing filter 18 | log10_rejection = -3.0 19 | stopband_cutoff_f = 1.0 / (2 * max(p, q)) 20 | roll_off_width = stopband_cutoff_f / 10 21 | 22 | # Determine filter length 23 | rejection_dB = -20 * log10_rejection 24 | L = np.ceil((rejection_dB - 8) / (28.714 * roll_off_width)) 25 | 26 | # Ideal sinc filter 27 | t = np.arange(-L, L + 1) 28 | ideal_filter = 2 * p * stopband_cutoff_f \ 29 | * np.sinc(2 * stopband_cutoff_f * t) 30 | 31 | # Determine parameter of Kaiser window 32 | if (rejection_dB >= 21) and (rejection_dB <= 50): 33 | beta = 0.5842 * (rejection_dB - 21)**0.4 \ 34 | + 0.07886 * (rejection_dB - 21) 35 | elif rejection_dB > 50: 36 | beta = 0.1102 * (rejection_dB - 8.7) 37 | else: 38 | beta = 0.0 39 | 40 | # Apodize ideal filter response 41 | h = np.kaiser(2 * L + 1, beta) * ideal_filter 42 | 43 | return h 44 | 45 | 46 | def resample_oct(x, p, q): 47 | """Resampler that is compatible with Octave""" 48 | h = _resample_window_oct(p, q) 49 | window = h / np.sum(h) 50 | return resample_poly(x, p, q, window=window) 51 | 52 | 53 | @functools.lru_cache(maxsize=None) 54 | def thirdoct(fs, nfft, num_bands, min_freq): 55 | """ Returns the 1/3 octave band matrix and its center frequencies 56 | # Arguments : 57 | fs : sampling rate 58 | nfft : FFT size 59 | num_bands : number of 1/3 octave bands 60 | min_freq : center frequency of the lowest 1/3 octave band 61 | # Returns : 62 | obm : Octave Band Matrix 63 | cf : center frequencies 64 | """ 65 | f = np.linspace(0, fs, nfft + 1) 66 | f = f[: int(nfft / 2) + 1] 67 | k = np.array(range(num_bands)).astype(float) 68 | cf = np.power(2.0 ** (1.0 / 3), k) * min_freq 69 | freq_low = min_freq * np.power(2.0, (2 * k - 1) / 6) 70 | freq_high = min_freq * np.power(2.0, (2 * k + 1) / 6) 71 | obm = np.zeros((num_bands, len(f))) # a verifier 72 | 73 | for i in range(len(cf)): 74 | # Match 1/3 oct band freq with fft frequency bin 75 | f_bin = np.argmin(np.square(f - freq_low[i])) 76 | freq_low[i] = f[f_bin] 77 | fl_ii = f_bin 78 | f_bin = np.argmin(np.square(f - freq_high[i])) 79 | freq_high[i] = f[f_bin] 80 | fh_ii = f_bin 81 | # Assign to the octave band matrix 82 | obm[i, fl_ii:fh_ii] = 1 83 | return obm, cf 84 | 85 | 86 | def stft(x, win_size, fft_size, overlap=4): 87 | """ Short-time Fourier transform for real 1-D inputs 88 | # Arguments 89 | x : 1D array, the waveform 90 | win_size : integer, the size of the window and the signal frames 91 | fft_size : integer, the size of the fft in samples (zero-padding or not) 92 | overlap: integer, number of steps to make in fftsize 93 | # Returns 94 | stft_out : 2D complex array, the STFT of x. 95 | """ 96 | hop = int(win_size / overlap) 97 | w = np.hanning(win_size + 2)[1: -1] # = matlab.hanning(win_size) 98 | stft_out = np.array([np.fft.rfft(w * x[i:i + win_size], n=fft_size) 99 | for i in range(0, len(x) - win_size, hop)]) 100 | return stft_out 101 | 102 | 103 | def _overlap_and_add(x_frames, hop): 104 | num_frames, framelen = x_frames.shape 105 | # Compute the number of segments, per frame. 106 | segments = -(-framelen // hop) # Divide and round up. 107 | 108 | # Pad the framelen dimension to segments * hop and add n=segments frames 109 | signal = np.pad(x_frames, ((0, segments), (0, segments * hop - framelen))) 110 | 111 | # Reshape to a 3D tensor, splitting the framelen dimension in two 112 | signal = signal.reshape((num_frames + segments, segments, hop)) 113 | # Transpose dimensions so that signal.shape = (segments, frame+segments, hop) 114 | signal = np.transpose(signal, [1, 0, 2]) 115 | # Reshape so that signal.shape = (segments * (frame+segments), hop) 116 | signal = signal.reshape((-1, hop)) 117 | 118 | # Now behold the magic!! Remove the last n=segments elements from the first axis 119 | signal = signal[:-segments] 120 | # Reshape to (segments, frame+segments-1, hop) 121 | signal = signal.reshape((segments, num_frames + segments - 1, hop)) 122 | # This has introduced a shift by one in all rows 123 | 124 | # Now, reduce over the columns and flatten the array to achieve the result 125 | signal = np.sum(signal, axis=0) 126 | end = (len(x_frames) - 1) * hop + framelen 127 | signal = signal.reshape(-1)[:end] 128 | return signal 129 | 130 | 131 | def remove_silent_frames(x, y, dyn_range, framelen, hop): 132 | """ Remove silent frames of x and y based on x 133 | A frame is excluded if its energy is lower than max(energy) - dyn_range 134 | The frame exclusion is based solely on x, the clean speech signal 135 | # Arguments : 136 | x : array, original speech wav file 137 | y : array, denoised speech wav file 138 | dyn_range : Energy range to determine which frame is silent 139 | framelen : Window size for energy evaluation 140 | hop : Hop size for energy evaluation 141 | # Returns : 142 | x without the silent frames 143 | y without the silent frames (aligned to x) 144 | """ 145 | # Compute Mask 146 | w = np.hanning(framelen + 2)[1:-1] 147 | 148 | x_frames = np.array( 149 | [w * x[i:i + framelen] for i in range(0, len(x) - framelen, hop)]) 150 | y_frames = np.array( 151 | [w * y[i:i + framelen] for i in range(0, len(x) - framelen, hop)]) 152 | 153 | # Compute energies in dB 154 | x_energies = 20 * np.log10(np.linalg.norm(x_frames, axis=1) + EPS) 155 | 156 | # Find boolean mask of energies lower than dynamic_range dB 157 | # with respect to maximum clean speech energy frame 158 | mask = (np.max(x_energies) - dyn_range - x_energies) < 0 159 | 160 | # Remove silent frames by masking 161 | x_frames = x_frames[mask] 162 | y_frames = y_frames[mask] 163 | 164 | x_sil = _overlap_and_add(x_frames, hop) 165 | y_sil = _overlap_and_add(y_frames, hop) 166 | 167 | return x_sil, y_sil 168 | 169 | 170 | def vect_two_norm(x, axis=-1): 171 | """ Returns an array of vectors of norms of the rows of matrices from 3D array """ 172 | return np.sum(np.square(x), axis=axis, keepdims=True) 173 | 174 | 175 | def row_col_normalize(x): 176 | """ Row and column mean and variance normalize an array of 2D segments """ 177 | # Row mean and variance normalization 178 | x_normed = x + EPS * np.random.standard_normal(x.shape) 179 | x_normed -= np.mean(x_normed, axis=-1, keepdims=True) 180 | x_inv = 1. / np.sqrt(vect_two_norm(x_normed)) 181 | x_diags = np.array( 182 | [np.diag(x_inv[i].reshape(-1)) for i in range(x_inv.shape[0])]) 183 | x_normed = np.matmul(x_diags, x_normed) 184 | # Column mean and variance normalization 185 | x_normed += + EPS * np.random.standard_normal(x_normed.shape) 186 | x_normed -= np.mean(x_normed, axis=1, keepdims=True) 187 | x_inv = 1. / np.sqrt(vect_two_norm(x_normed, axis=1)) 188 | x_diags = np.array( 189 | [np.diag(x_inv[i].reshape(-1)) for i in range(x_inv.shape[0])]) 190 | x_normed = np.matmul(x_normed, x_diags) 191 | return x_normed 192 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file=README.md 3 | 4 | [bdist_wheel] 5 | universal=1 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools import find_packages 3 | 4 | with open("README.md", encoding='utf-8') as fh: 5 | long_description = fh.read() 6 | 7 | setup( 8 | name='pystoi', 9 | version='0.4.1', 10 | description='Computes Short Term Objective Intelligibility measure', 11 | author='Manuel Pariente', 12 | author_email='pariente.mnl@gmail.com', 13 | url='https://github.com/mpariente/pystoi', 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | license='MIT', 17 | install_requires=['numpy', 'scipy'], 18 | classifiers=[ 19 | 'Development Status :: 4 - Beta', 20 | 'Intended Audience :: Science/Research', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 2', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3' 25 | ], 26 | packages=find_packages() 27 | ) 28 | -------------------------------------------------------------------------------- /tests/matlab/applyOBM.m: -------------------------------------------------------------------------------- 1 | function X = applyOBM(x, OBM, N_frame, NFFT, NUMBAND) 2 | 3 | x_hat = stdft(x, N_frame, N_frame/2, NFFT); 4 | x_hat = x_hat(:, 1:(NFFT/2+1)).'; 5 | 6 | X = zeros(NUMBAND, size(x_hat, 2)); 7 | for i = 1:size(x_hat, 2) 8 | X(:, i) = sqrt(OBM*abs(x_hat(:, i)).^2); 9 | end 10 | -------------------------------------------------------------------------------- /tests/matlab/estoi.m: -------------------------------------------------------------------------------- 1 | function d = estoi(x, y, fs_signal) 2 | % d = estoi(x, y, fs_signal) returns the output of the extended short-time 3 | % objective intelligibility (ESTOI) predictor. 4 | % 5 | % Implementation of the Extended Short-Time Objective 6 | % Intelligibility (ESTOI) predictor, described in Jesper Jensen and 7 | % Cees H. Taal, "An Algorithm for Predicting the Intelligibility of 8 | % Speech Masked by Modulated Noise Maskers," IEEE Transactions on 9 | % Audio, Speech and Language Processing, 2016. 10 | % 11 | % Input: 12 | % x: clean reference time domain signal 13 | % y: noisy/processed time domain signal 14 | % fs_signal: sampling rate [Hz] 15 | % 16 | % Output: 17 | % d: intelligibility index 18 | % 19 | % 20 | % Copyright 2016: Aalborg University, Section for Signal and Information Processing. 21 | % The software is free for non-commercial use. 22 | % The software comes WITHOUT ANY WARRANTY. 23 | 24 | 25 | if length(x)~=length(y) 26 | error('x and y should have the same length'); 27 | end 28 | 29 | % initialization 30 | x = x(:); % clean speech column vector 31 | y = y(:); % processed speech column vector 32 | 33 | fs = 10000; % sample rate of proposed intelligibility measure 34 | N_frame = 256; % window support 35 | K = 512; % FFT size 36 | J = 15; % Number of 1/3 octave bands 37 | mn = 150; % Center frequency of first 1/3 octave band in Hz. 38 | [H,fc_thirdoct] = thirdoct(fs, K, J, mn); % Get 1/3 octave band matrix 39 | N = 30; % Number of frames for intermediate intelligibility measure 40 | dyn_range = 40; % speech dynamic range 41 | 42 | % resample signals if other samplerate is used than fs 43 | if fs_signal ~= fs 44 | x = resample(x, fs, fs_signal); 45 | y = resample(y, fs, fs_signal); 46 | end 47 | 48 | % remove silent frames 49 | [x y] = removeSilentFrames(x, y, dyn_range, N_frame, N_frame/2); 50 | 51 | % apply 1/3 octave band TF-decomposition 52 | x_hat = stdft(x, N_frame, N_frame/2, K); % apply short-time DFT to clean speech 53 | y_hat = stdft(y, N_frame, N_frame/2, K); % apply short-time DFT to processed speech 54 | 55 | 56 | x_hat = x_hat(:, 1:(K/2+1)).'; % take clean single-sided spectrum 57 | y_hat = y_hat(:, 1:(K/2+1)).'; % take processed single-sided spectrum 58 | 59 | X = zeros(J, size(x_hat, 2)); % init memory for clean speech 1/3 octave band TF-representation 60 | Y = zeros(J, size(y_hat, 2)); % init memory for processed speech 1/3 octave band TF-representation 61 | 62 | for i = 1:size(x_hat, 2) 63 | X(:, i) = sqrt(H*abs(x_hat(:, i)).^2); % apply 1/3 octave band filtering 64 | Y(:, i) = sqrt(H*abs(y_hat(:, i)).^2); 65 | end 66 | 67 | % loop all segments of length N and obtain intermediate intelligibility measure for each 68 | d1 = zeros(length(N:size(X, 2)),1); % init memory for intermediate intelligibility measure 69 | for m=N:size(X,2) 70 | X_seg = X(:, (m-N+1):m); % region of length N with clean TF-units for all j 71 | Y_seg = Y(:, (m-N+1):m); % region of length N with processed TF-units for all j 72 | X_seg = X_seg + eps*randn(size(X_seg)); % to avoid divide by zero 73 | Y_seg = Y_seg + eps*randn(size(Y_seg)); % to avoid divide by zero 74 | 75 | %% first normalize rows (to give \bar{S}_m) 76 | XX = X_seg - mean(X_seg.').'*ones(1,N); % normalize rows to zero mean 77 | YY = Y_seg - mean(Y_seg.').'*ones(1,N); % normalize rows to zero mean 78 | 79 | YY = diag(1./sqrt(diag(YY*YY')))*YY; % normalize rows to unit length 80 | XX = diag(1./sqrt(diag(XX*XX')))*XX; % normalize rows to unit length 81 | 82 | XX = XX + eps*randn(size(XX)); % to avoid corr.div.by.0 83 | YY = YY + eps*randn(size(YY)); % to avoid corr.div.by.0 84 | 85 | %% then normalize columns (to give \check{S}_m) 86 | YYY = YY - ones(J,1)*mean(YY); % normalize cols to zero mean 87 | XXX = XX - ones(J,1)*mean(XX); % normalize cols to zero mean 88 | 89 | YYY = YYY*diag(1./sqrt(diag(YYY'*YYY))); % normalize cols to unit length 90 | XXX = XXX*diag(1./sqrt(diag(XXX'*XXX))); % normalize cols to unit length 91 | 92 | %compute average of col.correlations (by stacking cols) 93 | d1(m-N+1) = 1/N*XXX(:).'*YYY(:); 94 | end 95 | d = mean(d1); 96 | 97 | 98 | %% 99 | function [A cf] = thirdoct(fs, N_fft, numBands, mn) 100 | % [A CF] = THIRDOCT(FS, N_FFT, NUMBANDS, MN) returns 1/3 octave band matrix 101 | % inputs: 102 | % FS: samplerate 103 | % N_FFT: FFT size 104 | % NUMBANDS: number of bands 105 | % MN: center frequency of first 1/3 octave band 106 | % outputs: 107 | % A: octave band matrix 108 | % CF: center frequencies 109 | 110 | f = linspace(0, fs, N_fft+1); 111 | f = f(1:(N_fft/2+1)); 112 | k = 0:(numBands-1); 113 | cf = 2.^(k/3)*mn; 114 | fl = sqrt((2.^(k/3)*mn).*2.^((k-1)/3)*mn); 115 | fr = sqrt((2.^(k/3)*mn).*2.^((k+1)/3)*mn); 116 | A = zeros(numBands, length(f)); 117 | 118 | for i = 1:(length(cf)) 119 | [a b] = min((f-fl(i)).^2); 120 | fl(i) = f(b); 121 | fl_ii = b; 122 | 123 | [a b] = min((f-fr(i)).^2); 124 | fr(i) = f(b); 125 | fr_ii = b; 126 | A(i,fl_ii:(fr_ii-1)) = 1; 127 | end 128 | 129 | rnk = sum(A, 2); 130 | numBands = find((rnk(2:end)>=rnk(1:(end-1))) & (rnk(2:end)~=0)~=0, 1, 'last' )+1; 131 | A = A(1:numBands, :); 132 | cf = cf(1:numBands); 133 | 134 | %% 135 | function x_stdft = stdft(x, N, K, N_fft) 136 | % X_STDFT = X_STDFT(X, N, K, N_FFT) returns the short-time 137 | % hanning-windowed dft of X with frame-size N, overlap K and DFT size 138 | % N_FFT. The columns and rows of X_STDFT denote the frame-index and 139 | % dft-bin index, respectively. 140 | 141 | frames = 1:K:(length(x)-N); 142 | x_stdft = zeros(length(frames), N_fft); 143 | 144 | w = hanning(N); 145 | x = x(:); 146 | 147 | for i = 1:length(frames) 148 | ii = frames(i):(frames(i)+N-1); 149 | x_stdft(i, :) = fft(x(ii).*w, N_fft); 150 | end 151 | 152 | %% 153 | function [x_sil y_sil] = removeSilentFrames(x, y, range, N, K) 154 | % [X_SIL Y_SIL] = REMOVESILENTFRAMES(X, Y, RANGE, N, K) X and Y 155 | % are segmented with frame-length N and overlap K, where the maximum energy 156 | % of all frames of X is determined, say X_MAX. X_SIL and Y_SIL are the 157 | % reconstructed signals, excluding the frames, where the energy of a frame 158 | % of X is smaller than X_MAX-RANGE 159 | 160 | x = x(:); 161 | y = y(:); 162 | 163 | frames = 1:K:(length(x)-N); 164 | w = hanning(N); 165 | msk = zeros(size(frames)); 166 | 167 | for j = 1:length(frames) 168 | jj = frames(j):(frames(j)+N-1); 169 | msk(j) = 20*log10(norm(x(jj).*w)./sqrt(N)); 170 | end 171 | 172 | msk = (msk-max(msk)+range)>0; 173 | count = 1; 174 | 175 | x_sil = zeros(size(x)); 176 | y_sil = zeros(size(y)); 177 | 178 | for j = 1:length(frames) 179 | if msk(j) 180 | jj_i = frames(j):(frames(j)+N-1); 181 | jj_o = frames(count):(frames(count)+N-1); 182 | x_sil(jj_o) = x_sil(jj_o) + x(jj_i).*w; 183 | y_sil(jj_o) = y_sil(jj_o) + y(jj_i).*w; 184 | count = count+1; 185 | end 186 | end 187 | 188 | x_sil = x_sil(1:jj_o(end)); 189 | y_sil = y_sil(1:jj_o(end)); 190 | 191 | 192 | -------------------------------------------------------------------------------- /tests/matlab/removeSilentFrames.m: -------------------------------------------------------------------------------- 1 | function [x_sil y_sil] = removeSilentFrames(x, y, range, N, K) 2 | 3 | x = x(:); 4 | y = y(:); 5 | 6 | frames = 1:K:(length(x)-N); 7 | w = hanning(N); 8 | msk = zeros(size(frames)); 9 | 10 | for j = 1:length(frames) 11 | jj = frames(j):(frames(j)+N-1); 12 | msk(j) = 20*log10(norm(x(jj).*w)./sqrt(N)); 13 | end 14 | 15 | msk = (msk-max(msk)+range)>0; 16 | count = 1; 17 | 18 | x_sil = zeros(size(x)); 19 | y_sil = zeros(size(y)); 20 | 21 | for j = 1:length(frames) 22 | if msk(j) 23 | jj_i = frames(j):(frames(j)+N-1); 24 | jj_o = frames(count):(frames(count)+N-1); 25 | x_sil(jj_o) = x_sil(jj_o) + x(jj_i).*w; 26 | y_sil(jj_o) = y_sil(jj_o) + y(jj_i).*w; 27 | count = count+1; 28 | end 29 | end 30 | 31 | x_sil = x_sil(1:jj_o(end)); 32 | y_sil = y_sil(1:jj_o(end)); 33 | -------------------------------------------------------------------------------- /tests/matlab/stdft.m: -------------------------------------------------------------------------------- 1 | function x_stdft = stdft(x, N, K, N_fft) 2 | 3 | frames = 1:K:(length(x)-N); 4 | x_stdft = zeros(length(frames), N_fft); 5 | 6 | w = hanning(N); 7 | x = x(:); 8 | 9 | for i = 1:length(frames) 10 | ii = frames(i):(frames(i)+N-1); 11 | x_stdft(i, :) = fft(x(ii).*w, N_fft); 12 | end 13 | -------------------------------------------------------------------------------- /tests/matlab/stoi.m: -------------------------------------------------------------------------------- 1 | function d = stoi(x, y, fs_signal) 2 | % d = stoi(x, y, fs_signal) returns the output of the short-time 3 | % objective intelligibility (STOI) measure described in [1, 2], where x 4 | % and y denote the clean and processed speech, respectively, with sample 5 | % rate fs_signal in Hz. The output d is expected to have a monotonic 6 | % relation with the subjective speech-intelligibility, where a higher d 7 | % denotes better intelligible speech. See [1, 2] for more details. 8 | % 9 | % References: 10 | % [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time 11 | % Objective Intelligibility Measure for Time-Frequency Weighted Noisy 12 | % Speech', ICASSP 2010, Texas, Dallas. 13 | % 14 | % [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for 15 | % Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', 16 | % IEEE Transactions on Audio, Speech, and Language Processing, 2011. 17 | % 18 | % 19 | % Copyright 2009: Delft University of Technology, Signal & Information 20 | % Processing Lab. The software is free for non-commercial use. This program 21 | % comes WITHOUT ANY WARRANTY. 22 | % 23 | % 24 | % 25 | % Updates: 26 | % 2011-04-26 Using the more efficient 'taa_corr' instead of 'corr' 27 | 28 | if length(x)~=length(y) 29 | error('x and y should have the same length'); 30 | end 31 | 32 | % initialization 33 | x = x(:); % clean speech column vector 34 | y = y(:); % processed speech column vector 35 | 36 | fs = 10000; % sample rate of proposed intelligibility measure 37 | N_frame = 256; % window support 38 | K = 512; % FFT size 39 | J = 15; % Number of 1/3 octave bands 40 | mn = 150; % Center frequency of first 1/3 octave band in Hz. 41 | H = thirdoct(fs, K, J, mn); % Get 1/3 octave band matrix 42 | N = 30; % Number of frames for intermediate intelligibility measure (Length analysis window) 43 | Beta = -15; % lower SDR-bound 44 | dyn_range = 40; % speech dynamic range 45 | 46 | % resample signals if other samplerate is used than fs 47 | if fs_signal ~= fs 48 | x = resample(x, fs, fs_signal); 49 | y = resample(y, fs, fs_signal); 50 | end 51 | 52 | % remove silent frames 53 | [x y] = removeSilentFrames(x, y, dyn_range, N_frame, N_frame/2); 54 | 55 | % apply 1/3 octave band TF-decomposition 56 | x_hat = stdft(x, N_frame, N_frame/2, K); % apply short-time DFT to clean speech 57 | y_hat = stdft(y, N_frame, N_frame/2, K); % apply short-time DFT to processed speech 58 | 59 | x_hat = x_hat(:, 1:(K/2+1)).'; % take clean single-sided spectrum 60 | y_hat = y_hat(:, 1:(K/2+1)).'; % take processed single-sided spectrum 61 | 62 | X = zeros(J, size(x_hat, 2)); % init memory for clean speech 1/3 octave band TF-representation 63 | Y = zeros(J, size(y_hat, 2)); % init memory for processed speech 1/3 octave band TF-representation 64 | 65 | for i = 1:size(x_hat, 2) 66 | X(:, i) = sqrt(H*abs(x_hat(:, i)).^2); % apply 1/3 octave bands as described in Eq.(1) [1] 67 | Y(:, i) = sqrt(H*abs(y_hat(:, i)).^2); 68 | end 69 | 70 | % loop al segments of length N and obtain intermediate intelligibility measure for all TF-regions 71 | d_interm = zeros(J, length(N:size(X, 2))); % init memory for intermediate intelligibility measure 72 | c = 10^(-Beta/20); % constant for clipping procedure 73 | 74 | for m = N:size(X, 2) 75 | X_seg = X(:, (m-N+1):m); % region with length N of clean TF-units for all j 76 | Y_seg = Y(:, (m-N+1):m); % region with length N of processed TF-units for all j 77 | alpha = sqrt(sum(X_seg.^2, 2)./sum(Y_seg.^2, 2)); % obtain scale factor for normalizing processed TF-region for all j 78 | aY_seg = Y_seg.*repmat(alpha, [1 N]); % obtain \alpha*Y_j(n) from Eq.(2) [1] 79 | for j = 1:J 80 | Y_prime = min(aY_seg(j, :), X_seg(j, :)+X_seg(j, :)*c); % apply clipping from Eq.(3) 81 | d_interm(j, m-N+1) = taa_corr(X_seg(j, :).', Y_prime(:)); % obtain correlation coeffecient from Eq.(4) [1] 82 | end 83 | end 84 | 85 | d = mean(d_interm(:)); % combine all intermediate intelligibility measures as in Eq.(4) [1] 86 | 87 | %% 88 | function [A cf] = thirdoct(fs, N_fft, numBands, mn) 89 | % [A CF] = THIRDOCT(FS, N_FFT, NUMBANDS, MN) returns 1/3 octave band matrix 90 | % inputs: 91 | % FS: samplerate 92 | % N_FFT: FFT size 93 | % NUMBANDS: number of bands 94 | % MN: center frequency of first 1/3 octave band 95 | % outputs: 96 | % A: octave band matrix 97 | % CF: center frequencies 98 | 99 | f = linspace(0, fs, N_fft+1); 100 | f = f(1:(N_fft/2+1)); 101 | k = 0:(numBands-1); 102 | cf = 2.^(k/3)*mn; 103 | fl = sqrt((2.^(k/3)*mn).*2.^((k-1)/3)*mn); 104 | fr = sqrt((2.^(k/3)*mn).*2.^((k+1)/3)*mn); 105 | A = zeros(numBands, length(f)); 106 | 107 | for i = 1:(length(cf)) 108 | [a b] = min((f-fl(i)).^2); 109 | fl(i) = f(b); 110 | fl_ii = b; 111 | 112 | [a b] = min((f-fr(i)).^2); 113 | fr(i) = f(b); 114 | fr_ii = b; 115 | A(i,fl_ii:(fr_ii-1)) = 1; 116 | end 117 | 118 | rnk = sum(A, 2); 119 | numBands = find((rnk(2:end)>=rnk(1:(end-1))) & (rnk(2:end)~=0)~=0, 1, 'last' )+1; 120 | A = A(1:numBands, :); 121 | cf = cf(1:numBands); 122 | 123 | %% 124 | function x_stdft = stdft(x, N, K, N_fft) 125 | % X_STDFT = X_STDFT(X, N, K, N_FFT) returns the short-time 126 | % hanning-windowed dft of X with frame-size N, overlap K and DFT size 127 | % N_FFT. The columns and rows of X_STDFT denote the frame-index and 128 | % dft-bin index, respectively. 129 | 130 | frames = 1:K:(length(x)-N); 131 | x_stdft = zeros(length(frames), N_fft); 132 | 133 | w = hanning(N); 134 | x = x(:); 135 | 136 | for i = 1:length(frames) 137 | ii = frames(i):(frames(i)+N-1); 138 | x_stdft(i, :) = fft(x(ii).*w, N_fft); 139 | end 140 | 141 | %% 142 | function [x_sil y_sil] = removeSilentFrames(x, y, range, N, K) 143 | % [X_SIL Y_SIL] = REMOVESILENTFRAMES(X, Y, RANGE, N, K) X and Y 144 | % are segmented with frame-length N and overlap K, where the maximum energy 145 | % of all frames of X is determined, say X_MAX. X_SIL and Y_SIL are the 146 | % reconstructed signals, excluding the frames, where the energy of a frame 147 | % of X is smaller than X_MAX-RANGE 148 | 149 | x = x(:); 150 | y = y(:); 151 | 152 | frames = 1:K:(length(x)-N); 153 | w = hanning(N); 154 | msk = zeros(size(frames)); 155 | 156 | for j = 1:length(frames) 157 | jj = frames(j):(frames(j)+N-1); 158 | msk(j) = 20*log10(norm(x(jj).*w)./sqrt(N)); 159 | end 160 | 161 | msk = (msk-max(msk)+range)>0; 162 | count = 1; 163 | 164 | x_sil = zeros(size(x)); 165 | y_sil = zeros(size(y)); 166 | 167 | for j = 1:length(frames) 168 | if msk(j) 169 | jj_i = frames(j):(frames(j)+N-1); 170 | jj_o = frames(count):(frames(count)+N-1); 171 | x_sil(jj_o) = x_sil(jj_o) + x(jj_i).*w; 172 | y_sil(jj_o) = y_sil(jj_o) + y(jj_i).*w; 173 | count = count+1; 174 | end 175 | end 176 | 177 | x_sil = x_sil(1:jj_o(end)); 178 | y_sil = y_sil(1:jj_o(end)); 179 | 180 | %% 181 | function rho = taa_corr(x, y) 182 | % RHO = TAA_CORR(X, Y) Returns correlation coeffecient between column 183 | % vectors x and y. Gives same results as 'corr' from statistics toolbox. 184 | xn = x-mean(x); 185 | xn = xn/sqrt(sum(xn.^2)); 186 | yn = y-mean(y); 187 | yn = yn/sqrt(sum(yn.^2)); 188 | rho = sum(xn.*yn); 189 | -------------------------------------------------------------------------------- /tests/matlab/thirdoct.m: -------------------------------------------------------------------------------- 1 | function [A cf] = thirdoct(fs, N_fft, numBands, mn) 2 | 3 | f = linspace(0, fs, N_fft+1); 4 | f = f(1:(N_fft/2+1)); 5 | k = 0:(numBands-1); 6 | cf = 2.^(k/3)*mn; 7 | fl = sqrt((2.^(k/3)*mn).*2.^((k-1)/3)*mn); 8 | fr = sqrt((2.^(k/3)*mn).*2.^((k+1)/3)*mn); 9 | A = zeros(numBands, length(f)); 10 | 11 | for i = 1:(length(cf)) 12 | [a b] = min((f-fl(i)).^2); 13 | fl(i) = f(b); 14 | fl_ii = b; 15 | 16 | [a b] = min((f-fr(i)).^2); 17 | fr(i) = f(b); 18 | fr_ii = b; 19 | A(i,fl_ii:(fr_ii-1)) = 1; 20 | end 21 | 22 | rnk = sum(A, 2); 23 | numBands = find((rnk(2:end)>=rnk(1:(end-1))) & (rnk(2:end)~=0)~=0, 1, 'last' )+1; 24 | A = A(1:numBands, :); 25 | cf = cf(1:numBands); 26 | -------------------------------------------------------------------------------- /tests/octave/applyOBM.m: -------------------------------------------------------------------------------- 1 | function X = applyOBM(x, OBM, N_frame, NFFT, NUMBAND) 2 | 3 | x_hat = stdft(x, N_frame, N_frame/2, NFFT); 4 | x_hat = x_hat(:, 1:(NFFT/2+1)).'; 5 | 6 | X = zeros(NUMBAND, size(x_hat, 2)); 7 | for i = 1:size(x_hat, 2) 8 | X(:, i) = sqrt(OBM*abs(x_hat(:, i)).^2); 9 | end 10 | -------------------------------------------------------------------------------- /tests/octave/estoi.m: -------------------------------------------------------------------------------- 1 | function d = estoi(x, y, fs_signal) 2 | % d = estoi(x, y, fs_signal) returns the output of the extended short-time 3 | % objective intelligibility (ESTOI) predictor. 4 | % 5 | % Implementation of the Extended Short-Time Objective 6 | % Intelligibility (ESTOI) predictor, described in Jesper Jensen and 7 | % Cees H. Taal, "An Algorithm for Predicting the Intelligibility of 8 | % Speech Masked by Modulated Noise Maskers," IEEE Transactions on 9 | % Audio, Speech and Language Processing, 2016. 10 | % 11 | % Input: 12 | % x: clean reference time domain signal 13 | % y: noisy/processed time domain signal 14 | % fs_signal: sampling rate [Hz] 15 | % 16 | % Output: 17 | % d: intelligibility index 18 | % 19 | % 20 | % Copyright 2016: Aalborg University, Section for Signal and Information Processing. 21 | % The software is free for non-commercial use. 22 | % The software comes WITHOUT ANY WARRANTY. 23 | 24 | 25 | if length(x)~=length(y) 26 | error('x and y should have the same length'); 27 | end 28 | 29 | % initialization 30 | x = x(:); % clean speech column vector 31 | y = y(:); % processed speech column vector 32 | 33 | fs = 10000; % sample rate of proposed intelligibility measure 34 | N_frame = 256; % window support 35 | K = 512; % FFT size 36 | J = 15; % Number of 1/3 octave bands 37 | mn = 150; % Center frequency of first 1/3 octave band in Hz. 38 | [H,fc_thirdoct] = thirdoct(fs, K, J, mn); % Get 1/3 octave band matrix 39 | N = 30; % Number of frames for intermediate intelligibility measure 40 | dyn_range = 40; % speech dynamic range 41 | 42 | % resample signals if other samplerate is used than fs 43 | if fs_signal ~= fs 44 | x = resample(x, fs, fs_signal); 45 | y = resample(y, fs, fs_signal); 46 | end 47 | 48 | % remove silent frames 49 | [x y] = removeSilentFrames(x, y, dyn_range, N_frame, N_frame/2); 50 | 51 | % apply 1/3 octave band TF-decomposition 52 | x_hat = stdft(x, N_frame, N_frame/2, K); % apply short-time DFT to clean speech 53 | y_hat = stdft(y, N_frame, N_frame/2, K); % apply short-time DFT to processed speech 54 | 55 | 56 | x_hat = x_hat(:, 1:(K/2+1)).'; % take clean single-sided spectrum 57 | y_hat = y_hat(:, 1:(K/2+1)).'; % take processed single-sided spectrum 58 | 59 | X = zeros(J, size(x_hat, 2)); % init memory for clean speech 1/3 octave band TF-representation 60 | Y = zeros(J, size(y_hat, 2)); % init memory for processed speech 1/3 octave band TF-representation 61 | 62 | for i = 1:size(x_hat, 2) 63 | X(:, i) = sqrt(H*abs(x_hat(:, i)).^2); % apply 1/3 octave band filtering 64 | Y(:, i) = sqrt(H*abs(y_hat(:, i)).^2); 65 | end 66 | 67 | % loop all segments of length N and obtain intermediate intelligibility measure for each 68 | d1 = zeros(length(N:size(X, 2)),1); % init memory for intermediate intelligibility measure 69 | for m=N:size(X,2) 70 | X_seg = X(:, (m-N+1):m); % region of length N with clean TF-units for all j 71 | Y_seg = Y(:, (m-N+1):m); % region of length N with processed TF-units for all j 72 | X_seg = X_seg + eps*randn(size(X_seg)); % to avoid divide by zero 73 | Y_seg = Y_seg + eps*randn(size(Y_seg)); % to avoid divide by zero 74 | 75 | %% first normalize rows (to give \bar{S}_m) 76 | XX = X_seg - mean(X_seg.').'*ones(1,N); % normalize rows to zero mean 77 | YY = Y_seg - mean(Y_seg.').'*ones(1,N); % normalize rows to zero mean 78 | 79 | YY = diag(1./sqrt(diag(YY*YY')))*YY; % normalize rows to unit length 80 | XX = diag(1./sqrt(diag(XX*XX')))*XX; % normalize rows to unit length 81 | 82 | XX = XX + eps*randn(size(XX)); % to avoid corr.div.by.0 83 | YY = YY + eps*randn(size(YY)); % to avoid corr.div.by.0 84 | 85 | %% then normalize columns (to give \check{S}_m) 86 | YYY = YY - ones(J,1)*mean(YY); % normalize cols to zero mean 87 | XXX = XX - ones(J,1)*mean(XX); % normalize cols to zero mean 88 | 89 | YYY = YYY*diag(1./sqrt(diag(YYY'*YYY))); % normalize cols to unit length 90 | XXX = XXX*diag(1./sqrt(diag(XXX'*XXX))); % normalize cols to unit length 91 | 92 | %compute average of col.correlations (by stacking cols) 93 | d1(m-N+1) = 1/N*XXX(:).'*YYY(:); 94 | end 95 | d = mean(d1); 96 | 97 | 98 | %% 99 | function [A cf] = thirdoct(fs, N_fft, numBands, mn) 100 | % [A CF] = THIRDOCT(FS, N_FFT, NUMBANDS, MN) returns 1/3 octave band matrix 101 | % inputs: 102 | % FS: samplerate 103 | % N_FFT: FFT size 104 | % NUMBANDS: number of bands 105 | % MN: center frequency of first 1/3 octave band 106 | % outputs: 107 | % A: octave band matrix 108 | % CF: center frequencies 109 | 110 | f = linspace(0, fs, N_fft+1); 111 | f = f(1:(N_fft/2+1)); 112 | k = 0:(numBands-1); 113 | cf = 2.^(k/3)*mn; 114 | fl = sqrt((2.^(k/3)*mn).*2.^((k-1)/3)*mn); 115 | fr = sqrt((2.^(k/3)*mn).*2.^((k+1)/3)*mn); 116 | A = zeros(numBands, length(f)); 117 | 118 | for i = 1:(length(cf)) 119 | [a b] = min((f-fl(i)).^2); 120 | fl(i) = f(b); 121 | fl_ii = b; 122 | 123 | [a b] = min((f-fr(i)).^2); 124 | fr(i) = f(b); 125 | fr_ii = b; 126 | A(i,fl_ii:(fr_ii-1)) = 1; 127 | end 128 | 129 | rnk = sum(A, 2); 130 | numBands = find((rnk(2:end)>=rnk(1:(end-1))) & (rnk(2:end)~=0)~=0, 1, 'last' )+1; 131 | A = A(1:numBands, :); 132 | cf = cf(1:numBands); 133 | 134 | %% 135 | function x_stdft = stdft(x, N, K, N_fft) 136 | % X_STDFT = X_STDFT(X, N, K, N_FFT) returns the short-time 137 | % hanning-windowed dft of X with frame-size N, overlap K and DFT size 138 | % N_FFT. The columns and rows of X_STDFT denote the frame-index and 139 | % dft-bin index, respectively. 140 | 141 | frames = 1:K:(length(x)-N); 142 | x_stdft = zeros(length(frames), N_fft); 143 | 144 | w = ml_hanning(N); 145 | x = x(:); 146 | 147 | for i = 1:length(frames) 148 | ii = frames(i):(frames(i)+N-1); 149 | x_stdft(i, :) = fft(x(ii).*w, N_fft); 150 | end 151 | 152 | %% 153 | function [x_sil y_sil] = removeSilentFrames(x, y, range, N, K) 154 | % [X_SIL Y_SIL] = REMOVESILENTFRAMES(X, Y, RANGE, N, K) X and Y 155 | % are segmented with frame-length N and overlap K, where the maximum energy 156 | % of all frames of X is determined, say X_MAX. X_SIL and Y_SIL are the 157 | % reconstructed signals, excluding the frames, where the energy of a frame 158 | % of X is smaller than X_MAX-RANGE 159 | 160 | x = x(:); 161 | y = y(:); 162 | 163 | frames = 1:K:(length(x)-N); 164 | w = ml_hanning(N); 165 | msk = zeros(size(frames)); 166 | 167 | for j = 1:length(frames) 168 | jj = frames(j):(frames(j)+N-1); 169 | msk(j) = 20*log10(norm(x(jj).*w)./sqrt(N)); 170 | end 171 | 172 | msk = (msk-max(msk)+range)>0; 173 | count = 1; 174 | 175 | x_sil = zeros(size(x)); 176 | y_sil = zeros(size(y)); 177 | 178 | for j = 1:length(frames) 179 | if msk(j) 180 | jj_i = frames(j):(frames(j)+N-1); 181 | jj_o = frames(count):(frames(count)+N-1); 182 | x_sil(jj_o) = x_sil(jj_o) + x(jj_i).*w; 183 | y_sil(jj_o) = y_sil(jj_o) + y(jj_i).*w; 184 | count = count+1; 185 | end 186 | end 187 | 188 | x_sil = x_sil(1:jj_o(end)); 189 | y_sil = y_sil(1:jj_o(end)); 190 | 191 | 192 | -------------------------------------------------------------------------------- /tests/octave/ml_hanning.m: -------------------------------------------------------------------------------- 1 | function w = ml_hanning(M) 2 | % Compute a Hann window compatible with the MATLAB `hanning` function 3 | w = .5 * (1 - cos(2 * pi * (1:M)'/double(M + 1))); 4 | end 5 | 6 | -------------------------------------------------------------------------------- /tests/octave/removeSilentFrames.m: -------------------------------------------------------------------------------- 1 | function [x_sil, y_sil] = removeSilentFrames(x, y, range, N, K) 2 | 3 | x = x(:); 4 | y = y(:); 5 | 6 | frames = 1:K:(length(x)-N); 7 | w = ml_hanning(N); 8 | msk = zeros(size(frames)); 9 | 10 | for j = 1:length(frames) 11 | jj = frames(j):(frames(j)+N-1); 12 | msk(j) = 20*log10(norm(x(jj).*w)./sqrt(N)); 13 | end 14 | 15 | msk = (msk-max(msk)+range)>0; 16 | count = 1; 17 | 18 | x_sil = zeros(size(x)); 19 | y_sil = zeros(size(y)); 20 | 21 | for j = 1:length(frames) 22 | if msk(j) 23 | jj_i = frames(j):(frames(j)+N-1); 24 | jj_o = frames(count):(frames(count)+N-1); 25 | x_sil(jj_o) = x_sil(jj_o) + x(jj_i).*w; 26 | y_sil(jj_o) = y_sil(jj_o) + y(jj_i).*w; 27 | count = count+1; 28 | end 29 | end 30 | 31 | x_sil = x_sil(1:jj_o(end)); 32 | y_sil = y_sil(1:jj_o(end)); 33 | -------------------------------------------------------------------------------- /tests/octave/stdft.m: -------------------------------------------------------------------------------- 1 | function x_stdft = stdft(x, N, K, N_fft) 2 | 3 | frames = 1:K:(length(x)-N); 4 | x_stdft = zeros(length(frames), N_fft); 5 | 6 | w = ml_hanning(N); 7 | x = x(:); 8 | 9 | for i = 1:length(frames) 10 | ii = frames(i):(frames(i)+N-1); 11 | x_stdft(i, :) = fft(x(ii).*w, N_fft); 12 | end 13 | -------------------------------------------------------------------------------- /tests/octave/stoi.m: -------------------------------------------------------------------------------- 1 | function d = stoi(x, y, fs_signal) 2 | % d = stoi(x, y, fs_signal) returns the output of the short-time 3 | % objective intelligibility (STOI) measure described in [1, 2], where x 4 | % and y denote the clean and processed speech, respectively, with sample 5 | % rate fs_signal in Hz. The output d is expected to have a monotonic 6 | % relation with the subjective speech-intelligibility, where a higher d 7 | % denotes better intelligible speech. See [1, 2] for more details. 8 | % 9 | % References: 10 | % [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time 11 | % Objective Intelligibility Measure for Time-Frequency Weighted Noisy 12 | % Speech', ICASSP 2010, Texas, Dallas. 13 | % 14 | % [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for 15 | % Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', 16 | % IEEE Transactions on Audio, Speech, and Language Processing, 2011. 17 | % 18 | % 19 | % Copyright 2009: Delft University of Technology, Signal & Information 20 | % Processing Lab. The software is free for non-commercial use. This program 21 | % comes WITHOUT ANY WARRANTY. 22 | % 23 | % 24 | % 25 | % Updates: 26 | % 2011-04-26 Using the more efficient 'taa_corr' instead of 'corr' 27 | 28 | if length(x)~=length(y) 29 | error('x and y should have the same length'); 30 | end 31 | 32 | % initialization 33 | x = x(:); % clean speech column vector 34 | y = y(:); % processed speech column vector 35 | 36 | fs = 10000; % sample rate of proposed intelligibility measure 37 | N_frame = 256; % window support 38 | K = 512; % FFT size 39 | J = 15; % Number of 1/3 octave bands 40 | mn = 150; % Center frequency of first 1/3 octave band in Hz. 41 | H = thirdoct(fs, K, J, mn); % Get 1/3 octave band matrix 42 | N = 30; % Number of frames for intermediate intelligibility measure (Length analysis window) 43 | Beta = -15; % lower SDR-bound 44 | dyn_range = 40; % speech dynamic range 45 | 46 | % resample signals if other samplerate is used than fs 47 | if fs_signal ~= fs 48 | x = resample(x, fs, fs_signal); 49 | y = resample(y, fs, fs_signal); 50 | end 51 | 52 | % remove silent frames 53 | [x y] = removeSilentFrames(x, y, dyn_range, N_frame, N_frame/2); 54 | 55 | % apply 1/3 octave band TF-decomposition 56 | x_hat = stdft(x, N_frame, N_frame/2, K); % apply short-time DFT to clean speech 57 | y_hat = stdft(y, N_frame, N_frame/2, K); % apply short-time DFT to processed speech 58 | 59 | x_hat = x_hat(:, 1:(K/2+1)).'; % take clean single-sided spectrum 60 | y_hat = y_hat(:, 1:(K/2+1)).'; % take processed single-sided spectrum 61 | 62 | X = zeros(J, size(x_hat, 2)); % init memory for clean speech 1/3 octave band TF-representation 63 | Y = zeros(J, size(y_hat, 2)); % init memory for processed speech 1/3 octave band TF-representation 64 | 65 | for i = 1:size(x_hat, 2) 66 | X(:, i) = sqrt(H*abs(x_hat(:, i)).^2); % apply 1/3 octave bands as described in Eq.(1) [1] 67 | Y(:, i) = sqrt(H*abs(y_hat(:, i)).^2); 68 | end 69 | 70 | % loop al segments of length N and obtain intermediate intelligibility measure for all TF-regions 71 | d_interm = zeros(J, length(N:size(X, 2))); % init memory for intermediate intelligibility measure 72 | c = 10^(-Beta/20); % constant for clipping procedure 73 | 74 | for m = N:size(X, 2) 75 | X_seg = X(:, (m-N+1):m); % region with length N of clean TF-units for all j 76 | Y_seg = Y(:, (m-N+1):m); % region with length N of processed TF-units for all j 77 | alpha = sqrt(sum(X_seg.^2, 2)./sum(Y_seg.^2, 2)); % obtain scale factor for normalizing processed TF-region for all j 78 | aY_seg = Y_seg.*repmat(alpha, [1 N]); % obtain \alpha*Y_j(n) from Eq.(2) [1] 79 | for j = 1:J 80 | Y_prime = min(aY_seg(j, :), X_seg(j, :)+X_seg(j, :)*c); % apply clipping from Eq.(3) 81 | d_interm(j, m-N+1) = taa_corr(X_seg(j, :).', Y_prime(:)); % obtain correlation coeffecient from Eq.(4) [1] 82 | end 83 | end 84 | 85 | d = mean(d_interm(:)); % combine all intermediate intelligibility measures as in Eq.(4) [1] 86 | 87 | %% 88 | function [A cf] = thirdoct(fs, N_fft, numBands, mn) 89 | % [A CF] = THIRDOCT(FS, N_FFT, NUMBANDS, MN) returns 1/3 octave band matrix 90 | % inputs: 91 | % FS: samplerate 92 | % N_FFT: FFT size 93 | % NUMBANDS: number of bands 94 | % MN: center frequency of first 1/3 octave band 95 | % outputs: 96 | % A: octave band matrix 97 | % CF: center frequencies 98 | 99 | f = linspace(0, fs, N_fft+1); 100 | f = f(1:(N_fft/2+1)); 101 | k = 0:(numBands-1); 102 | cf = 2.^(k/3)*mn; 103 | fl = sqrt((2.^(k/3)*mn).*2.^((k-1)/3)*mn); 104 | fr = sqrt((2.^(k/3)*mn).*2.^((k+1)/3)*mn); 105 | A = zeros(numBands, length(f)); 106 | 107 | for i = 1:(length(cf)) 108 | [a b] = min((f-fl(i)).^2); 109 | fl(i) = f(b); 110 | fl_ii = b; 111 | 112 | [a b] = min((f-fr(i)).^2); 113 | fr(i) = f(b); 114 | fr_ii = b; 115 | A(i,fl_ii:(fr_ii-1)) = 1; 116 | end 117 | 118 | rnk = sum(A, 2); 119 | numBands = find((rnk(2:end)>=rnk(1:(end-1))) & (rnk(2:end)~=0)~=0, 1, 'last' )+1; 120 | A = A(1:numBands, :); 121 | cf = cf(1:numBands); 122 | 123 | %% 124 | function x_stdft = stdft(x, N, K, N_fft) 125 | % X_STDFT = X_STDFT(X, N, K, N_FFT) returns the short-time 126 | % hanning-windowed dft of X with frame-size N, overlap K and DFT size 127 | % N_FFT. The columns and rows of X_STDFT denote the frame-index and 128 | % dft-bin index, respectively. 129 | 130 | frames = 1:K:(length(x)-N); 131 | x_stdft = zeros(length(frames), N_fft); 132 | 133 | w = ml_hanning(N); 134 | x = x(:); 135 | 136 | for i = 1:length(frames) 137 | ii = frames(i):(frames(i)+N-1); 138 | x_stdft(i, :) = fft(x(ii).*w, N_fft); 139 | end 140 | 141 | %% 142 | function [x_sil y_sil] = removeSilentFrames(x, y, range, N, K) 143 | % [X_SIL Y_SIL] = REMOVESILENTFRAMES(X, Y, RANGE, N, K) X and Y 144 | % are segmented with frame-length N and overlap K, where the maximum energy 145 | % of all frames of X is determined, say X_MAX. X_SIL and Y_SIL are the 146 | % reconstructed signals, excluding the frames, where the energy of a frame 147 | % of X is smaller than X_MAX-RANGE 148 | 149 | x = x(:); 150 | y = y(:); 151 | 152 | frames = 1:K:(length(x)-N); 153 | w = ml_hanning(N); 154 | msk = zeros(size(frames)); 155 | 156 | for j = 1:length(frames) 157 | jj = frames(j):(frames(j)+N-1); 158 | msk(j) = 20*log10(norm(x(jj).*w)./sqrt(N)); 159 | end 160 | 161 | msk = (msk-max(msk)+range)>0; 162 | count = 1; 163 | 164 | x_sil = zeros(size(x)); 165 | y_sil = zeros(size(y)); 166 | 167 | for j = 1:length(frames) 168 | if msk(j) 169 | jj_i = frames(j):(frames(j)+N-1); 170 | jj_o = frames(count):(frames(count)+N-1); 171 | x_sil(jj_o) = x_sil(jj_o) + x(jj_i).*w; 172 | y_sil(jj_o) = y_sil(jj_o) + y(jj_i).*w; 173 | count = count+1; 174 | end 175 | end 176 | 177 | x_sil = x_sil(1:jj_o(end)); 178 | y_sil = y_sil(1:jj_o(end)); 179 | 180 | %% 181 | function rho = taa_corr(x, y) 182 | % RHO = TAA_CORR(X, Y) Returns correlation coeffecient between column 183 | % vectors x and y. Gives same results as 'corr' from statistics toolbox. 184 | xn = x-mean(x); 185 | xn = xn/sqrt(sum(xn.^2)); 186 | yn = y-mean(y); 187 | yn = yn/sqrt(sum(yn.^2)); 188 | rho = sum(xn.*yn); 189 | -------------------------------------------------------------------------------- /tests/octave/thirdoct.m: -------------------------------------------------------------------------------- 1 | function [A cf] = thirdoct(fs, N_fft, numBands, mn) 2 | 3 | f = linspace(0, fs, N_fft+1); 4 | f = f(1:(N_fft/2+1)); 5 | k = 0:(numBands-1); 6 | cf = 2.^(k/3)*mn; 7 | fl = sqrt((2.^(k/3)*mn).*2.^((k-1)/3)*mn); 8 | fr = sqrt((2.^(k/3)*mn).*2.^((k+1)/3)*mn); 9 | A = zeros(numBands, length(f)); 10 | 11 | for i = 1:(length(cf)) 12 | [a b] = min((f-fl(i)).^2); 13 | fl(i) = f(b); 14 | fl_ii = b; 15 | 16 | [a b] = min((f-fr(i)).^2); 17 | fr(i) = f(b); 18 | fr_ii = b; 19 | A(i,fl_ii:(fr_ii-1)) = 1; 20 | end 21 | 22 | rnk = sum(A, 2); 23 | numBands = find((rnk(2:end)>=rnk(1:(end-1))) & (rnk(2:end)~=0)~=0, 1, 'last' )+1; 24 | A = A(1:numBands, :); 25 | cf = cf(1:numBands); 26 | -------------------------------------------------------------------------------- /tests/test_matlab_python.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import matlab.engine 3 | import numpy as np 4 | import scipy 5 | from numpy.testing import assert_allclose 6 | from pystoi.stoi import N_FRAME, NFFT, FS 7 | 8 | ATOL = 1e-5 9 | 10 | eng = matlab.engine.start_matlab() 11 | eng.cd('matlab/') 12 | 13 | 14 | def test_hanning(): 15 | """ Compare scipy and Matlab hanning window. 16 | Matlab returns a N+2 size window without first and last samples""" 17 | hanning = scipy.hanning(N_FRAME+2)[1:-1] 18 | hanning_m = eng.hanning(float(N_FRAME)) 19 | hanning_m = np.array(hanning_m._data) 20 | assert_allclose(hanning, hanning_m, atol=ATOL) 21 | 22 | 23 | def test_fft(): 24 | x = np.random.randn(N_FRAME, ) 25 | x_m = matlab.double(list(x)) 26 | fft_m = eng.fft(x_m, NFFT) 27 | fft_m = np.array(fft_m).transpose() 28 | fft_m = fft_m[0:NFFT//2+1, 0] 29 | fft = np.fft.rfft(x, n=NFFT) 30 | assert_allclose(fft, fft_m, atol=ATOL) 31 | 32 | 33 | def test_resampy(): 34 | """ Compare matlab and librosa resample : FAILING """ 35 | from resampy import resample 36 | from pystoi.stoi import FS 37 | import matlab_wrapper 38 | matlab = matlab_wrapper.MatlabSession() 39 | matlab.put('FS', float(FS)) 40 | RTOL = 1e-4 41 | 42 | for fs in [8000, 11025, 16000, 22050, 32000, 44100, 48000]: 43 | x = np.random.randn(2*fs,) 44 | x_r = resample(x, fs, FS) 45 | matlab.put('x', x) 46 | matlab.put('fs', float(fs)) 47 | matlab.eval('x_r = resample(x, FS, fs)') 48 | assert_allclose(x_r, matlab.get('x_r'), atol=ATOL, rtol=RTOL) 49 | 50 | 51 | def test_nnresample(): 52 | """ Compare matlab and nnresample resample : FAILING """ 53 | from nnresample import resample 54 | from pystoi.stoi import FS 55 | import matlab_wrapper 56 | matlab = matlab_wrapper.MatlabSession() 57 | matlab.put('FS', float(FS)) 58 | RTOL = 1e-4 59 | 60 | for fs in [8000, 11025, 16000, 22050, 32000, 44100, 48000]: 61 | x = np.random.randn(2*fs,) 62 | x_r = resample(x, FS, fs) 63 | matlab.put('x', x) 64 | matlab.put('fs', float(fs)) 65 | matlab.eval('x_r = resample(x, FS, fs)') 66 | assert_allclose(x_r, matlab.get('x_r'), atol=ATOL, rtol=RTOL) 67 | 68 | 69 | if __name__ == '__main__': 70 | pytest.main([__file__]) 71 | -------------------------------------------------------------------------------- /tests/test_overlap_and_add.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_allclose 3 | 4 | from pystoi.stoi import N_FRAME 5 | from pystoi.utils import _overlap_and_add 6 | 7 | 8 | def test_OLA_vectorisation(): 9 | """test the vectorised overlap_and_add comparing to the old one""" 10 | 11 | def old_overlap_and_app(x_frames, hop): 12 | num_frames, framelen = x_frames.shape 13 | x_sil = np.zeros((num_frames - 1) * hop + framelen) 14 | for i in range(num_frames): 15 | x_sil[range(i * hop, i * hop + framelen)] += x_frames[i, :] 16 | return x_sil 17 | 18 | # Initialize 19 | x = np.random.randn(1000 * N_FRAME) 20 | # Add silence segment 21 | silence = np.zeros(10 * N_FRAME) 22 | x = np.concatenate([x[: 500 * N_FRAME], silence, x[500 * N_FRAME :]]) 23 | x = x.reshape([-1, N_FRAME]) 24 | xs = old_overlap_and_app(x, N_FRAME // 2) 25 | xs_vectorise = _overlap_and_add(x, N_FRAME // 2) 26 | assert_allclose(xs, xs_vectorise) 27 | -------------------------------------------------------------------------------- /tests/test_python_octave.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ Unit tests for Octave """ 4 | import numpy as np 5 | from numpy.testing import assert_allclose 6 | from oct2py import octave 7 | import pytest 8 | import scipy 9 | 10 | from pystoi.stoi import FS, N_FRAME, NFFT 11 | from pystoi.utils import resample_oct 12 | 13 | ATOL = 1e-5 14 | 15 | def test_hanning(): 16 | """ Compare scipy and Matlab hanning window. 17 | 18 | Matlab returns a N+2 size window without first and last samples. 19 | A custom Octave function has been written to mimic this 20 | behavior.""" 21 | hanning = scipy.hanning(N_FRAME+2)[1:-1] 22 | hanning_m = np.squeeze(octave.feval('octave/ml_hanning.m', N_FRAME)) 23 | assert_allclose(hanning, hanning_m, atol=ATOL) 24 | 25 | 26 | def test_fft(): 27 | """ Compare FFT to Octave. """ 28 | x = np.random.randn(NFFT) 29 | fft_m = np.squeeze(octave.fft(x)) 30 | fft_m = fft_m[:NFFT//2+1] 31 | fft = np.fft.rfft(x, n=NFFT) 32 | assert_allclose(fft, fft_m, atol=ATOL) 33 | 34 | 35 | def test_resample(): 36 | """ Compare Octave and SciPy resampling. 37 | Both packages use polyphase resampling with a Kaiser window. We use 38 | the window designed by Octave in the SciPy resampler.""" 39 | RTOL = 1e-4 40 | for fs in [8000, 11025, 16000, 22050, 32000, 44100, 48000]: 41 | x = np.random.randn(2 * fs) 42 | octave.eval('pkg load signal') 43 | x_m, h = octave.resample(x, float(FS), float(fs), nout=2) 44 | h = np.squeeze(h) 45 | x_m = np.squeeze(x_m) 46 | x_r = resample_oct(x, FS, fs) 47 | assert_allclose(x_r, x_m, atol=ATOL, rtol=RTOL) 48 | 49 | 50 | if __name__ == '__main__': 51 | pytest.main([__file__]) 52 | -------------------------------------------------------------------------------- /tests/test_stoi.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import matlab.engine 3 | import numpy as np 4 | import scipy 5 | from numpy.testing import assert_allclose 6 | from pystoi.stoi import stoi 7 | from pystoi.stoi import FS, N_FRAME, NFFT, NUMBAND, MINFREQ, N, BETA, DYN_RANGE 8 | 9 | RTOL = 1e-4 10 | ATOL = 1e-4 11 | 12 | eng = matlab.engine.start_matlab() 13 | eng.cd('matlab/') 14 | 15 | def test_stoi_good_fs(): 16 | """ Test STOI at sampling frequency of 10kHz. """ 17 | x = np.random.randn(2*FS, ) 18 | y = np.random.randn(2*FS, ) 19 | stoi_out = stoi(x, y, FS) 20 | x_m = matlab.double(list(x)) 21 | y_m = matlab.double(list(y)) 22 | stoi_out_m = eng.stoi(x_m, y_m, float(FS)) 23 | assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL) 24 | 25 | 26 | def test_estoi_good_fs(): 27 | """ Test extended STOI at sampling frequency of 10kHz. """ 28 | x = np.random.randn(2*FS, ) 29 | y = np.random.randn(2*FS, ) 30 | estoi_out = stoi(x, y, FS, extended=True) 31 | x_m = matlab.double(list(x)) 32 | y_m = matlab.double(list(y)) 33 | estoi_out_m = eng.estoi(x_m, y_m, float(FS)) 34 | assert_allclose(estoi_out, estoi_out_m, atol=ATOL, rtol=RTOL) 35 | 36 | 37 | def test_stoi_downsample(): 38 | """ Test STOI at sampling frequency below 10 kHz. 39 | PASSES FOR : RTOL = 1e-4 / ATOL = 1e-4. """ 40 | for fs in [11025, 16000, 22050, 32000, 44100, 48000]: 41 | x = np.random.randn(2*fs, ) 42 | y = np.random.randn(2*fs, ) 43 | stoi_out = stoi(x, y, fs) 44 | x_m = matlab.double(list(x)) 45 | y_m = matlab.double(list(y)) 46 | stoi_out_m = eng.stoi(x_m, y_m, float(fs)) 47 | assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL) 48 | 49 | 50 | def test_stoi_upsample(): 51 | """ Test STOI at sampling frequency above 10 kHz. 52 | PASSES FOR : RTOL = 1e-3 / ATOL = 1e-3. """ 53 | for fs in [8000]: 54 | x = np.random.randn(2*fs, ) 55 | y = np.random.randn(2*fs, ) 56 | stoi_out = stoi(x, y, fs) 57 | x_m = matlab.double(list(x)) 58 | y_m = matlab.double(list(y)) 59 | stoi_out_m = eng.stoi(x_m, y_m, float(fs)) 60 | assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL) 61 | 62 | 63 | def test_stoi_matlab_resample(): 64 | """ Test STOI with any sampling frequency, where Matlab is doing 65 | all the resampling. Successful test.""" 66 | from pystoi.stoi import FS 67 | import matlab_wrapper 68 | matlab = matlab_wrapper.MatlabSession() 69 | matlab.workspace.cd('matlab/') 70 | matlab.put('FS', float(FS)) 71 | for fs in [8000, 11025, 16000, 22050, 32000, 44100, 48000]: 72 | matlab.put('fs', float(fs)) 73 | x = np.random.randn(2*fs,) 74 | y = np.random.randn(2*fs, ) 75 | matlab.put('x', x) 76 | matlab.put('y', y) 77 | matlab.eval('x_r = resample(x, FS, fs)') 78 | matlab.eval('y_r = resample(y, FS, fs)') 79 | x_r = matlab.get('x_r') 80 | y_r = matlab.get('y_r') 81 | stoi_out = stoi(x_r, y_r, FS) 82 | stoi_out_m = matlab.eval('stoi_out_m = stoi(x_r, y_r, FS)') 83 | assert_allclose(stoi_out, matlab.get('stoi_out_m'), atol=ATOL, rtol=RTOL) 84 | 85 | 86 | """ 87 | Conclusion : 88 | The difference between the original Matlab and this STOI comes from the 89 | resampling method which uses different filters and interpolations. 90 | For all applications, a 1e-3 relative precision on STOI results is more 91 | than enough. 92 | """ 93 | 94 | 95 | if __name__ == '__main__': 96 | pytest.main([__file__]) 97 | -------------------------------------------------------------------------------- /tests/test_stoi_octave.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import numpy as np 4 | from numpy.testing import assert_allclose 5 | from oct2py import octave 6 | import pytest 7 | 8 | from pystoi.stoi import FS, stoi 9 | 10 | RTOL = 1e-6 11 | ATOL = 1e-6 12 | 13 | def test_stoi_good_fs(): 14 | """ Test STOI at sampling frequency of 10kHz. """ 15 | x = np.random.randn(2 * FS) 16 | y = np.random.randn(2 * FS) 17 | stoi_out = stoi(x, y, FS) 18 | stoi_out_m = octave.feval('octave/stoi.m', x, y, float(FS)) 19 | assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL) 20 | 21 | 22 | def test_estoi_good_fs(): 23 | """ Test extended STOI at sampling frequency of 10kHz. """ 24 | x = np.random.randn(2 * FS) 25 | y = np.random.randn(2 * FS) 26 | estoi_out = stoi(x, y, FS, extended=True) 27 | estoi_out_m = octave.feval('octave/estoi.m', x, y, float(FS)) 28 | assert_allclose(estoi_out, estoi_out_m, atol=ATOL, rtol=RTOL) 29 | 30 | 31 | def test_stoi_downsample(): 32 | """ Test STOI at sampling frequency below 10 kHz. """ 33 | for fs in [11025, 16000, 22050, 32000, 44100, 48000]: 34 | x = np.random.randn(2 * fs) 35 | y = np.random.randn(2 * fs) 36 | octave.eval('pkg load signal') 37 | stoi_out = stoi(x, y, fs) 38 | stoi_out_m = octave.feval('octave/stoi.m', x, y, float(fs)) 39 | assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL) 40 | 41 | 42 | def test_stoi_upsample(): 43 | """ Test STOI at sampling frequency above 10 kHz. """ 44 | for fs in [8000]: 45 | x = np.random.randn(2 * fs) 46 | y = np.random.randn(2 * fs) 47 | octave.eval('pkg load signal') 48 | stoi_out = stoi(x, y, fs) 49 | stoi_out_m = octave.feval('octave/stoi.m', x, y, float(fs)) 50 | assert_allclose(stoi_out, stoi_out_m, atol=ATOL, rtol=RTOL) 51 | 52 | 53 | if __name__ == '__main__': 54 | pytest.main([__file__]) 55 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import matlab.engine 3 | import numpy as np 4 | import scipy 5 | from numpy.testing import assert_allclose 6 | from pystoi.utils import thirdoct, stft, remove_silent_frames 7 | from pystoi.stoi import FS, N_FRAME, NFFT, NUMBAND, MINFREQ, N, BETA, DYN_RANGE, OBM 8 | 9 | ATOL = 1e-5 10 | 11 | eng = matlab.engine.start_matlab() 12 | eng.cd('matlab/') 13 | 14 | 15 | def test_thirdoct(): 16 | obm_m, cf_m = eng.thirdoct(float(FS), float(NFFT), float(NUMBAND), 17 | float(MINFREQ), nargout=2) 18 | obm, cf = thirdoct(FS, NFFT, NUMBAND, MINFREQ) 19 | obm_m = np.array(obm_m) 20 | cf_m = np.array(cf_m).transpose().squeeze() 21 | assert_allclose(obm, obm_m, atol=ATOL) 22 | assert_allclose(cf, cf_m, atol=ATOL) 23 | 24 | 25 | def test_stdft(): 26 | x = np.random.randn(2*FS, ) 27 | x_m = matlab.double(list(x)) 28 | spec_m = eng.stdft(x_m, float(N_FRAME), float(N_FRAME/2), float(NFFT)) 29 | spec_m = np.array(spec_m) 30 | spec_m = spec_m[:, 0:(NFFT/2+1)].transpose() 31 | spec = stft(x, N_FRAME, NFFT, overlap=2).transpose() 32 | assert_allclose(spec, spec_m, atol=ATOL) 33 | 34 | 35 | def test_removesf(): 36 | # Initialize 37 | x = np.random.randn(2*FS, ) 38 | y = np.random.randn(2*FS, ) 39 | # Add silence segment 40 | silence = np.zeros(3*NFFT, ) 41 | x = np.concatenate([x[:FS], silence, x[FS:]]) 42 | y = np.concatenate([y[:FS], silence, y[FS:]]) 43 | x_m = matlab.double(list(x)) 44 | y_m = matlab.double(list(y)) 45 | xs, ys = remove_silent_frames(x, y, DYN_RANGE, N_FRAME, N_FRAME/2) 46 | xs_m, ys_m = eng.removeSilentFrames(x_m, y_m, float(DYN_RANGE), 47 | float(N_FRAME), float(N_FRAME/2), 48 | nargout=2) 49 | xs_m, ys_m = np.array(xs_m._data), np.array(ys_m._data) 50 | assert_allclose(xs, xs_m, atol=ATOL) 51 | assert_allclose(ys, ys_m, atol=ATOL) 52 | 53 | 54 | def test_apply_OBM(): 55 | obm_m, cf_m = eng.thirdoct(float(FS), float(NFFT), float(NUMBAND), 56 | float(MINFREQ), nargout=2) 57 | x = np.random.randn(2*FS, ) 58 | x_m = matlab.double(list(x)) 59 | x_tob_m = eng.applyOBM(x_m, obm_m, float(N_FRAME), float(NFFT), float(NUMBAND)) 60 | x_tob_m = np.array(x_tob_m) 61 | x_spec = stft(x, N_FRAME, NFFT, overlap=2).transpose() 62 | x_tob = np.zeros((NUMBAND, x_spec.shape[1])) 63 | for i in range(x_tob.shape[1]): 64 | x_tob[:, i] = np.sqrt(np.matmul(OBM, np.square(np.abs(x_spec[:, i])))) 65 | assert_allclose(x_tob, x_tob_m, atol=ATOL) 66 | 67 | 68 | if __name__ == '__main__': 69 | pytest.main([__file__]) 70 | -------------------------------------------------------------------------------- /tests/test_utils_octave.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Test utilities based on Octave""" 4 | import numpy as np 5 | from numpy.testing import assert_allclose 6 | from oct2py import octave 7 | import pytest 8 | 9 | from pystoi.stoi import DYN_RANGE, FS, MINFREQ, N_FRAME, NFFT, NUMBAND, OBM 10 | from pystoi.utils import remove_silent_frames, stft, thirdoct 11 | 12 | ATOL = 1e-5 13 | 14 | 15 | def test_thirdoct(): 16 | """Test thirdoct by comparing to Octave""" 17 | obm_m, cf_m = octave.feval('octave/thirdoct.m', 18 | float(FS), float(NFFT), float(NUMBAND), 19 | float(MINFREQ), nout=2) 20 | obm, cf = thirdoct(FS, NFFT, NUMBAND, MINFREQ) 21 | obm_m = np.array(obm_m) 22 | cf_m = np.array(cf_m).transpose().squeeze() 23 | assert_allclose(obm, obm_m, atol=ATOL) 24 | assert_allclose(cf, cf_m, atol=ATOL) 25 | 26 | 27 | def test_stdft(): 28 | """Test stdft by comparing to Octave""" 29 | x = np.random.randn(2 * FS) 30 | spec_m = octave.feval('octave/stdft.m', 31 | x, float(N_FRAME), float(N_FRAME/2), float(NFFT)) 32 | spec_m = spec_m[:, 0:(NFFT // 2 + 1)].transpose() 33 | spec = stft(x, N_FRAME, NFFT, overlap=2).transpose() 34 | assert_allclose(spec, spec_m, atol=ATOL) 35 | 36 | 37 | def test_removesf(): 38 | """Test remove_silent_frames by comparing to Octave""" 39 | # Initialize 40 | x = np.random.randn(2 * FS) 41 | y = np.random.randn(2 * FS) 42 | # Add silence segment 43 | silence = np.zeros(3 * NFFT, ) 44 | x = np.concatenate([x[:FS], silence, x[FS:]]) 45 | y = np.concatenate([y[:FS], silence, y[FS:]]) 46 | xs, ys = remove_silent_frames(x, y, DYN_RANGE, N_FRAME, N_FRAME // 2) 47 | xs_m, ys_m = octave.feval('octave/removeSilentFrames.m', 48 | x, y, float(DYN_RANGE), 49 | float(N_FRAME), 50 | float(N_FRAME / 2), 51 | nout=2) 52 | xs_m = np.squeeze(xs_m) 53 | ys_m = np.squeeze(ys_m) 54 | assert_allclose(xs, xs_m, atol=ATOL) 55 | assert_allclose(ys, ys_m, atol=ATOL) 56 | 57 | 58 | def test_apply_OBM(): 59 | """Test apply_OBM by comparing to Octave""" 60 | obm_m, _ = octave.feval('octave/thirdoct.m', 61 | float(FS), float(NFFT), float(NUMBAND), 62 | float(MINFREQ), nout=2) 63 | x = np.random.randn(2 * FS) 64 | x_tob_m = octave.feval('octave/applyOBM', 65 | x, obm_m, float(N_FRAME), float(NFFT), 66 | float(NUMBAND)) 67 | x_tob_m = np.array(x_tob_m) 68 | x_spec = stft(x, N_FRAME, NFFT, overlap=2).transpose() 69 | x_tob = np.zeros((NUMBAND, x_spec.shape[1])) 70 | for i in range(x_tob.shape[1]): 71 | x_tob[:, i] = np.sqrt(np.matmul(OBM, np.square(np.abs(x_spec[:, i])))) 72 | assert_allclose(x_tob, x_tob_m, atol=ATOL) 73 | 74 | 75 | if __name__ == '__main__': 76 | pytest.main([__file__]) 77 | --------------------------------------------------------------------------------