├── data_generator.py ├── mir_eval.py ├── mir_util.py ├── params.py ├── readme.md ├── tasnet-architecture.png ├── tf_net.py ├── tf_test.py └── tf_train.py /data_generator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Required: 3 | mixture: [B, K, L] 4 | source: [B, nspk, K, L] 5 | get_batch 6 | ''' 7 | 8 | import numpy as np 9 | import librosa 10 | import os 11 | from params import * 12 | import pickle 13 | 14 | 15 | class DataGenerator(object): 16 | def __init__(self, batch_size, max_k, 17 | save_dir=None, data_dir=None, name='data_gen'): 18 | self.name = name 19 | self.batch_size = batch_size 20 | self.max_k = max_k 21 | self.data_dir = data_dir 22 | self.save_dir = save_dir 23 | if save_dir is not None and not os.path.exists(save_dir): 24 | os.mkdir(save_dir) 25 | self.data_subdir = ['s1', 's2', 'mix'] 26 | self.data_type = ['tr', 'cv'] 27 | 28 | self.spks = [] 29 | self.init_samples() 30 | 31 | self.epoch = 0 32 | self.idx = 0 33 | 34 | def init_samples(self): 35 | self.samples = {'mix': [], 's': []} 36 | self.sample_size = 0 37 | 38 | def gen_data(self): 39 | if self.data_dir and self.save_dir is None: 40 | raise AssertionError 41 | for dt in self.data_type: 42 | self.init_samples() 43 | save_cnt = 1 44 | save_dir = os.path.join(self.save_dir, dt) 45 | if not os.path.exists(save_dir): 46 | os.mkdir(save_dir) 47 | dt_path = os.path.join(self.data_dir, dt) 48 | dt_mix_path = os.path.join(dt_path, 'mix') 49 | dt_s1_path = os.path.join(dt_path, 's1') 50 | dt_s2_path = os.path.join(dt_path, 's2') 51 | 52 | list_mix = os.listdir(dt_mix_path) 53 | for wav_file in list_mix: 54 | if not wav_file.endswith('.wav'): 55 | continue 56 | # print(wav_file) 57 | mix_path = os.path.join(dt_mix_path, wav_file) 58 | s1_path = os.path.join(dt_s1_path, wav_file) 59 | s2_path = os.path.join(dt_s2_path, wav_file) 60 | 61 | spk1 = wav_file[:3] 62 | spk2 = wav_file.split(sep='_')[2][:3] 63 | 64 | if spk1 not in self.spks: 65 | self.spks.append(spk1) 66 | if spk2 not in self.spks: 67 | self.spks.append(spk2) 68 | 69 | mix, _ = librosa.load(mix_path, sr=sr) 70 | s1, _ = librosa.load(s1_path, sr=sr) 71 | s2, _ = librosa.load(s2_path, sr=sr) 72 | self.get_sample(mix, s1, s2, [spk1, spk2]) 73 | 74 | self.sample_size = len(self.samples['mix']) 75 | if self.sample_size % 50 == 0: 76 | print(self.sample_size) 77 | if self.sample_size >= save_cnt * 50000: 78 | pickle.dump(self.samples, 79 | open(save_dir + '/raw_' + str(self.max_k) + '-' + str(self.sample_size) + '.pkl', 80 | 'wb')) 81 | save_cnt += 1 82 | pickle.dump(self.samples, 83 | open(save_dir + '/raw_' + str(self.max_k) + '-' + str(self.sample_size) + '.pkl', 84 | 'wb')) 85 | 86 | def get_sample(self, mix, s1, s2, spks): 87 | spk_num = len(spks) 88 | 89 | mix_len = len(mix) 90 | sample_num = int(np.ceil(mix_len / L)) 91 | if sample_num < self.max_k: 92 | sample_num = self.max_k 93 | max_len = sample_num * L 94 | pad_s1 = np.concatenate([s1, np.zeros([max_len - len(s1)])]) 95 | pad_s2 = np.concatenate([s2, np.zeros([max_len - len(s1)])]) 96 | pad_mix = np.concatenate([mix, np.zeros([max_len - len(mix)])]) 97 | 98 | k_ = 0 99 | while k_ + self.max_k <= sample_num: 100 | begin = k_ * L 101 | end = (k_ +self.max_k) * L 102 | sample_mix = pad_mix[begin:end] 103 | sample_s1 = pad_s1[begin:end] 104 | sample_s2 = pad_s2[begin:end] 105 | 106 | sample_mix = np.reshape(sample_mix, [self.max_k, L]) 107 | sample_s1 = np.reshape(sample_s1, [self.max_k, L]) 108 | sample_s2 = np.reshape(sample_s2, [self.max_k, L]) 109 | sample_s = np.dstack((sample_s1, sample_s2)) 110 | sample_s = np.transpose(sample_s, (2, 0, 1)) 111 | 112 | self.samples['mix'].append(sample_mix) 113 | self.samples['s'].append(sample_s) 114 | k_ += self.max_k 115 | 116 | def load_data(self, data_path): 117 | self.samples = pickle.load(open(data_path, 'rb')) 118 | self.sample_size = len(self.samples['mix']) 119 | print('>> {0}: Loading samples from pkl: {1}...'.format(self.name, data_path)) 120 | 121 | def shuffle_dict(self): 122 | rand_per = np.random.permutation(self.sample_size) 123 | self.samples['mix'] = np.array(self.samples['mix'])[rand_per] 124 | self.samples['s'] = np.array(self.samples['s'])[rand_per] 125 | 126 | def get_a_sample(self, mix, s1, s2, spks, max_k): 127 | spk_num = len(spks) 128 | mix_len = len(mix) 129 | sample_num = int(np.ceil(mix_len / L / max_k)) * max_k 130 | max_len = sample_num * L 131 | pad_s1 = np.concatenate([s1, np.zeros([max_len - len(s1)])]) 132 | pad_s2 = np.concatenate([s2, np.zeros([max_len - len(s1)])]) 133 | pad_mix = np.concatenate([mix, np.zeros([max_len - len(mix)])]) 134 | 135 | test_sample = { 136 | 'mix': [], 137 | 's': [], 138 | } 139 | k_ = 0 140 | while k_ + self.max_k <= sample_num: 141 | begin = k_ * L 142 | end = (k_ + max_k) * L 143 | sample_mix = pad_mix[begin:end] 144 | sample_s1 = pad_s1[begin:end] 145 | sample_s2 = pad_s2[begin:end] 146 | 147 | sample_mix = np.reshape(sample_mix, [max_k, L]) 148 | sample_s1 = np.reshape(sample_s1, [max_k, L]) 149 | sample_s2 = np.reshape(sample_s2, [max_k, L]) 150 | sample_s = np.dstack((sample_s1, sample_s2)) 151 | sample_s = np.transpose(sample_s, (2, 0, 1)) 152 | 153 | test_sample['mix'].append(sample_mix) 154 | test_sample['s'].append(sample_s) 155 | 156 | k_ += max_k 157 | 158 | return test_sample 159 | 160 | def gen_batch(self, batch_size=None): 161 | if batch_size is None: 162 | batch_size = self.batch_size 163 | n_begin = self.idx 164 | n_end = self.idx + batch_size 165 | if n_end >= self.sample_size: 166 | # rewire the index 167 | self.idx = 0 168 | n_begin = self.idx 169 | n_end = self.idx + batch_size 170 | self.epoch += 1 171 | self.shuffle_dict() 172 | self.idx += batch_size 173 | samples = { 174 | 'mix': self.samples['mix'][n_begin: n_end], 175 | 's': self.samples['s'][n_begin: n_end] 176 | } 177 | return samples 178 | 179 | 180 | if __name__ == '__main__': 181 | # data_gen = DataGenerator(batch_size=1, max_k=int(0.5/0.005), save_dir='/home/grz/data/SSSR/wsj0_tasnet/', 182 | # data_dir='/home/grz/data/SSSR/wsj0/min/', 183 | # name='gen_data') 184 | # data_gen.gen_data() 185 | # data_gen = DataGenerator(batch_size=1, max_k=int(4/0.005), save_dir='/home/grz/data/SSSR/wsj0_tasnet/', 186 | # data_dir='/home/grz/data/SSSR/wsj0/min/', 187 | # name='gen_data') 188 | # data_gen.gen_data() 189 | 190 | data_gen = DataGenerator(batch_size=1, max_k=int(2/0.005), save_dir='/home/grz/data/SSSR/wsj0_tasnet/', 191 | data_dir='/home/grz/data/SSSR/wsj0/min/', 192 | name='gen_data') 193 | data_gen.gen_data() -------------------------------------------------------------------------------- /mir_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Source separation algorithms attempt to extract recordings of individual 4 | sources from a recording of a mixture of sources. Evaluation methods for 5 | source separation compare the extracted sources from reference sources and 6 | attempt to measure the perceptual quality of the separation. 7 | 8 | See also the bss_eval MATLAB toolbox: 9 | http://bass-db.gforge.inria.fr/bss_eval/ 10 | 11 | Conventions 12 | ----------- 13 | 14 | An audio signal is expected to be in the format of a 1-dimensional array where 15 | the entries are the samples of the audio signal. When providing a group of 16 | estimated or reference sources, they should be provided in a 2-dimensional 17 | array, where the first dimension corresponds to the source number and the 18 | second corresponds to the samples. 19 | 20 | Metrics 21 | ------- 22 | 23 | * :func:`mir_eval.separation.bss_eval_sources`: Computes the bss_eval_sources 24 | metrics from bss_eval, which optionally optimally match the estimated sources 25 | to the reference sources and measure the distortion and artifacts present in 26 | the estimated sources as well as the interference between them. 27 | 28 | * :func:`mir_eval.separation.bss_eval_sources_framewise`: Computes the 29 | bss_eval_sources metrics on a frame-by-frame basis. 30 | 31 | * :func:`mir_eval.separation.bss_eval_images`: Computes the bss_eval_images 32 | metrics from bss_eval, which includes the metrics in 33 | :func:`mir_eval.separation.bss_eval_sources` plus the image to spatial 34 | distortion ratio. 35 | 36 | * :func:`mir_eval.separation.bss_eval_images_framewise`: Computes the 37 | bss_eval_images metrics on a frame-by-frame basis. 38 | 39 | References 40 | ---------- 41 | .. [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric 42 | Févotte, "Performance measurement in blind audio source separation," IEEE 43 | Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006. 44 | 45 | 46 | ''' 47 | 48 | import numpy as np 49 | import scipy.fftpack 50 | from scipy.linalg import toeplitz 51 | from scipy.signal import fftconvolve 52 | import collections 53 | import itertools 54 | import warnings 55 | from tasnet import mir_util as util 56 | 57 | 58 | # The maximum allowable number of sources (prevents insane computational load) 59 | MAX_SOURCES = 100 60 | 61 | 62 | def validate(reference_sources, estimated_sources): 63 | """Checks that the input data to a metric are valid, and throws helpful 64 | errors if not. 65 | 66 | Parameters 67 | ---------- 68 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 69 | matrix containing true sources 70 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 71 | matrix containing estimated sources 72 | 73 | """ 74 | 75 | if reference_sources.shape != estimated_sources.shape: 76 | raise ValueError('The shape of estimated sources and the true ' 77 | 'sources should match. reference_sources.shape ' 78 | '= {}, estimated_sources.shape ' 79 | '= {}'.format(reference_sources.shape, 80 | estimated_sources.shape)) 81 | 82 | if reference_sources.ndim > 3 or estimated_sources.ndim > 3: 83 | raise ValueError('The number of dimensions is too high (must be less ' 84 | 'than 3). reference_sources.ndim = {}, ' 85 | 'estimated_sources.ndim ' 86 | '= {}'.format(reference_sources.ndim, 87 | estimated_sources.ndim)) 88 | 89 | if reference_sources.size == 0: 90 | warnings.warn("reference_sources is empty, should be of size " 91 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 92 | "be empty np.ndarrays") 93 | elif _any_source_silent(reference_sources): 94 | raise ValueError('All the reference sources should be non-silent (not ' 95 | 'all-zeros), but at least one of the reference ' 96 | 'sources is all 0s, which introduces ambiguity to the' 97 | ' evaluation. (Otherwise we can add infinitely many ' 98 | 'all-zero sources.)') 99 | 100 | if estimated_sources.size == 0: 101 | warnings.warn("estimated_sources is empty, should be of size " 102 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 103 | "be empty np.ndarrays") 104 | elif _any_source_silent(estimated_sources): 105 | raise ValueError('All the estimated sources should be non-silent (not ' 106 | 'all-zeros), but at least one of the estimated ' 107 | 'sources is all 0s. Since we require each reference ' 108 | 'source to be non-silent, having a silent estimated ' 109 | 'source will result in an underdetermined system.') 110 | 111 | if (estimated_sources.shape[0] > MAX_SOURCES or 112 | reference_sources.shape[0] > MAX_SOURCES): 113 | raise ValueError('The supplied matrices should be of shape (nsrc,' 114 | ' nsampl) but reference_sources.shape[0] = {} and ' 115 | 'estimated_sources.shape[0] = {} which is greater ' 116 | 'than mir_eval.separation.MAX_SOURCES = {}. To ' 117 | 'override this check, set ' 118 | 'mir_eval.separation.MAX_SOURCES to a ' 119 | 'larger value.'.format(reference_sources.shape[0], 120 | estimated_sources.shape[0], 121 | MAX_SOURCES)) 122 | 123 | 124 | def _any_source_silent(sources): 125 | """Returns true if the parameter sources has any silent first dimensions""" 126 | return np.any(np.all(np.sum( 127 | sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1)) 128 | 129 | 130 | def bss_eval_sources(reference_sources, estimated_sources, 131 | compute_permutation=True): 132 | """ 133 | Ordering and measurement of the separation quality for estimated source 134 | signals in terms of filtered true source, interference and artifacts. 135 | 136 | The decomposition allows a time-invariant filter distortion of length 137 | 512, as described in Section III.B of [#vincent2006performance]_. 138 | 139 | Passing ``False`` for ``compute_permutation`` will improve the computation 140 | performance of the evaluation; however, it is not always appropriate and 141 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_sources. 142 | 143 | Examples 144 | -------- 145 | >>> # reference_sources[n] should be an ndarray of samples of the 146 | >>> # n'th reference source 147 | >>> # estimated_sources[n] should be the same for the n'th estimated 148 | >>> # source 149 | >>> (sdr, sir, sar, 150 | ... perm) = mir_eval.separation.bss_eval_sources(reference_sources, 151 | ... estimated_sources) 152 | 153 | Parameters 154 | ---------- 155 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 156 | matrix containing true sources (must have same shape as 157 | estimated_sources) 158 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 159 | matrix containing estimated sources (must have same shape as 160 | reference_sources) 161 | compute_permutation : bool, optional 162 | compute permutation of estimate/source combinations (True by default) 163 | 164 | Returns 165 | ------- 166 | sdr : np.ndarray, shape=(nsrc,) 167 | vector of Signal to Distortion Ratios (SDR) 168 | sir : np.ndarray, shape=(nsrc,) 169 | vector of Source to Interference Ratios (SIR) 170 | sar : np.ndarray, shape=(nsrc,) 171 | vector of Sources to Artifacts Ratios (SAR) 172 | perm : np.ndarray, shape=(nsrc,) 173 | vector containing the best ordering of estimated sources in 174 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 175 | true source number ``j``). Note: ``perm`` will be ``[0, 1, ..., 176 | nsrc-1]`` if ``compute_permutation`` is ``False``. 177 | 178 | References 179 | ---------- 180 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 181 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 182 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 183 | (2007-2010): Achievements and remaining challenges", Signal Processing, 184 | 92, pp. 1928-1936, 2012. 185 | 186 | """ 187 | 188 | # make sure the input is of shape (nsrc, nsampl) 189 | if estimated_sources.ndim == 1: 190 | estimated_sources = estimated_sources[np.newaxis, :] 191 | if reference_sources.ndim == 1: 192 | reference_sources = reference_sources[np.newaxis, :] 193 | 194 | validate(reference_sources, estimated_sources) 195 | # If empty matrices were supplied, return empty lists (special case) 196 | if reference_sources.size == 0 or estimated_sources.size == 0: 197 | return np.array([]), np.array([]), np.array([]), np.array([]) 198 | 199 | nsrc = estimated_sources.shape[0] 200 | 201 | # does user desire permutations? 202 | if compute_permutation: 203 | # compute criteria for all possible pair matches 204 | sdr = np.empty((nsrc, nsrc)) 205 | sir = np.empty((nsrc, nsrc)) 206 | sar = np.empty((nsrc, nsrc)) 207 | for jest in range(nsrc): 208 | for jtrue in range(nsrc): 209 | s_true, e_spat, e_interf, e_artif = \ 210 | _bss_decomp_mtifilt(reference_sources, 211 | estimated_sources[jest], 212 | jtrue, 512) 213 | sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \ 214 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 215 | 216 | # select the best ordering 217 | perms = list(itertools.permutations(list(range(nsrc)))) 218 | mean_sir = np.empty(len(perms)) 219 | dum = np.arange(nsrc) 220 | for (i, perm) in enumerate(perms): 221 | mean_sir[i] = np.mean(sir[perm, dum]) 222 | popt = perms[np.argmax(mean_sir)] 223 | idx = (popt, dum) 224 | return (sdr[idx], sir[idx], sar[idx], np.asarray(popt)) 225 | else: 226 | # compute criteria for only the simple correspondence 227 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 228 | sdr = np.empty(nsrc) 229 | sir = np.empty(nsrc) 230 | sar = np.empty(nsrc) 231 | for j in range(nsrc): 232 | s_true, e_spat, e_interf, e_artif = \ 233 | _bss_decomp_mtifilt(reference_sources, 234 | estimated_sources[j], 235 | j, 512) 236 | sdr[j], sir[j], sar[j] = \ 237 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 238 | 239 | # return the default permutation for compatibility 240 | popt = np.arange(nsrc) 241 | return (sdr, sir, sar, popt) 242 | 243 | 244 | def bss_eval_sources_framewise(reference_sources, estimated_sources, 245 | window=30*44100, hop=15*44100, 246 | compute_permutation=False): 247 | """Framewise computation of bss_eval_sources 248 | 249 | Please be aware that this function does not compute permutations (by 250 | default) on the possible relations between reference_sources and 251 | estimated_sources due to the dangers of a changing permutation. Therefore 252 | (by default), it assumes that ``reference_sources[i]`` corresponds to 253 | ``estimated_sources[i]``. To enable computing permutations please set 254 | ``compute_permutation`` to be ``True`` and check that the returned ``perm`` 255 | is identical for all windows. 256 | 257 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 258 | using only a single window or are shorter than the window length, the 259 | result of :func:`mir_eval.separation.bss_eval_sources` called on 260 | ``reference_sources`` and ``estimated_sources`` (with the 261 | ``compute_permutation`` parameter passed to 262 | :func:`mir_eval.separation.bss_eval_sources`) is returned. 263 | 264 | Examples 265 | -------- 266 | >>> # reference_sources[n] should be an ndarray of samples of the 267 | >>> # n'th reference source 268 | >>> # estimated_sources[n] should be the same for the n'th estimated 269 | >>> # source 270 | >>> (sdr, sir, sar, 271 | ... perm) = mir_eval.separation.bss_eval_sources_framewise( 272 | reference_sources, 273 | ... estimated_sources) 274 | 275 | Parameters 276 | ---------- 277 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 278 | matrix containing true sources (must have the same shape as 279 | ``estimated_sources``) 280 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 281 | matrix containing estimated sources (must have the same shape as 282 | ``reference_sources``) 283 | window : int, optional 284 | Window length for framewise evaluation (default value is 30s at a 285 | sample rate of 44.1kHz) 286 | hop : int, optional 287 | Hop size for framewise evaluation (default value is 15s at a 288 | sample rate of 44.1kHz) 289 | compute_permutation : bool, optional 290 | compute permutation of estimate/source combinations for all windows 291 | (False by default) 292 | 293 | Returns 294 | ------- 295 | sdr : np.ndarray, shape=(nsrc, nframes) 296 | vector of Signal to Distortion Ratios (SDR) 297 | sir : np.ndarray, shape=(nsrc, nframes) 298 | vector of Source to Interference Ratios (SIR) 299 | sar : np.ndarray, shape=(nsrc, nframes) 300 | vector of Sources to Artifacts Ratios (SAR) 301 | perm : np.ndarray, shape=(nsrc, nframes) 302 | vector containing the best ordering of estimated sources in 303 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 304 | true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for 305 | all windows if ``compute_permutation`` is ``False`` 306 | 307 | """ 308 | 309 | # make sure the input is of shape (nsrc, nsampl) 310 | if estimated_sources.ndim == 1: 311 | estimated_sources = estimated_sources[np.newaxis, :] 312 | if reference_sources.ndim == 1: 313 | reference_sources = reference_sources[np.newaxis, :] 314 | 315 | validate(reference_sources, estimated_sources) 316 | # If empty matrices were supplied, return empty lists (special case) 317 | if reference_sources.size == 0 or estimated_sources.size == 0: 318 | return np.array([]), np.array([]), np.array([]), np.array([]) 319 | 320 | nsrc = reference_sources.shape[0] 321 | 322 | nwin = int( 323 | np.floor((reference_sources.shape[1] - window + hop) / hop) 324 | ) 325 | # if fewer than 2 windows would be evaluated, return the sources result 326 | if nwin < 2: 327 | result = bss_eval_sources(reference_sources, 328 | estimated_sources, 329 | compute_permutation) 330 | return [np.expand_dims(score, -1) for score in result] 331 | 332 | # compute the criteria across all windows 333 | sdr = np.empty((nsrc, nwin)) 334 | sir = np.empty((nsrc, nwin)) 335 | sar = np.empty((nsrc, nwin)) 336 | perm = np.empty((nsrc, nwin)) 337 | 338 | # k iterates across all the windows 339 | for k in range(nwin): 340 | win_slice = slice(k * hop, k * hop + window) 341 | ref_slice = reference_sources[:, win_slice] 342 | est_slice = estimated_sources[:, win_slice] 343 | # check for a silent frame 344 | if (not _any_source_silent(ref_slice) and 345 | not _any_source_silent(est_slice)): 346 | sdr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_sources( 347 | ref_slice, est_slice, compute_permutation 348 | ) 349 | else: 350 | # if we have a silent frame set results as np.nan 351 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 352 | 353 | return sdr, sir, sar, perm 354 | 355 | 356 | def bss_eval_images(reference_sources, estimated_sources, 357 | compute_permutation=True): 358 | """Implementation of the bss_eval_images function from the 359 | BSS_EVAL Matlab toolbox. 360 | 361 | Ordering and measurement of the separation quality for estimated source 362 | signals in terms of filtered true source, interference and artifacts. 363 | This method also provides the ISR measure. 364 | 365 | The decomposition allows a time-invariant filter distortion of length 366 | 512, as described in Section III.B of [#vincent2006performance]_. 367 | 368 | Passing ``False`` for ``compute_permutation`` will improve the computation 369 | performance of the evaluation; however, it is not always appropriate and 370 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_images. 371 | 372 | Examples 373 | -------- 374 | >>> # reference_sources[n] should be an ndarray of samples of the 375 | >>> # n'th reference source 376 | >>> # estimated_sources[n] should be the same for the n'th estimated 377 | >>> # source 378 | >>> (sdr, isr, sir, sar, 379 | ... perm) = mir_eval.separation.bss_eval_images(reference_sources, 380 | ... estimated_sources) 381 | 382 | Parameters 383 | ---------- 384 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 385 | matrix containing true sources 386 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 387 | matrix containing estimated sources 388 | compute_permutation : bool, optional 389 | compute permutation of estimate/source combinations (True by default) 390 | 391 | Returns 392 | ------- 393 | sdr : np.ndarray, shape=(nsrc,) 394 | vector of Signal to Distortion Ratios (SDR) 395 | isr : np.ndarray, shape=(nsrc,) 396 | vector of source Image to Spatial distortion Ratios (ISR) 397 | sir : np.ndarray, shape=(nsrc,) 398 | vector of Source to Interference Ratios (SIR) 399 | sar : np.ndarray, shape=(nsrc,) 400 | vector of Sources to Artifacts Ratios (SAR) 401 | perm : np.ndarray, shape=(nsrc,) 402 | vector containing the best ordering of estimated sources in 403 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 404 | true source number ``j``). Note: ``perm`` will be ``(1,2,...,nsrc)`` 405 | if ``compute_permutation`` is ``False``. 406 | 407 | References 408 | ---------- 409 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 410 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 411 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 412 | (2007-2010): Achievements and remaining challenges", Signal Processing, 413 | 92, pp. 1928-1936, 2012. 414 | 415 | """ 416 | 417 | # make sure the input has 3 dimensions 418 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 419 | estimated_sources = np.atleast_3d(estimated_sources) 420 | reference_sources = np.atleast_3d(reference_sources) 421 | # we will ensure input doesn't have more than 3 dimensions in validate 422 | 423 | validate(reference_sources, estimated_sources) 424 | # If empty matrices were supplied, return empty lists (special case) 425 | if reference_sources.size == 0 or estimated_sources.size == 0: 426 | return np.array([]), np.array([]), np.array([]), \ 427 | np.array([]), np.array([]) 428 | 429 | # determine size parameters 430 | nsrc = estimated_sources.shape[0] 431 | nsampl = estimated_sources.shape[1] 432 | nchan = estimated_sources.shape[2] 433 | 434 | # does the user desire permutation? 435 | if compute_permutation: 436 | # compute criteria for all possible pair matches 437 | sdr = np.empty((nsrc, nsrc)) 438 | isr = np.empty((nsrc, nsrc)) 439 | sir = np.empty((nsrc, nsrc)) 440 | sar = np.empty((nsrc, nsrc)) 441 | for jest in range(nsrc): 442 | for jtrue in range(nsrc): 443 | s_true, e_spat, e_interf, e_artif = \ 444 | _bss_decomp_mtifilt_images( 445 | reference_sources, 446 | np.reshape( 447 | estimated_sources[jest], 448 | (nsampl, nchan), 449 | order='F' 450 | ), 451 | jtrue, 452 | 512 453 | ) 454 | sdr[jest, jtrue], isr[jest, jtrue], \ 455 | sir[jest, jtrue], sar[jest, jtrue] = \ 456 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 457 | 458 | # select the best ordering 459 | perms = list(itertools.permutations(range(nsrc))) 460 | mean_sir = np.empty(len(perms)) 461 | dum = np.arange(nsrc) 462 | for (i, perm) in enumerate(perms): 463 | mean_sir[i] = np.mean(sir[perm, dum]) 464 | popt = perms[np.argmax(mean_sir)] 465 | idx = (popt, dum) 466 | return (sdr[idx], isr[idx], sir[idx], sar[idx], np.asarray(popt)) 467 | else: 468 | # compute criteria for only the simple correspondence 469 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 470 | sdr = np.empty(nsrc) 471 | isr = np.empty(nsrc) 472 | sir = np.empty(nsrc) 473 | sar = np.empty(nsrc) 474 | Gj = [0] * nsrc # prepare G matrics with zeroes 475 | G = np.zeros(1) 476 | for j in range(nsrc): 477 | # save G matrix to avoid recomputing it every call 478 | s_true, e_spat, e_interf, e_artif, Gj_temp, G = \ 479 | _bss_decomp_mtifilt_images(reference_sources, 480 | np.reshape(estimated_sources[j], 481 | (nsampl, nchan), 482 | order='F'), 483 | j, 512, Gj[j], G) 484 | Gj[j] = Gj_temp 485 | sdr[j], isr[j], sir[j], sar[j] = \ 486 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 487 | 488 | # return the default permutation for compatibility 489 | popt = np.arange(nsrc) 490 | return (sdr, isr, sir, sar, popt) 491 | 492 | 493 | def bss_eval_images_framewise(reference_sources, estimated_sources, 494 | window=30*44100, hop=15*44100, 495 | compute_permutation=False): 496 | """Framewise computation of bss_eval_images 497 | 498 | Please be aware that this function does not compute permutations (by 499 | default) on the possible relations between ``reference_sources`` and 500 | ``estimated_sources`` due to the dangers of a changing permutation. 501 | Therefore (by default), it assumes that ``reference_sources[i]`` 502 | corresponds to ``estimated_sources[i]``. To enable computing permutations 503 | please set ``compute_permutation`` to be ``True`` and check that the 504 | returned ``perm`` is identical for all windows. 505 | 506 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 507 | using only a single window or are shorter than the window length, the 508 | result of ``bss_eval_images`` called on ``reference_sources`` and 509 | ``estimated_sources`` (with the ``compute_permutation`` parameter passed to 510 | ``bss_eval_images``) is returned 511 | 512 | Examples 513 | -------- 514 | >>> # reference_sources[n] should be an ndarray of samples of the 515 | >>> # n'th reference source 516 | >>> # estimated_sources[n] should be the same for the n'th estimated 517 | >>> # source 518 | >>> (sdr, isr, sir, sar, 519 | ... perm) = mir_eval.separation.bss_eval_images_framewise( 520 | reference_sources, 521 | ... estimated_sources, 522 | window, 523 | .... hop) 524 | 525 | Parameters 526 | ---------- 527 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 528 | matrix containing true sources (must have the same shape as 529 | ``estimated_sources``) 530 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 531 | matrix containing estimated sources (must have the same shape as 532 | ``reference_sources``) 533 | window : int 534 | Window length for framewise evaluation 535 | hop : int 536 | Hop size for framewise evaluation 537 | compute_permutation : bool, optional 538 | compute permutation of estimate/source combinations for all windows 539 | (False by default) 540 | 541 | Returns 542 | ------- 543 | sdr : np.ndarray, shape=(nsrc, nframes) 544 | vector of Signal to Distortion Ratios (SDR) 545 | isr : np.ndarray, shape=(nsrc, nframes) 546 | vector of source Image to Spatial distortion Ratios (ISR) 547 | sir : np.ndarray, shape=(nsrc, nframes) 548 | vector of Source to Interference Ratios (SIR) 549 | sar : np.ndarray, shape=(nsrc, nframes) 550 | vector of Sources to Artifacts Ratios (SAR) 551 | perm : np.ndarray, shape=(nsrc, nframes) 552 | vector containing the best ordering of estimated sources in 553 | the mean SIR sense (estimated source number perm[j] corresponds to 554 | true source number j) 555 | Note: perm will be range(nsrc) for all windows if compute_permutation 556 | is False 557 | 558 | """ 559 | 560 | # make sure the input has 3 dimensions 561 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 562 | estimated_sources = np.atleast_3d(estimated_sources) 563 | reference_sources = np.atleast_3d(reference_sources) 564 | # we will ensure input doesn't have more than 3 dimensions in validate 565 | 566 | validate(reference_sources, estimated_sources) 567 | # If empty matrices were supplied, return empty lists (special case) 568 | if reference_sources.size == 0 or estimated_sources.size == 0: 569 | return np.array([]), np.array([]), np.array([]), np.array([]) 570 | 571 | nsrc = reference_sources.shape[0] 572 | 573 | nwin = int( 574 | np.floor((reference_sources.shape[1] - window + hop) / hop) 575 | ) 576 | # if fewer than 2 windows would be evaluated, return the images result 577 | if nwin < 2: 578 | result = bss_eval_images(reference_sources, 579 | estimated_sources, 580 | compute_permutation) 581 | return [np.expand_dims(score, -1) for score in result] 582 | 583 | # compute the criteria across all windows 584 | sdr = np.empty((nsrc, nwin)) 585 | isr = np.empty((nsrc, nwin)) 586 | sir = np.empty((nsrc, nwin)) 587 | sar = np.empty((nsrc, nwin)) 588 | perm = np.empty((nsrc, nwin)) 589 | 590 | # k iterates across all the windows 591 | for k in range(nwin): 592 | win_slice = slice(k * hop, k * hop + window) 593 | ref_slice = reference_sources[:, win_slice, :] 594 | est_slice = estimated_sources[:, win_slice, :] 595 | # check for a silent frame 596 | if (not _any_source_silent(ref_slice) and 597 | not _any_source_silent(est_slice)): 598 | sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = \ 599 | bss_eval_images( 600 | ref_slice, est_slice, compute_permutation 601 | ) 602 | else: 603 | # if we have a silent frame set results as np.nan 604 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 605 | 606 | return sdr, isr, sir, sar, perm 607 | 608 | 609 | def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen): 610 | """Decomposition of an estimated source image into four components 611 | representing respectively the true source image, spatial (or filtering) 612 | distortion, interference and artifacts, derived from the true source 613 | images using multichannel time-invariant filters. 614 | """ 615 | nsampl = estimated_source.size 616 | # decomposition 617 | # true source image 618 | s_true = np.hstack((reference_sources[j], np.zeros(flen - 1))) 619 | # spatial (or filtering) distortion 620 | e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source, 621 | flen) - s_true 622 | # interference 623 | e_interf = _project(reference_sources, 624 | estimated_source, flen) - s_true - e_spat 625 | # artifacts 626 | e_artif = -s_true - e_spat - e_interf 627 | e_artif[:nsampl] += estimated_source 628 | return (s_true, e_spat, e_interf, e_artif) 629 | 630 | 631 | def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen, 632 | Gj=None, G=None): 633 | """Decomposition of an estimated source image into four components 634 | representing respectively the true source image, spatial (or filtering) 635 | distortion, interference and artifacts, derived from the true source 636 | images using multichannel time-invariant filters. 637 | Adapted version to work with multichannel sources. 638 | Improved performance can be gained by passing Gj and G parameters initially 639 | as all zeros. These parameters store the results from the computation of 640 | the G matrix in _project_images and then return them for subsequent calls 641 | to this function. This only works when not computing permuations. 642 | """ 643 | nsampl = np.shape(estimated_source)[0] 644 | nchan = np.shape(estimated_source)[1] 645 | # are we saving the Gj and G parameters? 646 | saveg = Gj is not None and G is not None 647 | # decomposition 648 | # true source image 649 | s_true = np.hstack((np.reshape(reference_sources[j], 650 | (nsampl, nchan), 651 | order="F").transpose(), 652 | np.zeros((nchan, flen - 1)))) 653 | # spatial (or filtering) distortion 654 | if saveg: 655 | e_spat, Gj = _project_images(reference_sources[j, np.newaxis, :], 656 | estimated_source, flen, Gj) 657 | else: 658 | e_spat = _project_images(reference_sources[j, np.newaxis, :], 659 | estimated_source, flen) 660 | e_spat = e_spat - s_true 661 | # interference 662 | if saveg: 663 | e_interf, G = _project_images(reference_sources, 664 | estimated_source, flen, G) 665 | else: 666 | e_interf = _project_images(reference_sources, 667 | estimated_source, flen) 668 | e_interf = e_interf - s_true - e_spat 669 | # artifacts 670 | e_artif = -s_true - e_spat - e_interf 671 | e_artif[:, :nsampl] += estimated_source.transpose() 672 | # return Gj and G only if they were passed in 673 | if saveg: 674 | return (s_true, e_spat, e_interf, e_artif, Gj, G) 675 | else: 676 | return (s_true, e_spat, e_interf, e_artif) 677 | 678 | 679 | def _project(reference_sources, estimated_source, flen): 680 | """Least-squares projection of estimated source on the subspace spanned by 681 | delayed versions of reference sources, with delays between 0 and flen-1 682 | """ 683 | nsrc = reference_sources.shape[0] 684 | nsampl = reference_sources.shape[1] 685 | 686 | # computing coefficients of least squares problem via FFT ## 687 | # zero padding and FFT of input data 688 | reference_sources = np.hstack((reference_sources, 689 | np.zeros((nsrc, flen - 1)))) 690 | estimated_source = np.hstack((estimated_source, np.zeros(flen - 1))) 691 | n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) 692 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 693 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 694 | # inner products between delayed versions of reference_sources 695 | G = np.zeros((nsrc * flen, nsrc * flen)) 696 | for i in range(nsrc): 697 | for j in range(nsrc): 698 | ssf = sf[i] * np.conj(sf[j]) 699 | ssf = np.real(scipy.fftpack.ifft(ssf)) 700 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 701 | r=ssf[:flen]) 702 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 703 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 704 | # inner products between estimated_source and delayed versions of 705 | # reference_sources 706 | D = np.zeros(nsrc * flen) 707 | for i in range(nsrc): 708 | ssef = sf[i] * np.conj(sef) 709 | ssef = np.real(scipy.fftpack.ifft(ssef)) 710 | D[i * flen: (i+1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1])) 711 | 712 | # Computing projection 713 | # Distortion filters 714 | try: 715 | C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F') 716 | except np.linalg.linalg.LinAlgError: 717 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F') 718 | # Filtering 719 | sproj = np.zeros(nsampl + flen - 1) 720 | for i in range(nsrc): 721 | sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1] 722 | return sproj 723 | 724 | 725 | def _project_images(reference_sources, estimated_source, flen, G=None): 726 | """Least-squares projection of estimated source on the subspace spanned by 727 | delayed versions of reference sources, with delays between 0 and flen-1. 728 | Passing G as all zeros will populate the G matrix and return it so it can 729 | be passed into the next call to avoid recomputing G (this will only works 730 | if not computing permutations). 731 | """ 732 | nsrc = reference_sources.shape[0] 733 | nsampl = reference_sources.shape[1] 734 | nchan = reference_sources.shape[2] 735 | reference_sources = np.reshape(np.transpose(reference_sources, (2, 0, 1)), 736 | (nchan*nsrc, nsampl), order='F') 737 | 738 | # computing coefficients of least squares problem via FFT ## 739 | # zero padding and FFT of input data 740 | reference_sources = np.hstack((reference_sources, 741 | np.zeros((nchan*nsrc, flen - 1)))) 742 | estimated_source = \ 743 | np.hstack((estimated_source.transpose(), np.zeros((nchan, flen - 1)))) 744 | n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) 745 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 746 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 747 | 748 | # inner products between delayed versions of reference_sources 749 | if G is None: 750 | saveg = False 751 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 752 | for i in range(nchan * nsrc): 753 | for j in range(i+1): 754 | ssf = sf[i] * np.conj(sf[j]) 755 | ssf = np.real(scipy.fftpack.ifft(ssf)) 756 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 757 | r=ssf[:flen]) 758 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 759 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 760 | else: # avoid recomputing G (only works if no permutation is desired) 761 | saveg = True # return G 762 | if np.all(G == 0): # only compute G if passed as 0 763 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 764 | for i in range(nchan * nsrc): 765 | for j in range(i+1): 766 | ssf = sf[i] * np.conj(sf[j]) 767 | ssf = np.real(scipy.fftpack.ifft(ssf)) 768 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 769 | r=ssf[:flen]) 770 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 771 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 772 | 773 | # inner products between estimated_source and delayed versions of 774 | # reference_sources 775 | D = np.zeros((nchan * nsrc * flen, nchan)) 776 | for k in range(nchan * nsrc): 777 | for i in range(nchan): 778 | ssef = sf[k] * np.conj(sef[i]) 779 | ssef = np.real(scipy.fftpack.ifft(ssef)) 780 | D[k * flen: (k+1) * flen, i] = \ 781 | np.hstack((ssef[0], ssef[-1:-flen:-1])).transpose() 782 | 783 | # Computing projection 784 | # Distortion filters 785 | try: 786 | C = np.linalg.solve(G, D).reshape(flen, nchan*nsrc, nchan, order='F') 787 | except np.linalg.linalg.LinAlgError: 788 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan*nsrc, nchan, 789 | order='F') 790 | # Filtering 791 | sproj = np.zeros((nchan, nsampl + flen - 1)) 792 | for k in range(nchan * nsrc): 793 | for i in range(nchan): 794 | sproj[i] += fftconvolve(C[:, k, i].transpose(), 795 | reference_sources[k])[:nsampl + flen - 1] 796 | # return G only if it was passed in 797 | if saveg: 798 | return sproj, G 799 | else: 800 | return sproj 801 | 802 | 803 | def _bss_source_crit(s_true, e_spat, e_interf, e_artif): 804 | """Measurement of the separation quality for a given source in terms of 805 | filtered true source, interference and artifacts. 806 | """ 807 | # energy ratios 808 | s_filt = s_true + e_spat 809 | sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif)**2)) 810 | sir = _safe_db(np.sum(s_filt**2), np.sum(e_interf**2)) 811 | sar = _safe_db(np.sum((s_filt + e_interf)**2), np.sum(e_artif**2)) 812 | return (sdr, sir, sar) 813 | 814 | 815 | def _bss_image_crit(s_true, e_spat, e_interf, e_artif): 816 | """Measurement of the separation quality for a given image in terms of 817 | filtered true source, spatial error, interference and artifacts. 818 | """ 819 | # energy ratios 820 | sdr = _safe_db(np.sum(s_true**2), np.sum((e_spat+e_interf+e_artif)**2)) 821 | isr = _safe_db(np.sum(s_true**2), np.sum(e_spat**2)) 822 | sir = _safe_db(np.sum((s_true+e_spat)**2), np.sum(e_interf**2)) 823 | sar = _safe_db(np.sum((s_true+e_spat+e_interf)**2), np.sum(e_artif**2)) 824 | return (sdr, isr, sir, sar) 825 | 826 | 827 | def _safe_db(num, den): 828 | """Properly handle the potential +Inf db SIR, instead of raising a 829 | RuntimeWarning. Only denominator is checked because the numerator can never 830 | be 0. 831 | """ 832 | if den == 0: 833 | return np.Inf 834 | return 10 * np.log10(num / den) 835 | 836 | 837 | def evaluate(reference_sources, estimated_sources, **kwargs): 838 | """Compute all metrics for the given reference and estimated signals. 839 | 840 | NOTE: This will always compute :func:`mir_eval.separation.bss_eval_images` 841 | for any valid input and will additionally compute 842 | :func:`mir_eval.separation.bss_eval_sources` for valid input with fewer 843 | than 3 dimensions. 844 | 845 | Examples 846 | -------- 847 | >>> # reference_sources[n] should be an ndarray of samples of the 848 | >>> # n'th reference source 849 | >>> # estimated_sources[n] should be the same for the n'th estimated source 850 | >>> scores = mir_eval.separation.evaluate(reference_sources, 851 | ... estimated_sources) 852 | 853 | Parameters 854 | ---------- 855 | reference_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 856 | matrix containing true sources 857 | estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 858 | matrix containing estimated sources 859 | kwargs 860 | Additional keyword arguments which will be passed to the 861 | appropriate metric or preprocessing functions. 862 | 863 | Returns 864 | ------- 865 | scores : dict 866 | Dictionary of scores, where the key is the metric name (str) and 867 | the value is the (float) score achieved. 868 | 869 | """ 870 | # Compute all the metrics 871 | scores = collections.OrderedDict() 872 | 873 | sdr, isr, sir, sar, perm = util.filter_kwargs( 874 | bss_eval_images, 875 | reference_sources, 876 | estimated_sources, 877 | **kwargs 878 | ) 879 | scores['Images - Source to Distortion'] = sdr.tolist() 880 | scores['Images - Image to Spatial'] = isr.tolist() 881 | scores['Images - Source to Interference'] = sir.tolist() 882 | scores['Images - Source to Artifact'] = sar.tolist() 883 | scores['Images - Source permutation'] = perm.tolist() 884 | 885 | sdr, isr, sir, sar, perm = util.filter_kwargs( 886 | bss_eval_images_framewise, 887 | reference_sources, 888 | estimated_sources, 889 | **kwargs 890 | ) 891 | scores['Images Frames - Source to Distortion'] = sdr.tolist() 892 | scores['Images Frames - Image to Spatial'] = isr.tolist() 893 | scores['Images Frames - Source to Interference'] = sir.tolist() 894 | scores['Images Frames - Source to Artifact'] = sar.tolist() 895 | scores['Images Frames - Source permutation'] = perm.tolist() 896 | 897 | # Verify we can compute sources on this input 898 | if reference_sources.ndim < 3 and estimated_sources.ndim < 3: 899 | sdr, sir, sar, perm = util.filter_kwargs( 900 | bss_eval_sources_framewise, 901 | reference_sources, 902 | estimated_sources, 903 | **kwargs 904 | ) 905 | scores['Sources Frames - Source to Distortion'] = sdr.tolist() 906 | scores['Sources Frames - Source to Interference'] = sir.tolist() 907 | scores['Sources Frames - Source to Artifact'] = sar.tolist() 908 | scores['Sources Frames - Source permutation'] = perm.tolist() 909 | 910 | sdr, sir, sar, perm = util.filter_kwargs( 911 | bss_eval_sources, 912 | reference_sources, 913 | estimated_sources, 914 | **kwargs 915 | ) 916 | scores['Sources - Source to Distortion'] = sdr.tolist() 917 | scores['Sources - Source to Interference'] = sir.tolist() 918 | scores['Sources - Source to Artifact'] = sar.tolist() 919 | scores['Sources - Source permutation'] = perm.tolist() 920 | 921 | return scores 922 | -------------------------------------------------------------------------------- /mir_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This submodule collects useful functionality required across the task 3 | submodules, such as preprocessing, validation, and common computations. 4 | ''' 5 | 6 | import os 7 | import inspect 8 | import six 9 | 10 | import numpy as np 11 | 12 | 13 | def index_labels(labels, case_sensitive=False): 14 | """Convert a list of string identifiers into numerical indices. 15 | 16 | Parameters 17 | ---------- 18 | labels : list of strings, shape=(n,) 19 | A list of annotations, e.g., segment or chord labels from an 20 | annotation file. 21 | 22 | case_sensitive : bool 23 | Set to True to enable case-sensitive label indexing 24 | (Default value = False) 25 | 26 | Returns 27 | ------- 28 | indices : list, shape=(n,) 29 | Numerical representation of ``labels`` 30 | index_to_label : dict 31 | Mapping to convert numerical indices back to labels. 32 | ``labels[i] == index_to_label[indices[i]]`` 33 | 34 | """ 35 | 36 | label_to_index = {} 37 | index_to_label = {} 38 | 39 | # If we're not case-sensitive, 40 | if not case_sensitive: 41 | labels = [str(s).lower() for s in labels] 42 | 43 | # First, build the unique label mapping 44 | for index, s in enumerate(sorted(set(labels))): 45 | label_to_index[s] = index 46 | index_to_label[index] = s 47 | 48 | # Remap the labels to indices 49 | indices = [label_to_index[s] for s in labels] 50 | 51 | # Return the converted labels, and the inverse mapping 52 | return indices, index_to_label 53 | 54 | 55 | def generate_labels(items, prefix='__'): 56 | """Given an array of items (e.g. events, intervals), create a synthetic label 57 | for each event of the form '(label prefix)(item number)' 58 | 59 | Parameters 60 | ---------- 61 | items : list-like 62 | A list or array of events or intervals 63 | prefix : str 64 | This prefix will be prepended to all synthetically generated labels 65 | (Default value = '__') 66 | 67 | Returns 68 | ------- 69 | labels : list of str 70 | Synthetically generated labels 71 | 72 | """ 73 | return ['{}{}'.format(prefix, n) for n in range(len(items))] 74 | 75 | 76 | def intervals_to_samples(intervals, labels, offset=0, sample_size=0.1, 77 | fill_value=None): 78 | """Convert an array of labeled time intervals to annotated samples. 79 | 80 | Parameters 81 | ---------- 82 | intervals : np.ndarray, shape=(n, d) 83 | An array of time intervals, as returned by 84 | :func:`mir_eval.io.load_intervals()` or 85 | :func:`mir_eval.io.load_labeled_intervals()`. 86 | The ``i`` th interval spans time ``intervals[i, 0]`` to 87 | ``intervals[i, 1]``. 88 | 89 | labels : list, shape=(n,) 90 | The annotation for each interval 91 | 92 | offset : float > 0 93 | Phase offset of the sampled time grid (in seconds) 94 | (Default value = 0) 95 | 96 | sample_size : float > 0 97 | duration of each sample to be generated (in seconds) 98 | (Default value = 0.1) 99 | 100 | fill_value : type(labels[0]) 101 | Object to use for the label with out-of-range time points. 102 | (Default value = None) 103 | 104 | Returns 105 | ------- 106 | sample_times : list 107 | list of sample times 108 | 109 | sample_labels : list 110 | array of labels for each generated sample 111 | 112 | Notes 113 | ----- 114 | Intervals will be rounded down to the nearest multiple 115 | of ``sample_size``. 116 | 117 | """ 118 | 119 | # Round intervals to the sample size 120 | num_samples = int(np.floor(intervals.max() / sample_size)) 121 | sample_indices = np.arange(num_samples, dtype=np.float32) 122 | sample_times = (sample_indices*sample_size + offset).tolist() 123 | sampled_labels = interpolate_intervals( 124 | intervals, labels, sample_times, fill_value) 125 | 126 | return sample_times, sampled_labels 127 | 128 | 129 | def interpolate_intervals(intervals, labels, time_points, fill_value=None): 130 | """Assign labels to a set of points in time given a set of intervals. 131 | 132 | Time points that do not lie within an interval are mapped to `fill_value`. 133 | 134 | Parameters 135 | ---------- 136 | intervals : np.ndarray, shape=(n, 2) 137 | An array of time intervals, as returned by 138 | :func:`mir_eval.io.load_intervals()`. 139 | The ``i`` th interval spans time ``intervals[i, 0]`` to 140 | ``intervals[i, 1]``. 141 | 142 | Intervals are assumed to be disjoint. 143 | 144 | labels : list, shape=(n,) 145 | The annotation for each interval 146 | 147 | time_points : array_like, shape=(m,) 148 | Points in time to assign labels. These must be in 149 | non-decreasing order. 150 | 151 | fill_value : type(labels[0]) 152 | Object to use for the label with out-of-range time points. 153 | (Default value = None) 154 | 155 | Returns 156 | ------- 157 | aligned_labels : list 158 | Labels corresponding to the given time points. 159 | 160 | Raises 161 | ------ 162 | ValueError 163 | If `time_points` is not in non-decreasing order. 164 | """ 165 | 166 | # Verify that time_points is sorted 167 | time_points = np.asarray(time_points) 168 | 169 | if np.any(time_points[1:] < time_points[:-1]): 170 | raise ValueError('time_points must be in non-decreasing order') 171 | 172 | aligned_labels = [fill_value] * len(time_points) 173 | 174 | starts = np.searchsorted(time_points, intervals[:, 0], side='left') 175 | ends = np.searchsorted(time_points, intervals[:, 1], side='right') 176 | 177 | for (start, end, lab) in zip(starts, ends, labels): 178 | aligned_labels[start:end] = [lab] * (end - start) 179 | 180 | return aligned_labels 181 | 182 | 183 | def sort_labeled_intervals(intervals, labels=None): 184 | '''Sort intervals, and optionally, their corresponding labels 185 | according to start time. 186 | 187 | Parameters 188 | ---------- 189 | intervals : np.ndarray, shape=(n, 2) 190 | The input intervals 191 | 192 | labels : list, optional 193 | Labels for each interval 194 | 195 | Returns 196 | ------- 197 | intervals_sorted or (intervals_sorted, labels_sorted) 198 | Labels are only returned if provided as input 199 | ''' 200 | 201 | idx = np.argsort(intervals[:, 0]) 202 | 203 | intervals_sorted = intervals[idx] 204 | 205 | if labels is None: 206 | return intervals_sorted 207 | else: 208 | return intervals_sorted, [labels[_] for _ in idx] 209 | 210 | 211 | def f_measure(precision, recall, beta=1.0): 212 | """Compute the f-measure from precision and recall scores. 213 | 214 | Parameters 215 | ---------- 216 | precision : float in (0, 1] 217 | Precision 218 | recall : float in (0, 1] 219 | Recall 220 | beta : float > 0 221 | Weighting factor for f-measure 222 | (Default value = 1.0) 223 | 224 | Returns 225 | ------- 226 | f_measure : float 227 | The weighted f-measure 228 | 229 | """ 230 | 231 | if precision == 0 and recall == 0: 232 | return 0.0 233 | 234 | return (1 + beta**2)*precision*recall/((beta**2)*precision + recall) 235 | 236 | 237 | def intervals_to_boundaries(intervals, q=5): 238 | """Convert interval times into boundaries. 239 | 240 | Parameters 241 | ---------- 242 | intervals : np.ndarray, shape=(n_events, 2) 243 | Array of interval start and end-times 244 | q : int 245 | Number of decimals to round to. (Default value = 5) 246 | 247 | Returns 248 | ------- 249 | boundaries : np.ndarray 250 | Interval boundary times, including the end of the final interval 251 | 252 | """ 253 | 254 | return np.unique(np.ravel(np.round(intervals, decimals=q))) 255 | 256 | 257 | def boundaries_to_intervals(boundaries): 258 | """Convert an array of event times into intervals 259 | 260 | Parameters 261 | ---------- 262 | boundaries : list-like 263 | List-like of event times. These are assumed to be unique 264 | timestamps in ascending order. 265 | 266 | Returns 267 | ------- 268 | intervals : np.ndarray, shape=(n_intervals, 2) 269 | Start and end time for each interval 270 | """ 271 | 272 | if not np.allclose(boundaries, np.unique(boundaries)): 273 | raise ValueError('Boundary times are not unique or not ascending.') 274 | 275 | intervals = np.asarray(list(zip(boundaries[:-1], boundaries[1:]))) 276 | 277 | return intervals 278 | 279 | 280 | def adjust_intervals(intervals, 281 | labels=None, 282 | t_min=0.0, 283 | t_max=None, 284 | start_label='__T_MIN', 285 | end_label='__T_MAX'): 286 | """Adjust a list of time intervals to span the range ``[t_min, t_max]``. 287 | 288 | Any intervals lying completely outside the specified range will be removed. 289 | 290 | Any intervals lying partially outside the specified range will be cropped. 291 | 292 | If the specified range exceeds the span of the provided data in either 293 | direction, additional intervals will be appended. If an interval is 294 | appended at the beginning, it will be given the label ``start_label``; if 295 | an interval is appended at the end, it will be given the label 296 | ``end_label``. 297 | 298 | Parameters 299 | ---------- 300 | intervals : np.ndarray, shape=(n_events, 2) 301 | Array of interval start and end-times 302 | labels : list, len=n_events or None 303 | List of labels 304 | (Default value = None) 305 | t_min : float or None 306 | Minimum interval start time. 307 | (Default value = 0.0) 308 | t_max : float or None 309 | Maximum interval end time. 310 | (Default value = None) 311 | start_label : str or float or int 312 | Label to give any intervals appended at the beginning 313 | (Default value = '__T_MIN') 314 | end_label : str or float or int 315 | Label to give any intervals appended at the end 316 | (Default value = '__T_MAX') 317 | 318 | Returns 319 | ------- 320 | new_intervals : np.ndarray 321 | Intervals spanning ``[t_min, t_max]`` 322 | new_labels : list 323 | List of labels for ``new_labels`` 324 | 325 | """ 326 | 327 | # When supplied intervals are empty and t_max and t_min are supplied, 328 | # create one interval from t_min to t_max with the label start_label 329 | if t_min is not None and t_max is not None and intervals.size == 0: 330 | return np.array([[t_min, t_max]]), [start_label] 331 | # When intervals are empty and either t_min or t_max are not supplied, 332 | # we can't append new intervals 333 | elif (t_min is None or t_max is None) and intervals.size == 0: 334 | raise ValueError("Supplied intervals are empty, can't append new" 335 | " intervals") 336 | 337 | if t_min is not None: 338 | # Find the intervals that end at or after t_min 339 | first_idx = np.argwhere(intervals[:, 1] >= t_min) 340 | 341 | if len(first_idx) > 0: 342 | # If we have events below t_min, crop them out 343 | if labels is not None: 344 | labels = labels[int(first_idx[0]):] 345 | # Clip to the range (t_min, +inf) 346 | intervals = intervals[int(first_idx[0]):] 347 | intervals = np.maximum(t_min, intervals) 348 | 349 | if intervals.min() > t_min: 350 | # Lowest boundary is higher than t_min: 351 | # add a new boundary and label 352 | intervals = np.vstack(([t_min, intervals.min()], intervals)) 353 | if labels is not None: 354 | labels.insert(0, start_label) 355 | 356 | if t_max is not None: 357 | # Find the intervals that begin after t_max 358 | last_idx = np.argwhere(intervals[:, 0] > t_max) 359 | 360 | if len(last_idx) > 0: 361 | # We have boundaries above t_max. 362 | # Trim to only boundaries <= t_max 363 | if labels is not None: 364 | labels = labels[:int(last_idx[0])] 365 | # Clip to the range (-inf, t_max) 366 | intervals = intervals[:int(last_idx[0])] 367 | 368 | intervals = np.minimum(t_max, intervals) 369 | 370 | if intervals.max() < t_max: 371 | # Last boundary is below t_max: add a new boundary and label 372 | intervals = np.vstack((intervals, [intervals.max(), t_max])) 373 | if labels is not None: 374 | labels.append(end_label) 375 | 376 | return intervals, labels 377 | 378 | 379 | def adjust_events(events, labels=None, t_min=0.0, 380 | t_max=None, label_prefix='__'): 381 | """Adjust the given list of event times to span the range 382 | ``[t_min, t_max]``. 383 | 384 | Any event times outside of the specified range will be removed. 385 | 386 | If the times do not span ``[t_min, t_max]``, additional events will be 387 | added with the prefix ``label_prefix``. 388 | 389 | Parameters 390 | ---------- 391 | events : np.ndarray 392 | Array of event times (seconds) 393 | labels : list or None 394 | List of labels 395 | (Default value = None) 396 | t_min : float or None 397 | Minimum valid event time. 398 | (Default value = 0.0) 399 | t_max : float or None 400 | Maximum valid event time. 401 | (Default value = None) 402 | label_prefix : str 403 | Prefix string to use for synthetic labels 404 | (Default value = '__') 405 | 406 | Returns 407 | ------- 408 | new_times : np.ndarray 409 | Event times corrected to the given range. 410 | 411 | """ 412 | if t_min is not None: 413 | first_idx = np.argwhere(events >= t_min) 414 | 415 | if len(first_idx) > 0: 416 | # We have events below t_min 417 | # Crop them out 418 | if labels is not None: 419 | labels = labels[int(first_idx[0]):] 420 | events = events[int(first_idx[0]):] 421 | 422 | if events[0] > t_min: 423 | # Lowest boundary is higher than t_min: 424 | # add a new boundary and label 425 | events = np.concatenate(([t_min], events)) 426 | if labels is not None: 427 | labels.insert(0, '%sT_MIN' % label_prefix) 428 | 429 | if t_max is not None: 430 | last_idx = np.argwhere(events > t_max) 431 | 432 | if len(last_idx) > 0: 433 | # We have boundaries above t_max. 434 | # Trim to only boundaries <= t_max 435 | if labels is not None: 436 | labels = labels[:int(last_idx[0])] 437 | events = events[:int(last_idx[0])] 438 | 439 | if events[-1] < t_max: 440 | # Last boundary is below t_max: add a new boundary and label 441 | events = np.concatenate((events, [t_max])) 442 | if labels is not None: 443 | labels.append('%sT_MAX' % label_prefix) 444 | 445 | return events, labels 446 | 447 | 448 | def intersect_files(flist1, flist2): 449 | """Return the intersection of two sets of filepaths, based on the file name 450 | (after the final '/') and ignoring the file extension. 451 | 452 | Examples 453 | -------- 454 | >>> flist1 = ['/a/b/abc.lab', '/c/d/123.lab', '/e/f/xyz.lab'] 455 | >>> flist2 = ['/g/h/xyz.npy', '/i/j/123.txt', '/k/l/456.lab'] 456 | >>> sublist1, sublist2 = mir_eval.util.intersect_files(flist1, flist2) 457 | >>> print sublist1 458 | ['/e/f/xyz.lab', '/c/d/123.lab'] 459 | >>> print sublist2 460 | ['/g/h/xyz.npy', '/i/j/123.txt'] 461 | 462 | Parameters 463 | ---------- 464 | flist1 : list 465 | first list of filepaths 466 | flist2 : list 467 | second list of filepaths 468 | 469 | Returns 470 | ------- 471 | sublist1 : list 472 | subset of filepaths with matching stems from ``flist1`` 473 | sublist2 : list 474 | corresponding filepaths from ``flist2`` 475 | 476 | """ 477 | def fname(abs_path): 478 | """Returns the filename given an absolute path. 479 | 480 | Parameters 481 | ---------- 482 | abs_path : 483 | 484 | 485 | Returns 486 | ------- 487 | 488 | """ 489 | return os.path.splitext(os.path.split(abs_path)[-1])[0] 490 | 491 | fmap = dict([(fname(f), f) for f in flist1]) 492 | pairs = [list(), list()] 493 | for f in flist2: 494 | if fname(f) in fmap: 495 | pairs[0].append(fmap[fname(f)]) 496 | pairs[1].append(f) 497 | 498 | return pairs 499 | 500 | 501 | def merge_labeled_intervals(x_intervals, x_labels, y_intervals, y_labels): 502 | r"""Merge the time intervals of two sequences. 503 | 504 | Parameters 505 | ---------- 506 | x_intervals : np.ndarray 507 | Array of interval times (seconds) 508 | x_labels : list or None 509 | List of labels 510 | y_intervals : np.ndarray 511 | Array of interval times (seconds) 512 | y_labels : list or None 513 | List of labels 514 | 515 | Returns 516 | ------- 517 | new_intervals : np.ndarray 518 | New interval times of the merged sequences. 519 | new_x_labels : list 520 | New labels for the sequence ``x`` 521 | new_y_labels : list 522 | New labels for the sequence ``y`` 523 | 524 | """ 525 | align_check = [x_intervals[0, 0] == y_intervals[0, 0], 526 | x_intervals[-1, 1] == y_intervals[-1, 1]] 527 | if False in align_check: 528 | raise ValueError( 529 | "Time intervals do not align; did you mean to call " 530 | "'adjust_intervals()' first?") 531 | time_boundaries = np.unique( 532 | np.concatenate([x_intervals, y_intervals], axis=0)) 533 | output_intervals = np.array( 534 | [time_boundaries[:-1], time_boundaries[1:]]).T 535 | 536 | x_labels_out, y_labels_out = [], [] 537 | x_label_range = np.arange(len(x_labels)) 538 | y_label_range = np.arange(len(y_labels)) 539 | for t0, _ in output_intervals: 540 | x_idx = x_label_range[(t0 >= x_intervals[:, 0])] 541 | x_labels_out.append(x_labels[x_idx[-1]]) 542 | y_idx = y_label_range[(t0 >= y_intervals[:, 0])] 543 | y_labels_out.append(y_labels[y_idx[-1]]) 544 | return output_intervals, x_labels_out, y_labels_out 545 | 546 | 547 | def _bipartite_match(graph): 548 | """Find maximum cardinality matching of a bipartite graph (U,V,E). 549 | The input format is a dictionary mapping members of U to a list 550 | of their neighbors in V. 551 | 552 | The output is a dict M mapping members of V to their matches in U. 553 | 554 | Parameters 555 | ---------- 556 | graph : dictionary : left-vertex -> list of right vertices 557 | The input bipartite graph. Each edge need only be specified once. 558 | 559 | Returns 560 | ------- 561 | matching : dictionary : right-vertex -> left vertex 562 | A maximal bipartite matching. 563 | 564 | """ 565 | # Adapted from: 566 | # 567 | # Hopcroft-Karp bipartite max-cardinality matching and max independent set 568 | # David Eppstein, UC Irvine, 27 Apr 2002 569 | 570 | # initialize greedy matching (redundant, but faster than full search) 571 | matching = {} 572 | for u in graph: 573 | for v in graph[u]: 574 | if v not in matching: 575 | matching[v] = u 576 | break 577 | 578 | while True: 579 | # structure residual graph into layers 580 | # pred[u] gives the neighbor in the previous layer for u in U 581 | # preds[v] gives a list of neighbors in the previous layer for v in V 582 | # unmatched gives a list of unmatched vertices in final layer of V, 583 | # and is also used as a flag value for pred[u] when u is in the first 584 | # layer 585 | preds = {} 586 | unmatched = [] 587 | pred = dict([(u, unmatched) for u in graph]) 588 | for v in matching: 589 | del pred[matching[v]] 590 | layer = list(pred) 591 | 592 | # repeatedly extend layering structure by another pair of layers 593 | while layer and not unmatched: 594 | new_layer = {} 595 | for u in layer: 596 | for v in graph[u]: 597 | if v not in preds: 598 | new_layer.setdefault(v, []).append(u) 599 | layer = [] 600 | for v in new_layer: 601 | preds[v] = new_layer[v] 602 | if v in matching: 603 | layer.append(matching[v]) 604 | pred[matching[v]] = v 605 | else: 606 | unmatched.append(v) 607 | 608 | # did we finish layering without finding any alternating paths? 609 | if not unmatched: 610 | unlayered = {} 611 | for u in graph: 612 | for v in graph[u]: 613 | if v not in preds: 614 | unlayered[v] = None 615 | return matching 616 | 617 | def recurse(v): 618 | """Recursively search backward through layers to find alternating 619 | paths. recursion returns true if found path, false otherwise 620 | """ 621 | if v in preds: 622 | L = preds[v] 623 | del preds[v] 624 | for u in L: 625 | if u in pred: 626 | pu = pred[u] 627 | del pred[u] 628 | if pu is unmatched or recurse(pu): 629 | matching[v] = u 630 | return True 631 | return False 632 | 633 | for v in unmatched: 634 | recurse(v) 635 | 636 | 637 | def _outer_distance_mod_n(ref, est, modulus=12): 638 | """Compute the absolute outer distance modulo n. 639 | Using this distance, d(11, 0) = 1 (modulo 12) 640 | 641 | Parameters 642 | ---------- 643 | ref : np.ndarray, shape=(n,) 644 | Array of reference values. 645 | est : np.ndarray, shape=(m,) 646 | Array of estimated values. 647 | modulus : int 648 | The modulus. 649 | 12 by default for octave equivalence. 650 | 651 | Returns 652 | ------- 653 | outer_distance : np.ndarray, shape=(n, m) 654 | The outer circular distance modulo n. 655 | 656 | """ 657 | ref_mod_n = np.mod(ref, modulus) 658 | est_mod_n = np.mod(est, modulus) 659 | abs_diff = np.abs(np.subtract.outer(ref_mod_n, est_mod_n)) 660 | return np.minimum(abs_diff, modulus - abs_diff) 661 | 662 | 663 | def match_events(ref, est, window, distance=None): 664 | """Compute a maximum matching between reference and estimated event times, 665 | subject to a window constraint. 666 | 667 | Given two lists of event times ``ref`` and ``est``, we seek the largest set 668 | of correspondences ``(ref[i], est[j])`` such that 669 | ``distance(ref[i], est[j]) <= window``, and each 670 | ``ref[i]`` and ``est[j]`` is matched at most once. 671 | 672 | This is useful for computing precision/recall metrics in beat tracking, 673 | onset detection, and segmentation. 674 | 675 | Parameters 676 | ---------- 677 | ref : np.ndarray, shape=(n,) 678 | Array of reference values 679 | est : np.ndarray, shape=(m,) 680 | Array of estimated values 681 | window : float > 0 682 | Size of the window. 683 | distance : function 684 | function that computes the outer distance of ref and est. 685 | By default uses ``|ref[i] - est[j]|`` 686 | 687 | Returns 688 | ------- 689 | matching : list of tuples 690 | A list of matched reference and event numbers. 691 | ``matching[i] == (i, j)`` where ``ref[i]`` matches ``est[j]``. 692 | 693 | """ 694 | if distance is not None: 695 | # Compute the indices of feasible pairings 696 | hits = np.where(distance(ref, est) <= window) 697 | else: 698 | hits = _fast_hit_windows(ref, est, window) 699 | 700 | # Construct the graph input 701 | G = {} 702 | for ref_i, est_i in zip(*hits): 703 | if est_i not in G: 704 | G[est_i] = [] 705 | G[est_i].append(ref_i) 706 | 707 | # Compute the maximum matching 708 | matching = sorted(_bipartite_match(G).items()) 709 | 710 | return matching 711 | 712 | 713 | def _fast_hit_windows(ref, est, window): 714 | '''Fast calculation of windowed hits for time events. 715 | 716 | Given two lists of event times ``ref`` and ``est``, and a 717 | tolerance window, computes a list of pairings 718 | ``(i, j)`` where ``|ref[i] - est[j]| <= window``. 719 | 720 | This is equivalent to, but more efficient than the following: 721 | 722 | >>> hit_ref, hit_est = np.where(np.abs(np.subtract.outer(ref, est)) 723 | ... <= window) 724 | 725 | Parameters 726 | ---------- 727 | ref : np.ndarray, shape=(n,) 728 | Array of reference values 729 | est : np.ndarray, shape=(m,) 730 | Array of estimated values 731 | window : float >= 0 732 | Size of the tolerance window 733 | 734 | Returns 735 | ------- 736 | hit_ref : np.ndarray 737 | hit_est : np.ndarray 738 | indices such that ``|hit_ref[i] - hit_est[i]| <= window`` 739 | ''' 740 | 741 | ref = np.asarray(ref) 742 | est = np.asarray(est) 743 | ref_idx = np.argsort(ref) 744 | ref_sorted = ref[ref_idx] 745 | 746 | left_idx = np.searchsorted(ref_sorted, est - window, side='left') 747 | right_idx = np.searchsorted(ref_sorted, est + window, side='right') 748 | 749 | hit_ref, hit_est = [], [] 750 | 751 | for j, (start, end) in enumerate(zip(left_idx, right_idx)): 752 | hit_ref.extend(ref_idx[start:end]) 753 | hit_est.extend([j] * (end - start)) 754 | 755 | return hit_ref, hit_est 756 | 757 | 758 | def validate_intervals(intervals): 759 | """Checks that an (n, 2) interval ndarray is well-formed, and raises errors 760 | if not. 761 | 762 | Parameters 763 | ---------- 764 | intervals : np.ndarray, shape=(n, 2) 765 | Array of interval start/end locations. 766 | 767 | """ 768 | 769 | # Validate interval shape 770 | if intervals.ndim != 2 or intervals.shape[1] != 2: 771 | raise ValueError('Intervals should be n-by-2 numpy ndarray, ' 772 | 'but shape={}'.format(intervals.shape)) 773 | 774 | # Make sure no times are negative 775 | if (intervals < 0).any(): 776 | raise ValueError('Negative interval times found') 777 | 778 | # Make sure all intervals have strictly positive duration 779 | if (intervals[:, 1] <= intervals[:, 0]).any(): 780 | raise ValueError('All interval durations must be strictly positive') 781 | 782 | 783 | def validate_events(events, max_time=30000.): 784 | """Checks that a 1-d event location ndarray is well-formed, and raises 785 | errors if not. 786 | 787 | Parameters 788 | ---------- 789 | events : np.ndarray, shape=(n,) 790 | Array of event times 791 | max_time : float 792 | If an event is found above this time, a ValueError will be raised. 793 | (Default value = 30000.) 794 | 795 | """ 796 | # Make sure no event times are huge 797 | if (events > max_time).any(): 798 | raise ValueError('An event at time {} was found which is greater than ' 799 | 'the maximum allowable time of max_time = {} (did you' 800 | ' supply event times in ' 801 | 'seconds?)'.format(events.max(), max_time)) 802 | # Make sure event locations are 1-d np ndarrays 803 | if events.ndim != 1: 804 | raise ValueError('Event times should be 1-d numpy ndarray, ' 805 | 'but shape={}'.format(events.shape)) 806 | # Make sure event times are increasing 807 | if (np.diff(events) < 0).any(): 808 | raise ValueError('Events should be in increasing order.') 809 | 810 | 811 | def validate_frequencies(frequencies, max_freq, min_freq, 812 | allow_negatives=False): 813 | """Checks that a 1-d frequency ndarray is well-formed, and raises 814 | errors if not. 815 | 816 | Parameters 817 | ---------- 818 | frequencies : np.ndarray, shape=(n,) 819 | Array of frequency values 820 | max_freq : float 821 | If a frequency is found above this pitch, a ValueError will be raised. 822 | (Default value = 5000.) 823 | min_freq : float 824 | If a frequency is found below this pitch, a ValueError will be raised. 825 | (Default value = 20.) 826 | allow_negatives : bool 827 | Whether or not to allow negative frequency values. 828 | """ 829 | # If flag is true, map frequencies to their absolute value. 830 | if allow_negatives: 831 | frequencies = np.abs(frequencies) 832 | # Make sure no frequency values are huge 833 | if (np.abs(frequencies) > max_freq).any(): 834 | raise ValueError('A frequency of {} was found which is greater than ' 835 | 'the maximum allowable value of max_freq = {} (did ' 836 | 'you supply frequency values in ' 837 | 'Hz?)'.format(frequencies.max(), max_freq)) 838 | # Make sure no frequency values are tiny 839 | if (np.abs(frequencies) < min_freq).any(): 840 | raise ValueError('A frequency of {} was found which is less than the ' 841 | 'minimum allowable value of min_freq = {} (did you ' 842 | 'supply frequency values in ' 843 | 'Hz?)'.format(frequencies.min(), min_freq)) 844 | # Make sure frequency values are 1-d np ndarrays 845 | if frequencies.ndim != 1: 846 | raise ValueError('Frequencies should be 1-d numpy ndarray, ' 847 | 'but shape={}'.format(frequencies.shape)) 848 | 849 | 850 | def has_kwargs(function): 851 | r'''Determine whether a function has \*\*kwargs. 852 | 853 | Parameters 854 | ---------- 855 | function : callable 856 | The function to test 857 | 858 | Returns 859 | ------- 860 | True if function accepts arbitrary keyword arguments. 861 | False otherwise. 862 | ''' 863 | 864 | if six.PY2: 865 | return inspect.getargspec(function).keywords is not None 866 | else: 867 | sig = inspect.signature(function) 868 | 869 | for param in sig.parameters.values(): 870 | if param.kind == param.VAR_KEYWORD: 871 | return True 872 | 873 | return False 874 | 875 | 876 | def filter_kwargs(_function, *args, **kwargs): 877 | """Given a function and args and keyword args to pass to it, call the function 878 | but using only the keyword arguments which it accepts. This is equivalent 879 | to redefining the function with an additional \*\*kwargs to accept slop 880 | keyword args. 881 | 882 | If the target function already accepts \*\*kwargs parameters, no filtering 883 | is performed. 884 | 885 | Parameters 886 | ---------- 887 | _function : callable 888 | Function to call. Can take in any number of args or kwargs 889 | 890 | """ 891 | 892 | if has_kwargs(_function): 893 | return _function(*args, **kwargs) 894 | 895 | # Get the list of function arguments 896 | func_code = six.get_function_code(_function) 897 | function_args = func_code.co_varnames[:func_code.co_argcount] 898 | # Construct a dict of those kwargs which appear in the function 899 | filtered_kwargs = {} 900 | for kwarg, value in list(kwargs.items()): 901 | if kwarg in function_args: 902 | filtered_kwargs[kwarg] = value 903 | # Call the function with the supplied args and the filtered kwarg dict 904 | return _function(*args, **filtered_kwargs) 905 | 906 | 907 | def intervals_to_durations(intervals): 908 | """Converts an array of n intervals to their n durations. 909 | 910 | Parameters 911 | ---------- 912 | intervals : np.ndarray, shape=(n, 2) 913 | An array of time intervals, as returned by 914 | :func:`mir_eval.io.load_intervals()`. 915 | The ``i`` th interval spans time ``intervals[i, 0]`` to 916 | ``intervals[i, 1]``. 917 | 918 | Returns 919 | ------- 920 | durations : np.ndarray, shape=(n,) 921 | Array of the duration of each interval. 922 | 923 | """ 924 | validate_intervals(intervals) 925 | return np.abs(np.diff(intervals, axis=-1)).flatten() 926 | 927 | 928 | def hz_to_midi(freqs): 929 | '''Convert Hz to MIDI numbers 930 | 931 | Parameters 932 | ---------- 933 | freqs : number or ndarray 934 | Frequency/frequencies in Hz 935 | 936 | Returns 937 | ------- 938 | midi : number or ndarray 939 | MIDI note numbers corresponding to input frequencies. 940 | Note that these may be fractional. 941 | ''' 942 | return 12.0 * (np.log2(freqs) - np.log2(440.0)) + 69.0 943 | 944 | 945 | def midi_to_hz(midi): 946 | '''Convert MIDI numbers to Hz 947 | 948 | Parameters 949 | ---------- 950 | midi : number or ndarray 951 | MIDI notes 952 | 953 | Returns 954 | ------- 955 | freqs : number or ndarray 956 | Frequency/frequencies in Hz corresponding to `midi` 957 | ''' 958 | return 440.0 * (2.0 ** ((midi - 69.0)/12.0)) -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | sr = 8000 2 | sample_duration = 0.005 # s 3 | L = int(sr * sample_duration) # [40] 4 | N = 500 # num_basis 5 | nspk = 2 6 | batch_size = 128 7 | epochs = 100 8 | seq_duration = 0.5 # s 9 | seq_len = int(seq_duration / sample_duration) 10 | 11 | cuda = True 12 | seed = 20181117 13 | log_step = 100 14 | lr = 3e-4 15 | 16 | rnn_type = 'LSTM' 17 | rnn_hidden_size = 500 18 | num_layers = 4 19 | bidirectional = True 20 | 21 | display_freq = 10 22 | val_save = 'model_181117.pt' 23 | data_dir = '/home/grz/data/SSSR/wsj0/min/' 24 | sum_dir = './tasnet/summary/' -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # TasNet: Time-domain Audio Separation Network 2 | This is the implementation in Tensorflow of "TasNet: 3 | Time-domain Audio Separation 4 | Network for Real-time, single-channel speech separation", published in ICASSP2018, by Yi Luo and 5 | Nima Mesgarani. 6 | 7 | ![avatar](./tasnet-architecture.png) 8 | 9 | ## Special 10 | This implementation takes [ododoyo's](https://github.com/ododoyo/TASNET) as 11 | reference, especially in SI-SNR and PIT training part. A extra MSE training objective and PIT 12 | training policy is implemented by myself. Also, this implementation haven't supported 13 | variable-length segments in training so far. 14 | Discussion, (friendly) criticism, suggestions are always welcomed! 15 | 16 | ## Requirements 17 | * tensorflow 1.8.0 18 | * python 3.5 19 | * librosa 20 | 21 | ## Contents 22 | 23 | * `params.py` defines all global parameters. 24 | * `data_generator.py` This file establishes WSJ0 2-mix datasets (referred to ICASSP 2016 Deep 25 | Clustering paper) and generates batch data for training. You may run this code firstly to 26 | generate datasets and change the path in `tf_train.py`. 27 | * `tf_net.py` defines the TasNet structure, loss, training optimizer, etc. 28 | * `tf_train.py` trains the model. Rewrite the dataset path with your own path. 29 | * `tf_test.py` evaluates the model performance. This code hasn't been written well, still under 30 | repair. 31 | * `mir_eval.py` and `mir_util.py` are forked from [ododoyo's](https://github.com/ododoyo/TASNET), 32 | implementing bss_eval calculation in Python rather than MATLAB. 33 | -------------------------------------------------------------------------------- /tasnet-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Moplast/TasNet-tensorflow/ddd29b41d24378c0f3dfdb752b02d4e8f48ee4a4/tasnet-architecture.png -------------------------------------------------------------------------------- /tf_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from params import * 3 | from tensorflow.contrib.layers import layer_norm, fully_connected 4 | from itertools import product, permutations 5 | 6 | 7 | class TasNet(object): 8 | def __init__(self, batch_size, seq_len): 9 | self.rnn_hidden = rnn_hidden_size 10 | self.K = int(seq_len) 11 | self.context = 3 12 | self.context_window = self.context // 2 13 | self.nspk = 2 14 | self.batch_size = batch_size 15 | 16 | self.eps = 1e-8 17 | 18 | self.var_U = tf.Variable(tf.truncated_normal(shape=[L, 1, N], dtype=tf.float32, name='var_U')) 19 | self.var_V = tf.Variable(tf.truncated_normal(shape=[L, 1, N], dtype=tf.float32, name='var_V')) 20 | self.var_B = tf.Variable(tf.truncated_normal(shape=[N, L], dtype=tf.float32, name='var_B')) 21 | 22 | def BLSTM_layernorm(self, input, index): 23 | var_scope = 'BLSTM' + str(index) 24 | with tf.variable_scope(var_scope) as scope: 25 | lstm_fw_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 26 | self.rnn_hidden, layer_norm=True, ) 27 | lstm_bw_cell = tf.contrib.rnn.LayerNormBasicLSTMCell( 28 | self.rnn_hidden, layer_norm=False, ) 29 | outputs, _ = tf.nn.bidirectional_dynamic_rnn( 30 | lstm_fw_cell, lstm_bw_cell, input, 31 | sequence_length=[self.context * N] * self.batch_size, 32 | dtype=tf.float32) 33 | output = tf.concat(outputs, 2) 34 | return output 35 | 36 | def BLSTM(self, input, index): 37 | var_scope = 'BLSTM' + str(index) 38 | with tf.variable_scope(var_scope) as scope: 39 | lstm_fw_cell = tf.contrib.rnn.LSTMCell( 40 | self.rnn_hidden, use_peepholes=True, cell_clip=25, state_is_tuple=True) 41 | lstm_bw_cell = tf.contrib.rnn.LSTMCell( 42 | self.rnn_hidden, use_peepholes=True, cell_clip=25, state_is_tuple=True) 43 | initial_fw = lstm_fw_cell.zero_state(tf.shape(input)[0], dtype=tf.float32) 44 | initial_bw = lstm_bw_cell.zero_state(tf.shape(input)[0], dtype=tf.float32) 45 | output, _ = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input, 46 | sequence_length=[self.K]*self.batch_size, 47 | initial_state_fw=initial_fw, 48 | initial_state_bw=initial_bw, 49 | dtype=tf.float32, 50 | time_major=False) 51 | output = tf.concat(output, 2) 52 | return output 53 | 54 | def encoder(self, mixture): 55 | ''' 56 | mixture 57 | :param mixture: [B, K, L] 58 | :return: mixture_w:[B, K, N], norm_coef: [B, K, 1] 59 | ''' 60 | with tf.variable_scope("encoder"): 61 | # normalize inputs at axis [L] 62 | norm_coef = tf.sqrt(tf.reduce_sum(mixture ** 2, axis=2, keepdims=True) + 1e-8) 63 | norm_mixture = mixture / norm_coef 64 | norm_mixture = tf.expand_dims(tf.reshape(norm_mixture, [-1, L]), axis=2) # [B*K, L, 1] 65 | # [B*K, L, 1] conv [L, 1, N] -> [B*K, N] 66 | conv = tf.nn.relu(tf.nn.conv1d(norm_mixture, self.var_U, 67 | stride=1, padding='VALID')) 68 | # [B*K, L, 1] conv [L, 1, N] -> [B*K, N] 69 | gate = tf.nn.sigmoid(tf.nn.conv1d(norm_mixture, self.var_V, 70 | stride=1, padding='VALID')) 71 | # gated 1D CNN to encode segment inputs into mixture weights 72 | mixture_w = conv * gate # [B*K,N] 73 | mixture_w = tf.reshape(mixture_w, [self.batch_size, -1, N]) # [B, K, N] 74 | 75 | self.summary_conv = tf.summary.histogram('encoder_conv', conv) 76 | self.summary_gate = tf.summary.histogram('encoder_gate', gate) 77 | return mixture_w, norm_coef 78 | 79 | def separate(self, mixture_w): 80 | ''' 81 | Separation Network 82 | :param mixture_w: [B, K, N] 83 | :return: mask_fc: [B, K, nspk, N] 84 | ''' 85 | 86 | # 1> layer normlization [B, K, N] 87 | norm_mixture_w = layer_norm(mixture_w, begin_norm_axis=2) 88 | norm_mixture_w = tf.reshape(norm_mixture_w, (self.batch_size, self.K, N)) 89 | 90 | self.summary_layer_norm_mix = tf.summary.histogram('separator_layer_norm_mix_w', norm_mixture_w) 91 | 92 | # 2> 1-segment context window -> [B, K, context * N] 93 | blank_ = tf.zeros([self.batch_size, self.context_window, N], dtype=tf.float32) 94 | # [B, context_window + K + context_window, N] 95 | padded_w_ = tf.concat([blank_, norm_mixture_w, blank_], axis=1) 96 | idx = 0 97 | new_w_ = padded_w_[:, idx: idx + self.context, :] 98 | for idx in range(1, self.K): 99 | new_w_ = tf.concat([new_w_, 100 | padded_w_[:, idx: idx + self.context, :]], 101 | axis=1) 102 | contexted_w = tf.reshape(new_w_, [self.batch_size, self.K * self.context, N]) 103 | contexted_w = tf.reshape(contexted_w, [self.batch_size, self.K, self.context * N]) 104 | 105 | # 3> BLSTM layer [B*K, rnn_layer_size] 106 | lstm1 = self.BLSTM(contexted_w, 1) 107 | lstm2 = self.BLSTM(lstm1, 2) 108 | lstm3 = self.BLSTM(lstm2, 3) 109 | lstm4 = self.BLSTM(lstm3 + lstm2, 4) 110 | output = lstm4 # [B, hidden] 111 | lstm_out = tf.reshape(output, [-1, 2 * self.rnn_hidden]) # [B*K, 2 * rnn_hidden] 112 | self.summary_lstm_out = tf.summary.histogram('separator_lstm_out', lstm_out) 113 | 114 | # 4> FC layer [B, K, nspk, N] 115 | fc = fully_connected(inputs=lstm_out, num_outputs=self.nspk * N, activation_fn=None) 116 | mask_fc = tf.reshape(fc, [self.batch_size, self.K, self.nspk, N]) 117 | mask_fc = tf.nn.softmax(mask_fc, axis=2) 118 | self.summary_lstm_out = tf.summary.histogram('separator_lstm_out', lstm_out) 119 | 120 | return mask_fc 121 | 122 | def decoder(self, mixture_w, est_mask): 123 | ''' 124 | decode network 125 | :param mixture_w: [B, K, N] 126 | :param est_mask: [B, K, nspk, N] 127 | :return: est_source: [B, K, nspk, L] 128 | ''' 129 | 130 | with tf.variable_scope("deocder"): 131 | source_w = est_mask * tf.expand_dims(mixture_w, axis=2) # [B, K, nspk, N] 132 | # another form of matmul. 133 | # source_w [B, K, nspk, N], var_B [N, L] -> [B, K, nspk, L] 134 | est_source = tf.einsum('bkcn,nl->bkcl', source_w, self.var_B) 135 | 136 | self.summary_B = tf.summary.histogram('decoder_basis_signals', self.var_B) 137 | return est_source 138 | 139 | def build_network(self, mixture): 140 | mixture_w, norm_coef = self.encoder(mixture) 141 | est_mask = self.separate(mixture_w) 142 | est_source = self.decoder(mixture_w, est_mask) # [B, K, nspk, L] 143 | 144 | norm_coef_ = tf.expand_dims(norm_coef, axis=2) # [B, K, 1, 1] 145 | est_source = tf.transpose(est_source * norm_coef_, [0, 2, 1, 3]) # [B, nspk, K, L] 146 | 147 | return est_source 148 | 149 | def objective(self, est_source, source): 150 | ''' 151 | :param est_source: [B, C, K, L] 152 | :param source: [B, C, K, L] 153 | :return: 154 | ''' 155 | max_snr, v_perms, max_snr_idx = self.get_si_snr(source, est_source) 156 | loss = 20 - tf.reduce_mean(max_snr) 157 | tar_perm = tf.gather(v_perms, max_snr_idx) 158 | tar_perm = tf.transpose(tf.one_hot(tar_perm, self.nspk), [0, 2, 1]) 159 | tar_perm = tf.cast(tf.argmax(tar_perm, axis=2), tf.int32) 160 | outer_axis = tf.tile(tf.reshape(tf.range(self.batch_size), [-1, 1]), [1, self.nspk]) 161 | gather_idx = tf.stack([outer_axis, tar_perm], axis=2) 162 | gather_idx = tf.reshape(gather_idx, [-1, 2]) 163 | reorder_recon = tf.reshape(tf.gather_nd(est_source, gather_idx), 164 | [self.batch_size, self.nspk, -1, L]) 165 | 166 | self.loss_summary = tf.summary.scalar('tasnet_loss', loss) 167 | self.snr_summary = tf.summary.scalar('snr', tf.reduce_mean(max_snr)) 168 | return loss, max_snr, est_source, reorder_recon 169 | 170 | 171 | def get_si_snr(self, source, est_source, name='pit_snr'): 172 | ''' 173 | :param source: [B, nspk, K, L] 174 | :param est_source: [B, nspk, K, L] 175 | :param name: 176 | :return: 177 | ''' 178 | max_len = tf.shape(source)[2] # 179 | # mask the padding part and flat the segmentation 180 | # zero-mean source and recon in the real length 181 | # seq_mask = self.get_seq_mask(max_len, self.K) 182 | # seq_mask = tf.reshape(seq_mask, [self.batch_size, 1, -1, 1]) 183 | # mask_targets = source * seq_mask 184 | # mask_recon = est_source * seq_mask 185 | sample_count = tf.cast(tf.reshape(self.batch_size * [self.K * L], [self.batch_size, 1, 1, 1]), tf.float32) 186 | mean_targets = tf.reduce_sum(source, axis=[2, 3], keepdims=True) / sample_count 187 | mean_recon = tf.reduce_sum(est_source, axis=[2, 3], keepdims=True) / sample_count 188 | zero_mean_targets = source - mean_targets 189 | zero_mean_recon = est_source - mean_recon 190 | # shape is [B, nspk, s] 191 | flat_targets = tf.reshape(zero_mean_targets, [self.batch_size, self.nspk, -1]) 192 | flat_recon = tf.reshape(zero_mean_recon, [self.batch_size, self.nspk, -1]) 193 | 194 | # calculate the SI-SNR, PIT is necessary 195 | with tf.variable_scope(name): 196 | v_perms = tf.constant( 197 | list(permutations(range(self.nspk))), 198 | dtype=tf.int32) 199 | perms_one_hot = tf.one_hot(v_perms, depth=self.nspk, dtype=tf.float32) 200 | 201 | # shape is [B, 1, nspk, s] 202 | s_truth = tf.expand_dims(flat_targets, axis=1) 203 | # shape is [B, nspk, 1, s] 204 | s_estimate = tf.expand_dims(flat_recon, axis=2) 205 | pair_wise_dot = tf.reduce_sum(s_estimate * s_truth, axis=3, keepdims=True) 206 | s_truth_energy = tf.reduce_sum(s_truth ** 2, axis=3, keepdims=True) + self.eps 207 | pair_wise_proj = pair_wise_dot * s_truth / s_truth_energy 208 | e_noise = s_estimate - pair_wise_proj 209 | # shape is [B, nspk, nspk] 210 | pair_wise_snr = tf.div(tf.reduce_sum(pair_wise_proj ** 2, axis=3), 211 | tf.reduce_sum(e_noise ** 2, axis=3) + self.eps) 212 | pair_wise_snr = 10 * tf.log(pair_wise_snr + self.eps) / tf.log(10.0) # log operation use 10 as base 213 | snr_set = tf.einsum('bij,pij->bp', pair_wise_snr, perms_one_hot) 214 | max_snr_idx = tf.cast(tf.argmax(snr_set, axis=1), dtype=tf.int32) 215 | max_snr = tf.gather_nd(snr_set, 216 | tf.stack([tf.range(self.batch_size, dtype=tf.int32), max_snr_idx], axis=1)) 217 | max_snr = max_snr / self.nspk 218 | 219 | return max_snr, v_perms, max_snr_idx 220 | 221 | def MSE_objective(self, source, est_source, name='pit_mse'): 222 | ''' 223 | :param source: [B, nspk, K, L] 224 | :param est_source: [B, nspk, K, L] 225 | :param name: 226 | :return: 227 | ''' 228 | sample_count = tf.cast(tf.reshape(self.batch_size * [self.K * L], [self.batch_size, 1, 1, 1]), tf.float32) 229 | mean_targets = tf.reduce_sum(source, axis=[2, 3], keepdims=True) / sample_count 230 | mean_recon = tf.reduce_sum(est_source, axis=[2, 3], keepdims=True) / sample_count 231 | zero_mean_targets = source - mean_targets 232 | zero_mean_recon = est_source - mean_recon 233 | # shape is [B, nspk, s] 234 | flat_targets = tf.reshape(zero_mean_targets, [self.batch_size, self.nspk, -1]) 235 | flat_recon = tf.reshape(zero_mean_recon, [self.batch_size, self.nspk, -1]) 236 | norm_targets = tf.nn.l2_normalize(flat_targets, axis=2) # [B, spk, s] 237 | norm_recon = tf.nn.l2_normalize(flat_recon, axis=2)# [B, spk, s] 238 | 239 | # calculate the MSE, PIT is necessary 240 | with tf.variable_scope(name): 241 | v_perms = tf.constant( 242 | list(permutations(range(self.nspk))), 243 | dtype=tf.int32) 244 | perms_one_hot = tf.one_hot(v_perms, depth=self.nspk, dtype=tf.float32) 245 | 246 | # compute pairwise costs 247 | pairwise_mse = [] 248 | for src_id, out_id in product(range(self.nspk), range(self.nspk)): 249 | loss = tf.squared_difference(norm_targets[:, src_id, :], 250 | norm_recon[:, out_id, :]) 251 | if src_id == 0 and out_id == 0: 252 | pairwise_mse = tf.reduce_sum(loss, axis=1, keepdims=True) 253 | else: 254 | pairwise_mse = tf.concat([pairwise_mse, tf.reduce_sum(loss, axis=1, keepdims=True)], axis=1) 255 | pairwise_mse = tf.reshape(pairwise_mse, [self.batch_size, self.nspk, self.nspk]) 256 | 257 | # decide assignment 258 | mse_set = tf.einsum('bij,pij->bp', pairwise_mse, perms_one_hot) 259 | min_mse_idx = tf.cast(tf.argmin(mse_set, axis=1), dtype=tf.int32) 260 | min_mse = tf.gather_nd(mse_set, 261 | tf.stack([tf.range(self.batch_size, dtype=tf.int32), min_mse_idx], axis=1)) 262 | min_mse = min_mse / self.nspk 263 | 264 | loss = tf.reduce_mean(min_mse) 265 | tar_perm = tf.gather(v_perms, min_mse_idx) 266 | tar_perm = tf.transpose(tf.one_hot(tar_perm, self.nspk), [0, 2, 1]) 267 | tar_perm = tf.cast(tf.argmax(tar_perm, axis=2), tf.int32) 268 | outer_axis = tf.tile(tf.reshape(tf.range(self.batch_size), [-1, 1]), [1, self.nspk]) 269 | gather_idx = tf.stack([outer_axis, tar_perm], axis=2) 270 | gather_idx = tf.reshape(gather_idx, [-1, 2]) 271 | reorder_recon = tf.reshape(tf.gather_nd(est_source, gather_idx), 272 | [self.batch_size, self.nspk, -1, L]) 273 | 274 | self.loss_summary = tf.summary.scalar('tasnet_mse_loss', loss) 275 | 276 | return loss, min_mse, est_source, reorder_recon 277 | 278 | def train(self, loss, lr): 279 | optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.9, beta2=0.999, epsilon=1e-8) 280 | # optimizer = tf.segment_test.MomentumOptimizer(lr, 0.9) 281 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 282 | with tf.control_dependencies(update_ops): 283 | gradients, v = zip(*optimizer.compute_gradients(loss)) 284 | gradients, _ = tf.clip_by_global_norm(gradients, 200) 285 | train_op = optimizer.apply_gradients(zip(gradients, v)) 286 | return train_op 287 | -------------------------------------------------------------------------------- /tf_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tasnet.params import * 4 | from tasnet.tf_net import TasNet 5 | from tasnet.mir_eval import bss_eval_sources 6 | from tasnet.data_generator import DataGenerator 7 | import os 8 | from datetime import datetime 9 | import time 10 | import librosa 11 | import matlab.engine 12 | import pickle 13 | 14 | from itertools import product, permutations 15 | 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 17 | seq_duration = 0.5 18 | seq_len = int(seq_duration / sample_duration) 19 | 20 | model_subpath = '100-e19.ckpt-32641' 21 | model_path = os.getcwd() + '/' + sum_dir + '/model/' + model_subpath 22 | 23 | 24 | def evaluate(data_type): 25 | wav_dir = os.path.join(data_dir, data_type) 26 | spk_gender = pickle.load(open('/home/grz/SS/MESID/dataset/wsj0_spk_gender.pkl', 'rb')) 27 | mix_dir = os.path.join(wav_dir, 'mix') 28 | s1_dir = os.path.join(wav_dir, 's1') 29 | s2_dir = os.path.join(wav_dir, 's2') 30 | 31 | list_mix = os.listdir(mix_dir) 32 | list_mix.sort() 33 | if data_type == 'tt': # 3000 34 | factor = 50 35 | else: 36 | factor = 2 37 | list_mix = list_mix[::factor] 38 | np.random.shuffle(list_mix) 39 | 40 | segment_num = 0 41 | MATLAB = matlab.engine.start_matlab() 42 | 43 | # =============== PLACEHOLDER & MODEL DEFINITION ======================== 44 | with tf.Graph().as_default(): 45 | mixture = tf.placeholder(shape=[None, None, L], dtype=tf.float32, name='mixture') 46 | source = tf.placeholder(shape=[None, nspk, None, L], dtype=tf.float32, name='source') 47 | 48 | print('>> Initializing model...') 49 | model = TasNet(batch_size=1, seq_len=seq_len) 50 | est_source = model.build_network(mixture) 51 | loss, max_snr, reest_source, reorder_recon = model.objective(est_source=est_source, source=source) 52 | 53 | saver = tf.train.Saver(tf.global_variables()) 54 | sess = tf.Session() 55 | init = tf.global_variables_initializer() 56 | sess.run(init) 57 | saver.restore(sess, model_path) 58 | print('load model from: ', model_path) 59 | 60 | data_gen = DataGenerator(batch_size=1, max_k=seq_len, name='eval-generator') 61 | 62 | print('>> Evaluating... %s start' % datetime.now()) 63 | 64 | for wav_file in list_mix: 65 | gender_mix = 'dg' 66 | print('# segment: ', segment_num) 67 | if not wav_file.endswith('.wav'): 68 | continue 69 | 70 | gender1 = spk_gender[wav_file.split(sep='_')[0][:3]] 71 | gender2 = spk_gender[wav_file.split(sep='_')[2][:3]] 72 | if gender1 == gender2: 73 | print('>> same gender') 74 | gender_mix = 'sg' 75 | else: 76 | print('>> diff gender') 77 | 78 | print("Sentence {0}, gender_mix: {1}".format(wav_file, gender_mix)) 79 | mix_path = os.path.join(mix_dir, wav_file) 80 | s1_path = os.path.join(s1_dir, wav_file) 81 | s2_path = os.path.join(s2_dir, wav_file) 82 | spk1 = wav_file[:3] 83 | spk2 = wav_file.split(sep='_')[2][:3] 84 | 85 | mix, _ = librosa.load(mix_path, sr=sr) 86 | s1, _ = librosa.load(s1_path, sr=sr) 87 | s2, _ = librosa.load(s2_path, sr=sr) 88 | mix_len = len(mix) 89 | test_sample = data_gen.get_a_sample(mix, s1, s2, spks=[spk1, spk2], max_k=seq_len) 90 | sample_num = len(test_sample['mix']) 91 | # utterance-level info 92 | est_s_u = [] 93 | snr_u = [] 94 | 95 | start_time = datetime.now() 96 | for i in range(sample_num): 97 | est_source_np, ordered_est_source_np, max_snr_np = sess.run([est_source, reorder_recon, max_snr], 98 | feed_dict={mixture: [np.array(test_sample['mix'][i]).astype(np.float32)], 99 | source: [np.array(test_sample['s'][i]).astype(np.float32)]}) 100 | # (ordered) est_source_np [1, nspk, K, L] 101 | # max_snr_np [1] 102 | est_s_u.append(ordered_est_source_np[0]) 103 | snr_u.append(max_snr_np[0]) 104 | 105 | duration = (datetime.now() - start_time).seconds 106 | print('>> past time (s):', duration) 107 | recon_s1_sig, recon_s2_sig = recover_sig(est_s_u) 108 | recon_s1_sig = recon_s1_sig[:mix_len] 109 | recon_s2_sig = recon_s2_sig[:mix_len] 110 | 111 | # sdri = measure_wsj0(MATLAB, mix, recon_s1_sig, recon_s2_sig, s1, s2, 112 | # mixture_name='tasnet-sdri') 113 | src_ref = np.stack([s1, s2], axis=0) 114 | src_est = np.stack([recon_s1_sig, recon_s2_sig], axis=0) 115 | src_anchor = np.stack([mix, mix], axis=0) 116 | sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est) 117 | sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor) 118 | snr1 = get_SISNR(s1, recon_s1_sig) 119 | snr2 = get_SISNR(s2, recon_s2_sig) 120 | 121 | print("snr1: {}, snr2: {}".format(snr1, snr2)) 122 | print("sdr1: {}, sdr2: {}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[0])) 123 | 124 | sess.close() 125 | 126 | 127 | def recover_sig(sig_list): 128 | '''sig_list: [n] - [nspk, K, L]''' 129 | sig_np = np.concatenate([n for n in sig_list], axis=1) # [n*K, L] 130 | sig = np.reshape(sig_np, [nspk, -1]) 131 | return sig[0], sig[1] 132 | 133 | 134 | def get_time_str(): 135 | time = datetime.now() 136 | timestr = '%02d%02d%02d-%02d%02d%02d' % (time.year, time.month, time.day, 137 | time.hour, time.minute, time.second) 138 | return timestr 139 | 140 | 141 | def get_SISNR(ref_sig, out_sig, eps=1e-8): 142 | assert len(ref_sig) == len(out_sig) 143 | ref_sig = ref_sig - np.mean(ref_sig) 144 | out_sig = out_sig - np.mean(out_sig) 145 | ref_energy = np.sum(ref_sig ** 2) + eps 146 | proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy 147 | noise = out_sig - proj 148 | ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps) 149 | sisnr = 10 * np.log(ratio + eps) / np.log(10.0) 150 | return sisnr 151 | 152 | 153 | def measure_wsj0(MATLAB, mix_wav, est_speech1, est_speech2, ori_speech1, ori_speech2, 154 | mixture_name=''): 155 | 156 | min_len = min(len(mix_wav), np.array(est_speech1).shape[0], np.array(est_speech2).shape[0], 157 | np.array(ori_speech1).shape[0], len(ori_speech2)) 158 | mix_wav = mix_wav[:min_len] 159 | est_speech2 = est_speech2[:min_len] 160 | est_speech1 = est_speech1[:min_len] 161 | ori_speech1 = ori_speech1[:min_len] 162 | ori_speech2 = ori_speech2[:min_len] 163 | 164 | mix_wav = matlab.double(mix_wav.tolist()) 165 | ori_speech1 = matlab.double(ori_speech1.tolist()) 166 | ori_speech2 = matlab.double(ori_speech2.tolist()) 167 | est_speech1 = matlab.double(est_speech1.tolist()) 168 | est_speech2 = matlab.double(est_speech2.tolist()) 169 | # BSS_EVAL (true_signal, true_noise, pred_signal, mix) 170 | 171 | bss_eval_results11 = MATLAB.BSS_EVAL(ori_speech1, ori_speech2, est_speech1, mix_wav) # ori_speech 172 | bss_eval_results21 = MATLAB.BSS_EVAL(ori_speech2, ori_speech1, est_speech2, mix_wav) 173 | 174 | writeline = mixture_name + '\n- speech_1\tSDR: %.2f dB\tSIR: %.2f dB\tSAR: %.2f dB\tNSDR: %.2f dB\n' \ 175 | % (bss_eval_results11['SDR'], bss_eval_results11['SIR'], bss_eval_results11['SAR'], 176 | bss_eval_results11['NSDR']) \ 177 | + '- speech_2\tSDR: %.2f dB\tSIR: %.2f dB\tSAR: %.2f dB\tNSDR: %.2f dB' \ 178 | % (bss_eval_results21['SDR'], bss_eval_results21['SIR'], bss_eval_results21['SAR'], 179 | bss_eval_results21['NSDR']) 180 | print(writeline) 181 | 182 | nsdr1 = bss_eval_results11['NSDR'] 183 | if nsdr1 < 0: 184 | nsdr1 = 0 185 | nsdr2 = bss_eval_results21['NSDR'] 186 | if nsdr2 < 0: 187 | nsdr2 = 0 188 | # for debug 189 | # if nsdr1 < 4 or nsdr2 < 4: 190 | # return None 191 | # else: 192 | return nsdr1 + nsdr2 193 | 194 | if __name__ == '__main__': 195 | evaluate('cv') 196 | -------------------------------------------------------------------------------- /tf_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from params import * 4 | from tf_net import TasNet 5 | from data_generator import DataGenerator 6 | import os 7 | from datetime import datetime 8 | import time 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 11 | scratch_or_resume = False 12 | seq_duration = seq_duration if scratch_or_resume else 2.0 13 | if scratch_or_resume: 14 | seq_duration = seq_duration 15 | val_list = '/home/grz/data/SSSR/wsj0_tasnet/cv/raw_100-52825.pkl' 16 | trn_list = '/home/grz/data/SSSR/wsj0_tasnet/tr/raw_100-208965.pkl' 17 | else: 18 | if seq_duration == 2.0: 19 | val_list = '/home/grz/data/SSSR/wsj0_tasnet/cv/raw_400-11338.pkl' 20 | trn_list = '/home/grz/data/SSSR/wsj0_tasnet/tr/raw_400-44956.pkl' 21 | batch_size = 30 22 | else: 23 | val_list = '/home/grz/data/SSSR/wsj0_tasnet/cv/raw_800-5486.pkl' 24 | trn_list = '/home/grz/data/SSSR/wsj0_tasnet/tr/raw_800-21789.pkl' 25 | batch_size = 16 26 | model_subpath = '100-e19.ckpt-32641' 27 | model_path = os.getcwd() + '/' + sum_dir + '/model/' + model_subpath 28 | 29 | seq_len = int(seq_duration / sample_duration) 30 | 31 | 32 | def train(max_epoch): 33 | 34 | # =============== PLACEHOLDER & MODEL DEFINITION ======================== 35 | with tf.Graph().as_default(): 36 | mixture = tf.placeholder(shape=[None, None, L], dtype=tf.float32, name='mixture') 37 | source = tf.placeholder(shape=[None, nspk, None, L], dtype=tf.float32, name='source') 38 | train_lr = tf.placeholder(shape=None, dtype=tf.float32, name='lr') 39 | 40 | print('>> Initializing model...') 41 | model = TasNet(batch_size=batch_size, seq_len=seq_len) 42 | est_source = model.build_network(mixture) 43 | # loss, min_mse, est_source, reorder_recon = model.MSE_objective(source=source, 44 | # est_source=est_source) 45 | loss, max_snr, est_source, _ = model.objective(est_source=est_source, source=source) 46 | train_op = model.train(loss, lr=train_lr) 47 | 48 | saver = tf.train.Saver(tf.global_variables()) 49 | # summary_op = tf.summary.merge_all() 50 | sess = tf.Session() 51 | init = tf.global_variables_initializer() 52 | sess.run(init) 53 | 54 | if scratch_or_resume: 55 | val_loss = [] 56 | val_loss_min = np.inf 57 | step = 0 58 | last_epoch = 0 59 | else: 60 | saver.restore(sess, model_path) 61 | step = int(model_subpath.split(sep='-')[-1]) 62 | last_epoch = int(model_subpath.split(sep='.')[0].split(sep='-')[-1][1:]) 63 | print('load model from: ', model_path) 64 | 65 | # =============== SUMMARY & DATA ======================== 66 | summary_dir = sum_dir + '/' + 'duration' + str(seq_duration) 67 | if not os.path.exists(summary_dir): 68 | os.mkdir(summary_dir) 69 | os.mkdir(summary_dir + '/train') 70 | os.mkdir(summary_dir + '/val') 71 | summary_writer_train = tf.summary.FileWriter(summary_dir + '/train', sess.graph) 72 | summary_writer_val = tf.summary.FileWriter(summary_dir + '/val') 73 | 74 | # ----------------------------------- DATA ------------------------------------------ 75 | # 1> generator for training set and validation set 76 | print('>> Loading data...') 77 | val_generator = DataGenerator(batch_size=batch_size, max_k=seq_len, 78 | name='val-generator') 79 | data_generator = DataGenerator(batch_size=batch_size, max_k=seq_len, 80 | name='train-generator') 81 | val_generator.load_data(val_list) 82 | data_generator.load_data(trn_list) 83 | 84 | train_lr_value = lr 85 | data_generator.epoch = last_epoch 86 | train_loss_ = [] 87 | val_no_best = 0 88 | print('>> Training... %s start' % datetime.now()) 89 | while last_epoch <= max_epoch: 90 | step += 1 91 | start_time = time.time() 92 | 93 | data_batch = data_generator.gen_batch() 94 | loss_value, _, \ 95 | sum1, sum2, sum3, \ 96 | sum4, sum5, \ 97 | sum6, sum7 = sess.run([loss, train_op, 98 | model.loss_summary, model.snr_summary, model.summary_conv, 99 | model.summary_gate, model.summary_lstm_out, 100 | model.summary_layer_norm_mix, model.summary_B], 101 | feed_dict={mixture: np.array(data_batch['mix']).astype(np.float32), 102 | source: np.array(data_batch['s']).astype(np.float32), 103 | train_lr: train_lr_value}) 104 | for sum_idx in range(1, 8): 105 | eval('summary_writer_train.add_summary(sum' + str(sum_idx) + ', step)') 106 | train_loss_.append(loss_value) 107 | duration = time.time() - start_time 108 | if np.isnan(loss_value): 109 | print('NAN loss: epoch %d step %d' % (last_epoch, step)) 110 | 111 | if step % display_freq == 0: 112 | num_examples_per_step = batch_size 113 | examples_per_sec = num_examples_per_step / duration 114 | sec_per_batch = float(duration) 115 | format_str = '%s: step %d, loss = %.5f, (%.1f sp/s; %.3f s/batch, epoch %d)' 116 | print(format_str % (get_time_str(), step, sum(train_loss_) / len(train_loss_), 117 | examples_per_sec, sec_per_batch, 118 | data_generator.epoch)) 119 | train_loss_ = [] 120 | 121 | # ----------------------------------- VALIDATION ------------------------------------------ 122 | if last_epoch != data_generator.epoch: 123 | # doing validation every training epoch 124 | print('>> Current epoch: ', last_epoch, ', doing validation') 125 | val_epoch = val_generator.epoch 126 | count, loss_sum, sum1, sum2 = 0, 0, '', '' 127 | # average the validation loss 128 | while val_epoch == val_generator.epoch: 129 | count += 1 130 | data_batch = val_generator.gen_batch() 131 | loss_value, sum1, sum2 = sess.run([loss, model.loss_summary, model.snr_summary], 132 | feed_dict={mixture: np.array(data_batch['mix']).astype(np.float32), 133 | source: np.array(data_batch['s']).astype(np.float32)}) 134 | loss_sum += loss_value 135 | summary_writer_val.add_summary(sum1, step) 136 | summary_writer_val.add_summary(sum2, step) 137 | val_loss_sum = (loss_sum / count) 138 | val_loss.append(val_loss_sum) 139 | format_str = 'validation: loss = %.5f' 140 | print(format_str % (val_loss_sum)) 141 | np.array(val_loss).tofile(sum_dir + '/loss/val_' + str(seq_len)) 142 | if val_loss_sum < val_loss_min: 143 | print('# train_net: saving model at step %d because of minimum validation loss' % step) 144 | save_model(sess, saver, last_epoch, step) 145 | val_loss_min = val_loss_sum 146 | val_no_best = 0 147 | else: 148 | val_no_best += 1 149 | if val_no_best == 3: 150 | train_lr_value = train_lr_value / 2 151 | print('# no improvement in 3 epochs, reduce learning rate to:', 152 | train_lr_value) 153 | val_no_best = 0 154 | if val_no_best >= 10: 155 | print('# Early stop! ') 156 | break 157 | last_epoch = data_generator.epoch 158 | if last_epoch == max_epoch: 159 | save_model(sess, saver, last_epoch, step) 160 | print('reach max training epoch.') 161 | return 162 | 163 | sess.close() 164 | 165 | 166 | def get_time_str(): 167 | time = datetime.now() 168 | timestr = '%02d%02d%02d-%02d%02d%02d' % (time.year, time.month, time.day, 169 | time.hour, time.minute, time.second) 170 | return timestr 171 | 172 | 173 | def save_model(sess, saver, last_epoch, step): 174 | checkpoint_path = os.path.join(sum_dir + '/model/', 175 | str(seq_len) + '-e' + str(last_epoch) + '.ckpt') 176 | saver.save(sess, checkpoint_path, global_step=step) 177 | 178 | 179 | if __name__ == '__main__': 180 | train(100) 181 | --------------------------------------------------------------------------------