├── utils.py ├── README.md ├── .gitignore ├── 224s-final-paper.pdf ├── tensorboard ├── 2017_05_29_05_12_22 │ └── train │ │ └── events.out.tfevents.1496034754.rescomp-16-283607.stanford.edu ├── 2017_05_29_05_13_10 │ └── train │ │ └── events.out.tfevents.1496034796.rescomp-16-283607.stanford.edu └── 2017_05_29_05_14_13 │ └── train │ └── events.out.tfevents.1496034859.rescomp-16-283607.stanford.edu ├── data_scripts ├── preprocessing │ ├── remove_short_segments.py │ ├── slice_clean_audio.py │ ├── process_audio.py │ └── split_train_dev.py ├── combining │ └── combine_audio.py ├── test_scripts │ ├── create_test_data.py │ └── create_test_batch.py └── masking │ └── test_masking.py ├── config.py ├── freq_weight.py ├── model.py ├── run.py └── evaluate.py /utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # speech_separation 2 | 3 | Hi 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | *.pyc 3 | data/* 4 | tensorboard/* 5 | checkpoints/* 6 | -------------------------------------------------------------------------------- /224s-final-paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjkwon0609/speech_separation/HEAD/224s-final-paper.pdf -------------------------------------------------------------------------------- /tensorboard/2017_05_29_05_12_22/train/events.out.tfevents.1496034754.rescomp-16-283607.stanford.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjkwon0609/speech_separation/HEAD/tensorboard/2017_05_29_05_12_22/train/events.out.tfevents.1496034754.rescomp-16-283607.stanford.edu -------------------------------------------------------------------------------- /tensorboard/2017_05_29_05_13_10/train/events.out.tfevents.1496034796.rescomp-16-283607.stanford.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjkwon0609/speech_separation/HEAD/tensorboard/2017_05_29_05_13_10/train/events.out.tfevents.1496034796.rescomp-16-283607.stanford.edu -------------------------------------------------------------------------------- /tensorboard/2017_05_29_05_14_13/train/events.out.tfevents.1496034859.rescomp-16-283607.stanford.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjkwon0609/speech_separation/HEAD/tensorboard/2017_05_29_05_14_13/train/events.out.tfevents.1496034859.rescomp-16-283607.stanford.edu -------------------------------------------------------------------------------- /data_scripts/preprocessing/remove_short_segments.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import matplotlib.pyplot as plt 3 | import os 4 | import numpy as np 5 | 6 | DATA_DIR = '../../data/sliced_clean/' 7 | 8 | for f in os.listdir(DATA_DIR): 9 | if f[-4:] == '.wav': 10 | rate, data = wavfile.read(DATA_DIR + f) 11 | file_length = len(data) / float(rate) 12 | if file_length < 1: 13 | os.remove(DATA_DIR + f) 14 | print 'removed file %s which had length %f seconds' % (DATA_DIR + f, file_length) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | """Holds model hyperparams and data information. 3 | 4 | The config class is used to store various hyperparameters and dataset 5 | information parameters. Model objects are passed a Config() object at 6 | instantiation. 7 | """ 8 | num_final_features = 513 9 | 10 | batch_size = 16 # 16 11 | output_size = num_final_features * 2 12 | num_hidden = 128 13 | 14 | num_layers = 3 15 | 16 | num_epochs = 50 17 | l2_lambda = 0.0000001 18 | lr = 5e-4 -------------------------------------------------------------------------------- /data_scripts/combining/combine_audio.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import os 3 | import numpy as np 4 | 5 | INPUT_NOISE_DIR = '../../data/raw_noise/' 6 | INPUT_CLEAN_DIR = '../../data/sliced_clean/' 7 | OUTPUT_DIR = '../../data/combined/' 8 | 9 | for clean in os.listdir(INPUT_CLEAN_DIR): 10 | for noise in os.listdir(INPUT_NOISE_DIR): 11 | if clean[-4:] == '.wav' and noise[-4:] == '.wav': 12 | rate_clean, data_clean = wavfile.read(INPUT_CLEAN_DIR + clean) 13 | rate_noise, data_noise = wavfile.read(INPUT_NOISE_DIR + noise) 14 | 15 | length = len(data_clean) 16 | 17 | data_noise = data_noise[:length] 18 | 19 | average = [(s1/2 + s2/2) for (s1, s2) in zip(data_clean, data_noise)] 20 | 21 | filename = '%s%s.wav' % (OUTPUT_DIR, clean[:-4]) 22 | 23 | wavfile.write(filename, rate_clean, np.asarray(average, dtype=np.int16)) 24 | -------------------------------------------------------------------------------- /freq_weight.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import matplotlib.pyplot as plt 4 | 5 | num_freq_bins = 512 6 | 7 | frequencies = np.array([2.0 * 180 * i / num_freq_bins * 22050 / 360 for i in xrange(num_freq_bins)]) 8 | frequencies[0] = 2.0 * 180 / num_freq_bins / 2 * 22050 / 360 # 0th frequency threshold is computed at 3/4th of the frequency range 9 | ath_val = 3.64 * np.power(1000 / frequencies, 0.8) - 6.5 * np.exp(-0.6 * np.power(frequencies / 1000 - 3.3, 2)) + np.power(0.1, 3) * np.power(frequencies / 1000, 4) 10 | ath_shifted = (1 - np.amin(ath_val)) + ath_val # shift all ath vals so that min is 1 11 | weights = 1 / ath_shifted 12 | print(frequencies) 13 | print(weights) 14 | 15 | normalized = np.full(weights.shape, np.sqrt(np.sum(np.power(weights, 2)) / num_freq_bins)) 16 | print(np.linalg.norm(weights, ord=2)) 17 | print(np.linalg.norm(normalized, ord=2)) 18 | 19 | plt.plot(frequencies, weights) 20 | plt.xlabel('Frequency (Hz)') 21 | plt.ylabel('weights') 22 | plt.show() 23 | -------------------------------------------------------------------------------- /data_scripts/preprocessing/slice_clean_audio.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import matplotlib.pyplot as plt 3 | import os 4 | import numpy as np 5 | 6 | DATA_DIR = '../../data/raw_clean/' 7 | OUTPUT_DIR = '../../data/sliced_clean/' 8 | 9 | for f in os.listdir(DATA_DIR): 10 | if f[-4:] == '.wav': 11 | rate, data = wavfile.read(DATA_DIR + f) 12 | clean_samples = 0 13 | clean_frame_threshold = 1300 14 | window_size = 10000 15 | 16 | frame_slice_ix = [] 17 | silent_frame = True 18 | 19 | moving_average = np.average(np.absolute(data[0:window_size])) 20 | skip_i = 0 21 | 22 | for i in xrange(window_size, len(data) - window_size): 23 | moving_average = moving_average * (1 - 1.0 / window_size) + np.absolute(data[i]) / float(window_size) 24 | 25 | if silent_frame: 26 | if moving_average > clean_frame_threshold: 27 | silent_frame = False 28 | frame_slice_ix.append(i) 29 | else: 30 | if moving_average < clean_frame_threshold: 31 | silent_frame = True 32 | frame_slice_ix.append(i) 33 | 34 | for i in xrange(0, len(frame_slice_ix), 2): 35 | filename = '%s%s_%d.wav' % (OUTPUT_DIR, f[:-4], i / 2) 36 | wavfile.write(filename, rate, data[frame_slice_ix[i]:frame_slice_ix[i + 1]]) 37 | 38 | -------------------------------------------------------------------------------- /data_scripts/test_scripts/create_test_data.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import matplotlib.pyplot as plt 3 | import os 4 | import numpy as np 5 | from scipy import signal 6 | import stft 7 | import h5py 8 | 9 | 10 | 11 | INPUT_NOISE_DIR = '../../data/raw_noise/' 12 | INPUT_CLEAN_DIR = '../../data/sliced_clean/' 13 | OUTPUT_DIR = '../../data/test_combined/' 14 | 15 | CLEAN_FILE = INPUT_CLEAN_DIR + 'f10_script2_clean_113.wav' 16 | NOISE_FILE = INPUT_NOISE_DIR + 'noise1_1.wav' 17 | 18 | def writeWav(fn, fs, data): 19 | data = data * 1.5 / np.max(np.abs(data)) 20 | wavfile.write(fn, fs, data) 21 | 22 | if __name__ == '__main__': 23 | spectrogram_args = {'framelength': 512} 24 | rate_clean, data_clean = wavfile.read(CLEAN_FILE) 25 | rate_noise, data_noise = wavfile.read(NOISE_FILE) 26 | 27 | data_len = len(data_clean) 28 | data_noise = data_noise[:data_len] 29 | 30 | print data_clean.dtype 31 | print data_noise.dtype 32 | 33 | data_combined = np.array([s1/2 + s2/2 for (s1, s2) in zip(data_clean, data_noise)], dtype=np.int16) 34 | # data_combined = data_noise 35 | 36 | print data_combined.dtype 37 | 38 | wavfile.write('%scombined.wav' % (OUTPUT_DIR), rate_clean, data_combined) 39 | 40 | Sx_clean = stft.spectrogram(data_clean, **spectrogram_args) 41 | Sx_noise = stft.spectrogram(data_noise, **spectrogram_args) 42 | 43 | reverted_clean = stft.ispectrogram(Sx_clean) 44 | reverted_noise = stft.ispectrogram(Sx_noise) 45 | 46 | writeWav('%soriginal_clean.wav' % (OUTPUT_DIR), rate_clean, reverted_clean) 47 | writeWav('%soriginal_noise.wav' % (OUTPUT_DIR), rate_noise, reverted_noise) -------------------------------------------------------------------------------- /data_scripts/preprocessing/process_audio.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import matplotlib.pyplot as plt 3 | import os 4 | import numpy as np 5 | from scipy import signal 6 | import stft 7 | import h5py 8 | 9 | INPUT_NOISE_DIR = '../../data/raw_noise/' 10 | INPUT_CLEAN_DIR = '../../data/sliced_clean/' 11 | OUTPUT_DIR = '../../data/processed/' 12 | 13 | def pad_data(data): 14 | num_samples = len(data) 15 | # print(num_samples) 16 | max_rows_in_sample = max(len(data[i]) for i in xrange(num_samples)) 17 | # print([len(data[i]) for i in xrange(num_samples)]) 18 | # print(max_rows_in_sample) 19 | num_cols_in_row = data[0][0].size 20 | # print(num_cols_in_row) 21 | # print(data[0][0]) 22 | new_data = np.zeros((num_samples, max_rows_in_sample, num_cols_in_row)) 23 | for i, sample in enumerate(data): 24 | for j, row in enumerate(sample): 25 | num_rows = len(sample) 26 | for k, c in enumerate(row): 27 | new_data[i][max_rows_in_sample - num_rows + j][k] = c 28 | return new_data 29 | 30 | if __name__ == '__main__': 31 | processed_data = [] 32 | 33 | noise_data = [wavfile.read(INPUT_NOISE_DIR + noise)[1] for noise in os.listdir(INPUT_NOISE_DIR) if noise[-4:] == '.wav'] 34 | noise_data = noise_data[:5] 35 | 36 | batch_size = 200 37 | curr = 0 38 | curr_batch = 0 39 | 40 | for i, clean in enumerate(os.listdir(INPUT_CLEAN_DIR)): 41 | 42 | if clean[-4:] == '.wav': 43 | rate_clean, data_clean = wavfile.read(INPUT_CLEAN_DIR + clean) 44 | for noise in noise_data: 45 | data_noise = noise[:] 46 | 47 | length = len(data_clean) 48 | data_noise = data_noise[:length][:] 49 | 50 | data_combined = np.array([(s1/2 + s2/2) for (s1, s2) in zip(data_clean, data_noise)]) 51 | 52 | Sx_clean = stft.spectrogram(data_clean).transpose() / 100000 53 | Sx_noise = stft.spectrogram(data_noise).transpose() / 100000 54 | Sx_combined = stft.spectrogram(data_combined).transpose() / 100000 55 | # Sx_clean = pretty_spectrogram(data_clean.astype('float64'), fft_size=fft_size, step_size=step_size, thresh=spec_thresh) 56 | # Sx_noise = pretty_spectrogram(data_noise.astype('float64'), fft_size=fft_size, step_size=step_size, thresh=spec_thresh) 57 | # Sx_combined = pretty_spectrogram(data_combined.astype('float64'), fft_size=fft_size, step_size=step_size, thresh=spec_thresh) 58 | 59 | # Sx_target = np.concatenate((Sx_clean, Sx_noise), axis=0) 60 | # print(clean) 61 | # print (Sx_clean.shape) 62 | 63 | processed_data.append([Sx_combined, Sx_clean, Sx_noise]) 64 | 65 | curr_batch += 1 66 | if curr_batch == batch_size: 67 | combined, clean, noise = zip(*processed_data) 68 | 69 | combined_padded = pad_data(combined) 70 | clean_padded = pad_data(clean) 71 | noise_padded = pad_data(noise) 72 | 73 | processed_data = np.array([combined_padded, clean_padded, noise_padded]) 74 | 75 | # np.savez_compressed('%sdata%d' % (OUTPUT_DIR, curr), processed_data) 76 | f = h5py.File('%sdata%d' % (OUTPUT_DIR, curr), 'w') 77 | f.create_dataset('data', data=processed_data, compression="gzip", compression_opts=9) 78 | print('Saved batch curr %d' % (curr)) 79 | processed_data = [] 80 | curr += 1 81 | curr_batch = 0 82 | 83 | print('Finished processing %d clean slice files' % (i + 1)) 84 | 85 | # np.savez_compressed('%sdata' % (OUTPUT_DIR), processed_data) 86 | 87 | # hkl.dump(processed_data, OUTPUT_DIR + 'data.hkl') 88 | -------------------------------------------------------------------------------- /data_scripts/test_scripts/create_test_batch.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from scipy.io import wavfile 4 | import matplotlib.pyplot as plt 5 | import os 6 | from scipy import signal 7 | import stft 8 | import pickle 9 | 10 | # def create_batch(input_data, target_data, batch_size): 11 | # input_batches = [] 12 | # target_batches = [] 13 | 14 | # for i in xrange(0, len(target_data), batch_size): 15 | # input_batches.append(input_data[i:i + batch_size]) 16 | # target_batches.append(target_data[i:i + batch_size]) 17 | 18 | # return input_batches, target_batches 19 | 20 | 21 | # if __name__ == '__main__': 22 | # DIR = '../../data/processed/' 23 | # data = h5py.File('%sdata%d' % (DIR, 5))['data'].value 24 | 25 | # combined, clean, noise = zip(data) 26 | # combined = combined[0] 27 | # clean = clean[0] 28 | # noise = noise[0] 29 | # target = np.concatenate((clean,noise), axis=2) 30 | 31 | # combined_batch, target_batch = create_batch(combined, target, 50) 32 | 33 | # f = h5py.File('%stest_batch' % (DIR), 'w') 34 | # f.create_dataset('combined_batch', data=combined_batch[0], compression="gzip", compression_opts=9) 35 | # f.create_dataset('target_batch', data=target_batch[0], compression="gzip", compression_opts=9) 36 | 37 | INPUT_NOISE_DIR = '../../data/raw_noise/' 38 | INPUT_CLEAN_DIR = '../../data/sliced_clean/' 39 | OUTPUT_DIR = '../../data/processed/' 40 | 41 | def pad_data(data): 42 | num_samples = len(data) 43 | # print(num_samples) 44 | max_rows_in_sample = max(len(data[i]) for i in xrange(num_samples)) 45 | # print([len(data[i]) for i in xrange(num_samples)]) 46 | # print(max_rows_in_sample) 47 | num_cols_in_row = data[0][0].size 48 | # print(num_cols_in_row) 49 | # print(data[0][0]) 50 | new_data = np.zeros((num_samples, max_rows_in_sample, num_cols_in_row)) 51 | for i, sample in enumerate(data): 52 | for j, row in enumerate(sample): 53 | num_rows = len(sample) 54 | for k, c in enumerate(row): 55 | new_data[i][max_rows_in_sample - num_rows + j][k] = c 56 | return new_data 57 | 58 | if __name__ == '__main__': 59 | processed_data = [] 60 | 61 | noise_data = [wavfile.read(INPUT_NOISE_DIR + noise)[1] for noise in os.listdir(INPUT_NOISE_DIR) if noise[-4:] == '.wav'] 62 | noise_data = noise_data[:5] 63 | 64 | batch_size = 50 65 | curr = 0 66 | curr_batch = 0 67 | 68 | for i, clean in enumerate(os.listdir(INPUT_CLEAN_DIR)): 69 | if i > 800: 70 | continue 71 | 72 | if clean[-4:] == '.wav': 73 | rate_clean, data_clean = wavfile.read(INPUT_CLEAN_DIR + clean) 74 | for noise in noise_data: 75 | data_noise = noise[:] 76 | 77 | length = len(data_clean) 78 | data_noise = data_noise[:length][:] 79 | 80 | data_combined = np.array([(s1/2 + s2/2) for (s1, s2) in zip(data_clean, data_noise)]) 81 | 82 | Sx_clean = stft.spectrogram(data_clean).transpose() / 100000 83 | Sx_noise = stft.spectrogram(data_noise).transpose() / 100000 84 | Sx_combined = stft.spectrogram(data_combined).transpose() / 100000 85 | 86 | # Sx_clean = pretty_spectrogram(data_clean.astype('float64'), fft_size=fft_size, step_size=step_size, thresh=spec_thresh) 87 | # Sx_noise = pretty_spectrogram(data_noise.astype('float64'), fft_size=fft_size, step_size=step_size, thresh=spec_thresh) 88 | # Sx_combined = pretty_spectrogram(data_combined.astype('float64'), fft_size=fft_size, step_size=step_size, thresh=spec_thresh) 89 | 90 | # Sx_target = np.concatenate((Sx_clean, Sx_noise), axis=0) 91 | # print(clean) 92 | # print (Sx_clean.shape) 93 | 94 | settings = Sx_combined.stft_settings 95 | orig_length = len(Sx_clean) 96 | settings['orig_length'] = orig_length 97 | 98 | processed_data.append([Sx_combined, Sx_clean, Sx_noise, Sx_combined.stft_settings]) 99 | 100 | curr_batch += 1 101 | if curr_batch == batch_size: 102 | combined, clean, noise, stft_settings = zip(*processed_data) 103 | stft_settings = list(stft_settings) 104 | 105 | combined_padded = pad_data(combined) 106 | clean_padded = pad_data(clean) 107 | noise_padded = pad_data(noise) 108 | 109 | processed_data = np.array([combined_padded, clean_padded, noise_padded]) 110 | 111 | f = h5py.File('%stest_batch' % (OUTPUT_DIR), 'w') 112 | f.create_dataset('data', data=processed_data, compression="gzip", compression_opts=9) 113 | 114 | with open('%stest_settings.pkl' % (OUTPUT_DIR), 'wb') as f: 115 | pickle.dump(stft_settings, f, pickle.HIGHEST_PROTOCOL) 116 | 117 | print('Finished processing %d clean slice files' % (i + 1)) 118 | break 119 | print('Finished processing %d clean slice files' % (i + 1)) 120 | -------------------------------------------------------------------------------- /data_scripts/preprocessing/split_train_dev.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../') 3 | import os 4 | import numpy as np 5 | from config import Config 6 | from create_wavefile import * 7 | import random 8 | import hickle as hkl 9 | 10 | DIR = '../../data/processed/' 11 | 12 | MAKE_SMALLER = True 13 | 14 | def create_batch(input_data, target_data, batch_size): 15 | input_batches = [] 16 | target_batches = [] 17 | 18 | for i in xrange(0, len(target_data), batch_size): 19 | input_batches.append(input_data[i:i + batch_size]) 20 | target_batches.append(target_data[i:i + batch_size]) 21 | 22 | return input_batches, target_batches 23 | 24 | def pad_data(data): 25 | num_samples = len(data) 26 | print(num_samples) 27 | max_rows_in_sample = max(len(data[i]) for i in xrange(num_samples)) 28 | print([len(data[i]) for i in xrange(num_samples)]) 29 | print(max_rows_in_sample) 30 | num_cols_in_row = data[0][0].size 31 | print(num_cols_in_row) 32 | print(data[0][0]) 33 | new_data = np.zeros((num_samples, max_rows_in_sample, num_cols_in_row)) 34 | for i, sample in enumerate(data): 35 | for j, row in enumerate(sample): 36 | num_rows = len(sample) 37 | print(row) 38 | for k, c in enumerate(row): 39 | new_data[i][max_rows_in_sample - num_rows + j][k] = c 40 | return new_data 41 | # padded_batches = 42 | # for i in xrange(len(data)): 43 | # batch = data[i] 44 | # max_len = max(s[0].size for s in batch) 45 | # padded_batch = [] 46 | # for s in batch: 47 | # if max_len - len(s) > 0: 48 | # padded_batch.append(np.pad(s, ((max_len - len(s[0]), 0),(0, 0)), 'constant')) 49 | # else: 50 | # padded_batch.append(s) 51 | # padded_batches.append(padded_batch) 52 | 53 | # return padded_batches 54 | 55 | if __name__ == '__main__': 56 | 57 | processed_data = [] 58 | for i in xrange(4): 59 | processed_data.append(np.load('%sdata%d.npz' % (DIR, i))) 60 | print('finished loading data') 61 | num_data = len(processed_data) 62 | 63 | ############################################################################### 64 | # preprocess for smaller data to get model working (BEGIN) 65 | ############################################################################### 66 | if MAKE_SMALLER: 67 | dev_ix = set(random.sample(xrange(num_data), num_data / 100)) 68 | processed_data = [l for i, l in enumerate(processed_data) if i in dev_ix] 69 | num_data = len(processed_data) 70 | ############################################################################### 71 | # preprocess for smaller data to get model working (END) 72 | ############################################################################### 73 | 74 | dev_ix = set(random.sample(xrange(num_data), num_data / 5)) 75 | 76 | processed_data = [np.transpose(s) for s in processed_data] 77 | 78 | inp, clean, noise = zip(*processed_data) 79 | padded_input = pad_data(inp) 80 | padded_clean = pad_data(clean) 81 | padded_noise = pad_data(noise) 82 | 83 | train_padded_input = [s for i, s in enumerate(padded_input) if i not in dev_ix] 84 | train_padded_clean = [s for i, s in enumerate(padded_clean) if i not in dev_ix] 85 | train_padded_noise = [s for i, s in enumerate(padded_noise) if i not in dev_ix] 86 | dev_padded_input = [s for i, s in enumerate(padded_input) if i in dev_ix] 87 | dev_padded_clean = [s for i, s in enumerate(padded_clean) if i in dev_ix] 88 | dev_padded_noise = [s for i, s in enumerate(padded_noise) if i in dev_ix] 89 | 90 | 91 | # train_input, train_clean, train_noise = zip(*train_data) 92 | # dev_input, dev_clean, dev_noise = zip(*dev_data) 93 | 94 | # train_padded_input = pad_data(train_input) 95 | # train_padded_clean = pad_data(train_clean) 96 | # train_padded_noise = pad_data(train_noise) 97 | # dev_padded_input = pad_data(dev_input) 98 | # dev_padded_clean = pad_data(dev_clean) 99 | # dev_padded_noise = pad_data(dev_noise) 100 | 101 | # print('train_padded_clean.shape: '% (train_padded_clean.shape)) 102 | 103 | train_target = np.concatenate((train_padded_clean, train_padded_noise), axis=1) 104 | dev_target = np.concatenate((dev_padded_clean, dev_padded_noise), axis=1) 105 | 106 | # print('train_target.shape: '% (train_targt.shape)) 107 | 108 | train_input_batches, train_target_batches = create_batch(train_padded_input, train_target, Config.batch_size) 109 | dev_input_batches, dev_target_batches = create_batch(dev_padded_input, dev_target, Config.batch_size) 110 | 111 | print(np.array(train_input_batches)) 112 | 113 | train_input_batch_name = 'train_input_batch' 114 | train_target_batch_name = 'train_target_batch' 115 | dev_input_batch_name = 'dev_input_batch' 116 | dev_target_batch_name = 'dev_target_batch' 117 | 118 | if MAKE_SMALLER: 119 | train_input_batch_name = 'smaller_' + train_input_batch_name 120 | train_target_batch_name = 'smaller_' + train_target_batch_name 121 | dev_input_batch_name = 'smaller_' + dev_input_batch_name 122 | dev_target_batch_name = 'smaller_' + dev_target_batch_name 123 | 124 | np.save(DIR + train_input_batch_name, train_input_batches) 125 | np.save(DIR + train_target_batch_name, train_target_batches) 126 | np.save(DIR + dev_input_batch_name, dev_input_batches) 127 | np.save(DIR + dev_target_batch_name, dev_target_batches) 128 | 129 | -------------------------------------------------------------------------------- /data_scripts/masking/test_masking.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import stft 3 | import numpy as np 4 | import numpy.linalg as la 5 | from numpy.linalg import inv 6 | from numpy.linalg import svd 7 | from scipy.io import wavfile 8 | from stft.types import SpectrogramArray 9 | import pickle 10 | import scipy 11 | from matplotlib import pyplot as plt 12 | 13 | raw1 = '../../data/sliced_clean/f10_script2_clean_113.wav' 14 | raw2 = "../../data/raw_noise/noise11_1.wav" 15 | merged = '../../data/test_combined/combined.wav' 16 | m_dir = "results/" 17 | separated_dir = "results/" 18 | 19 | # ASSUMPTION: len(spec.shape) <= 3 20 | def squeeze(spec): 21 | if len(spec.shape) > 2: 22 | spec = np.delete(spec, 1, axis=2) 23 | spec = spec.squeeze(2) 24 | return spec 25 | 26 | def createSpectrogram(arr, orig): 27 | x = SpectrogramArray(arr, stft_settings={ 28 | 'framelength': orig.stft_settings['framelength'], 29 | 'hopsize': orig.stft_settings['hopsize'], 30 | 'overlap': orig.stft_settings['overlap'], 31 | 'centered': orig.stft_settings['centered'], 32 | 'window': orig.stft_settings['window'], 33 | 'halved': orig.stft_settings['halved'], 34 | 'transform': orig.stft_settings['transform'], 35 | 'padding': orig.stft_settings['padding'], 36 | 'outlength': orig.stft_settings['outlength'], 37 | } 38 | ) 39 | return x 40 | 41 | def writeWav(fn, fs, data): 42 | data = data# * 1.5 / np.max(np.abs(data)) 43 | wavfile.write(fn, fs, data) 44 | 45 | 46 | 47 | def createMatrix(): 48 | # spectrogram_arguments = {'framelength': 512, 'overlap': 512, 'window': scipy.signal.hamming(512)} 49 | def saveFile(fn, data): 50 | f = open(fn, 'wb') 51 | pickle.dump(data, f) 52 | f.close() 53 | fs1, data1 = wavfile.read(raw1) 54 | fs2, data2 = wavfile.read(raw2) 55 | 56 | minlen = min(len(data1), len(data2)) 57 | data1 = data1[:minlen] 58 | data2 = data2[:minlen] 59 | 60 | spec1 = stft.spectrogram(data1) 61 | spec2 = stft.spectrogram(data2) 62 | 63 | # Reduce dimension 64 | spec1 = squeeze(spec1) 65 | spec2 = squeeze(spec2) 66 | 67 | # same dimensions 68 | a = np.zeros(spec1.shape) 69 | b = np.zeros(spec2.shape) 70 | 71 | # hard 72 | for i in range(len(spec1)): 73 | for j in range(len(spec1[0])): 74 | if abs(spec1[i][j]) < abs(spec2[i][j]): 75 | b[i][j] = 1.0 76 | else: 77 | a[i][j] = 1.0 78 | 79 | # soft 80 | # for i in range(len(spec1)): 81 | # for j in range(len(spec1[0])): 82 | # if (abs(spec1[i][j]) + abs(spec2[i][j])) == 0: 83 | # continue 84 | # a[i][j] = abs(spec1[i][j]) / (abs(spec1[i][j]) + abs(spec2[i][j])) 85 | # b[i][j] = abs(spec2[i][j]) / (abs(spec1[i][j]) + abs(spec2[i][j])) 86 | 87 | def plotfft(data, sr, ylim=None): 88 | plt.plot(np.abs(data)) 89 | if ylim != None: 90 | plt.ylim(ylim); 91 | plt.show() 92 | 93 | fs, data = wavfile.read(merged) 94 | spec = stft.spectrogram(data) 95 | spec = squeeze(spec) 96 | 97 | # ax1 = plt.subplot(211) 98 | time = np.arange(0, 7.6382, 0.0001) 99 | # plt.plot(time, data1) 100 | plt.xlim([0, 2]) 101 | # plt.subplot(212) 102 | Pxx, freqs, bins, im = plt.specgram(data, NFFT=200, Fs=fs, noverlap=100, cmap=plt.cm.gist_heat) 103 | plt.show() 104 | 105 | return 106 | 107 | output_a = createSpectrogram(np.multiply(a, spec), spec) 108 | output_b = createSpectrogram(np.multiply(b, spec), spec) 109 | 110 | output_a2 = stft.ispectrogram(output_a) 111 | output_b2 = stft.ispectrogram(output_b) 112 | 113 | writeWav(separated_dir + "a.wav", fs1, output_a2) 114 | writeWav(separated_dir + "b.wav", fs1, output_b2) 115 | 116 | return 117 | 118 | 119 | 120 | def divide(): 121 | def loadFile(fn): 122 | f = open(fn, 'rb') 123 | data = pickle.load(f) 124 | f.close() 125 | return data 126 | 127 | fs, data = wavfile.read(merged) 128 | spec = stft.spectrogram(data, framelength=512) 129 | spec = squeeze(spec) 130 | Ma = loadFile(m_dir + "M_" + raw1[:-4]) 131 | Mb = loadFile(m_dir + "M_" + raw2[:-4]) 132 | a = createSpectrogram(np.dot(Ma, spec), spec) 133 | b = createSpectrogram(np.dot(Mb, spec), spec) 134 | 135 | output_a = stft.ispectrogram(a) 136 | output_b = stft.ispectrogram(b) 137 | 138 | writeWav(separated_dir + "a.wav", fs, output_a) 139 | writeWav(separated_dir + "b.wav", fs, output_b) 140 | 141 | if __name__ == "__main__": 142 | # argparse later 143 | c = sys.argv[1] 144 | if c == "a": 145 | createMatrix() 146 | elif c == "b": 147 | divide() 148 | else: 149 | print "Unknown" 150 | 151 | def trash1(): 152 | def getMatrix(s, a, b): 153 | s_mat = np.zeros([a, b]) 154 | s_mat[:min(a,b), :min(a,b)] = np.diag(s) 155 | return s_mat 156 | 157 | def getInverseMatrix(s, a, b): 158 | s = inv(np.diag(s)).diagonal() 159 | s_mat = np.zeros([b, a]) 160 | s_mat[:min(a,b), :min(a,b)] = np.diag(s) 161 | return s_mat 162 | 163 | def originalValuesInverse(spec): 164 | u, s_lin, v = svd(spec) 165 | s = getInverseMatrix(s_lin, len(u), len(v)) 166 | return inv(u), s, inv(v) 167 | 168 | def newValues(spec): 169 | u, s_lin, v = svd(spec) 170 | s = getMatrix(s_lin, len(u), len(v)) 171 | return u, s, v 172 | 173 | ua, sa, va = originalValuesInverse(spec1) 174 | ub, sb, vb = originalValuesInverse(spec2) 175 | 176 | una, sna, vna = newValues(a) 177 | unb, snb, vnb = newValues(b) 178 | 179 | # M * A_orig = A_new 180 | Ma = np.dot(np.dot(np.dot(np.dot(np.dot(una, sna), vna), va), sa), ua) 181 | Mb = np.dot(np.dot(np.dot(np.dot(np.dot(unb, snb), vnb), vb), sb), ub) 182 | 183 | Ma = createSpectrogram(Ma, spec1) 184 | Mb = createSpectrogram(Mb, spec2) 185 | 186 | saveFile(m_dir + "M_" + raw1[:-4], Ma) 187 | saveFile(m_dir + "M_" + raw2[:-4], Mb) 188 | 189 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Compatibility imports 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import time 7 | import argparse 8 | import math 9 | import random 10 | import os 11 | # uncomment this line to suppress Tensorflow warnings 12 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 13 | import tensorflow as tf 14 | import numpy as np 15 | from six.moves import xrange as range 16 | 17 | from utils import * 18 | import pdb 19 | from time import gmtime, strftime 20 | 21 | from config import Config 22 | 23 | class SeparationModel(): 24 | """ 25 | Implements a recursive neural network with a single hidden layer attached to CTC loss. 26 | This network will predict a sequence of TIDIGITS (e.g. z1039) for a given audio wav file. 27 | """ 28 | 29 | def add_placeholders(self): 30 | """Generates placeholder variables to represent the input tensors 31 | 32 | These placeholders are used as inputs by the rest of the model building and will be fed 33 | data during training. Note that when "None" is in a placeholder's shape, it's flexible 34 | (so we can use different batch sizes without rebuilding the model). 35 | 36 | Adds following nodes to the computational graph: 37 | 38 | inputs_placeholder: Input placeholder tensor of shape (None, None, num_final_features), type tf.float32 39 | targets_placeholder: Sparse placeholder, type tf.int32. You don't need to specify shape dimension. 40 | seq_lens_placeholder: Sequence length placeholder tensor of shape (None), type tf.int32 41 | 42 | TODO: Add these placeholders to self as the instance variables 43 | self.inputs_placeholder 44 | self.targets_placeholder 45 | self.seq_lens_placeholder 46 | 47 | HINTS: 48 | - Use tf.sparse_placeholder(tf.int32) for targets_placeholder. This is required by TF's ctc_loss op. 49 | - Inputs is of shape [batch_size, max_timesteps, num_final_features], but we allow flexible sizes for 50 | batch_size and max_timesteps (hence the shape definition as [None, None, num_final_features]. 51 | 52 | (Don't change the variable names) 53 | """ 54 | self.inputs_placeholder = tf.placeholder(tf.float32, shape=(None, None, Config.num_final_features), name='inputs') 55 | self.targets_placeholder = tf.placeholder(tf.float32, shape=(None, None, Config.output_size), name='targets') 56 | 57 | def create_feed_dict(self, inputs_batch, targets_batch): 58 | """Creates the feed_dict for the digit recognizer. 59 | 60 | A feed_dict takes the form of: 61 | 62 | feed_dict = { 63 | : , 64 | .... 65 | } 66 | 67 | Hint: The keys for the feed_dict should be a subset of the placeholder 68 | tensors created in add_placeholders. 69 | 70 | Args: 71 | inputs_batch: A batch of input data. 72 | targets_batch: A batch of targets data. 73 | seq_lens_batch: A batch of seq_lens data. 74 | Returns: 75 | feed_dict: The feed dictionary mapping from placeholders to values. 76 | """ 77 | feed_dict = { 78 | self.inputs_placeholder: inputs_batch, 79 | self.targets_placeholder: targets_batch, 80 | } 81 | 82 | return feed_dict 83 | 84 | def add_prediction_op(self): 85 | """Applies a GRU RNN over the input data, then an affine layer projection. Steps to complete 86 | in this function: 87 | 88 | - Roll over inputs_placeholder with GRUCell, producing a Tensor of shape [batch_s, max_timestep, 89 | num_hidden]. 90 | - Apply a W * f + b transformation over the data, where f is each hidden layer feature. This 91 | should produce a Tensor of shape [batch_s, max_timesteps, num_classes]. Set this result to 92 | "logits". 93 | 94 | Remember: 95 | * Use the xavier initialization for matrices (W, but not b). 96 | * W should be shape [num_hidden, num_classes]. num_classes for our dataset is 12 97 | * tf.contrib.rnn.GRUCell, tf.contrib.rnn.MultiRNNCell and tf.nn.dynamic_rnn are of interest 98 | """ 99 | 100 | cell = None 101 | cell_bw = None 102 | if Config.num_layers > 1: 103 | # multi layer 104 | cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.GRUCell(Config.output_size, 105 | input_size=Config.num_final_features) for _ in range(Config.num_layers)], state_is_tuple=False) 106 | cell_bw = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.GRUCell(Config.output_size, 107 | input_size=Config.num_final_features) for _ in range(Config.num_layers)], state_is_tuple=False) 108 | else: 109 | cell = tf.contrib.rnn.GRUCell(Config.output_size, input_size=Config.num_final_features) 110 | cell_bw = tf.contrib.rnn.GRUCell(Config.output_size, input_size=Config.num_final_features) 111 | 112 | # output, state = tf.nn.dynamic_rnn(cell, self.inputs_placeholder, dtype=tf.float32) 113 | output, state = tf.nn.bidirectional_dynamic_rnn(cell, cell_bw, self.inputs_placeholder, dtype=tf.float32) 114 | output = output[0] 115 | 116 | # output_seq_length = tf.shape(output)[1] 117 | # last_output = output[:,output_seq_length - 1,:] 118 | 119 | self.output = output 120 | # self.output = tf.Print(self.output, [self.output, tf.shape(self.output)]) 121 | 122 | 123 | def add_loss_op(self, freq_weighted): 124 | l2_cost = 0.0 125 | 126 | weighted_differences = self.output - self.targets_placeholder 127 | 128 | num_freq_bins = Config.num_final_features 129 | frequencies = np.array([2.0 * 180 * i / (num_freq_bins - 1) * 22050 / 360 for i in xrange(num_freq_bins)]) 130 | frequencies[0] = 2.0 * 180 / (num_freq_bins - 1) / 2 * 22050 / 360 # 0th frequency threshold is computed at 3/4th of the frequency range 131 | ath_val = 3.64 * np.power(1000 / frequencies, 0.8) - 6.5 * np.exp(-0.6 * np.power(frequencies / 1000 - 3.3, 2)) + np.power(0.1, 3) * np.power(frequencies / 1000, 4) 132 | 133 | ath_shifted = (1 - np.amin(ath_val)) + ath_val # shift all ath vals so that min is 1 134 | weights = np.tile(1 / ath_shifted, 2) 135 | 136 | if freq_weighted: 137 | weighted_differences = weights * weighted_differences 138 | else: 139 | normalized = np.full(weights.shape, np.sqrt(np.sum(np.power(weights, 2)) / num_freq_bins)) 140 | weighted_differences = normalized * weighted_differences 141 | 142 | squared_error = tf.norm(weighted_differences, ord=2) 143 | self.loss = Config.l2_lambda * l2_cost + squared_error 144 | 145 | tf.summary.scalar("squared_error", squared_error) 146 | tf.summary.scalar("loss", self.loss) 147 | 148 | def add_training_op(self): 149 | """Sets up the training Ops. 150 | 151 | Creates an optimizer and applies the gradients to all trainable variables. The Op returned by this 152 | function is what must be passed to the `sess.run()` call to cause the model to train. See 153 | 154 | https://www.tensorflow.org/versions/r0.7/api_docs/python/train.html#Optimizer 155 | 156 | for more information. 157 | 158 | Use tf.train.AdamOptimizer for this model. Call optimizer.minimize() on self.loss. 159 | 160 | """ 161 | optimizer = None 162 | 163 | ### YOUR CODE HERE (~1-2 lines) 164 | optimizer = tf.train.AdamOptimizer(learning_rate=Config.lr).minimize(self.loss) 165 | ### END YOUR CODE 166 | 167 | self.optimizer = optimizer 168 | 169 | def add_summary_op(self): 170 | self.merged_summary_op = tf.summary.merge_all() 171 | 172 | 173 | # This actually builds the computational graph 174 | def build(self, freq_weighted): 175 | self.add_placeholders() 176 | self.add_prediction_op() 177 | self.add_loss_op(freq_weighted) 178 | self.add_training_op() 179 | self.add_summary_op() 180 | 181 | 182 | def train_on_batch(self, session, train_inputs_batch, train_targets_batch, train=True): 183 | feed = self.create_feed_dict(train_inputs_batch, train_targets_batch) 184 | output, batch_cost, summary = session.run([self.output, self.loss, self.merged_summary_op], feed) 185 | 186 | if math.isnan(batch_cost): # basically all examples in this batch have been skipped 187 | return 0 188 | if train: 189 | _ = session.run([self.optimizer], feed) 190 | 191 | return output, batch_cost, summary 192 | 193 | def print_results(self, train_inputs_batch, train_targets_batch): 194 | train_feed = self.create_feed_dict(train_inputs_batch, train_targets_batch) 195 | train_first_batch_preds = session.run(self.decoded_sequence, feed_dict=train_feed) 196 | compare_predicted_to_true(train_first_batch_preds, train_targets_batch) 197 | 198 | def __init__(self, freq_weighted=None): 199 | self.build(freq_weighted) 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import math 4 | import random 5 | import os 6 | import distutils.util 7 | # uncomment this line to suppress Tensorflow warnings 8 | # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 9 | import tensorflow as tf 10 | import numpy as np 11 | from six.moves import xrange as range 12 | from scipy.io import wavfile 13 | 14 | from utils import * 15 | import pdb 16 | from time import gmtime, strftime 17 | 18 | from config import Config 19 | from model import SeparationModel 20 | import h5py 21 | from stft.types import SpectrogramArray 22 | import stft 23 | 24 | from evaluate import bss_eval_sources 25 | import pickle 26 | 27 | import copy 28 | 29 | DIR = 'data/processed/' 30 | 31 | INPUT_NOISE_DIR = 'data/raw_noise/' 32 | INPUT_CLEAN_DIR = 'data/sliced_clean/' 33 | 34 | CLEAN_FILE = INPUT_CLEAN_DIR + 'f10_script2_clean_113.wav' 35 | NOISE_FILE = INPUT_NOISE_DIR + 'noise1_1.wav' 36 | 37 | def clean_data(data): 38 | # hack for now so that I don't have to preprocess again 39 | num_batches = len(data) 40 | print(num_batches) 41 | num_samples_in_batch = len(data[0]) 42 | print(num_samples_in_batch) 43 | num_rows_in_sample = len(data[0][0]) 44 | print(num_rows_in_sample) 45 | num_cols_in_row = len(data[0][0][0]) 46 | print(num_cols_in_row) 47 | new_data = np.zeros((num_batches, num_samples_in_batch, num_rows_in_sample, num_cols_in_row)) 48 | for i, batch in enumerate(data): 49 | for j, sample in enumerate(batch): 50 | for k, r in enumerate(sample): 51 | for l, c in enumerate(r): 52 | new_data[i][j][k][l] 53 | return new_data 54 | 55 | def create_batch(input_data, target_data, batch_size): 56 | input_batches = [] 57 | target_batches = [] 58 | 59 | for i in xrange(0, len(target_data), batch_size): 60 | input_batches.append(input_data[i:i + batch_size]) 61 | target_batches.append(target_data[i:i + batch_size]) 62 | 63 | return input_batches, target_batches 64 | 65 | def model_train(freq_weighted): 66 | logs_path = "tensorboard/" + strftime("%Y_%m_%d_%H_%M_%S", gmtime()) 67 | 68 | 69 | TESTING_MODE = True 70 | 71 | data = h5py.File('%sdata%d' % (DIR, 0))['data'].value 72 | np.append(data, h5py.File('%sdata%d' % (DIR, 1))['data'].value) 73 | 74 | combined, clean, noise = zip(data) 75 | combined = combined[0] 76 | clean = clean[0] 77 | noise = noise[0] 78 | 79 | target = np.concatenate((clean,noise), axis=2) 80 | 81 | num_data = len(combined) 82 | random.seed(1) 83 | dev_ix = set(random.sample(xrange(num_data), num_data / 5)) 84 | 85 | train_input = [s for i, s in enumerate(combined) if i not in dev_ix] 86 | train_target = [s for i, s in enumerate(target) if i not in dev_ix] 87 | dev_input = [s for i, s in enumerate(combined) if i in dev_ix] 88 | dev_target = [s for i, s in enumerate(target) if i in dev_ix] 89 | 90 | train_input_batch, train_target_batch = create_batch(train_input, train_target, Config.batch_size) 91 | dev_input_batch, dev_target_batch = create_batch(dev_input, dev_target, Config.batch_size) 92 | 93 | num_data = np.sum(len(batch) for batch in train_input_batch) 94 | num_batches_per_epoch = int(math.ceil(num_data / Config.batch_size)) 95 | num_dev_data = np.sum(len(batch) for batch in dev_input_batch) 96 | num_dev_batches_per_epoch = int(math.ceil(num_dev_data / Config.batch_size)) 97 | 98 | with tf.Graph().as_default(): 99 | model = SeparationModel(freq_weighted=freq_weighted) 100 | init = tf.global_variables_initializer() 101 | 102 | saver = tf.train.Saver(tf.trainable_variables()) 103 | 104 | with tf.Session() as session: 105 | session.run(init) 106 | 107 | # if args.load_from_file is not None: 108 | # new_saver = tf.train.import_meta_graph('%s.meta' % args.load_from_file, clear_devices=True) 109 | # new_saver.restore(session, args.load_from_file) 110 | 111 | train_writer = tf.summary.FileWriter(logs_path + '/train', session.graph) 112 | 113 | global_start = time.time() 114 | 115 | step_ii = 0 116 | 117 | for curr_epoch in range(Config.num_epochs): 118 | total_train_cost = 0 119 | total_train_examples = 0 120 | 121 | start = time.time() 122 | 123 | for batch in random.sample(range(num_batches_per_epoch), num_batches_per_epoch): 124 | cur_batch_size = len(train_target_batch[batch]) 125 | total_train_examples += cur_batch_size 126 | 127 | _, batch_cost, summary = model.train_on_batch(session, 128 | train_input_batch[batch], 129 | train_target_batch[batch], 130 | train=True) 131 | 132 | total_train_cost += batch_cost * cur_batch_size 133 | train_writer.add_summary(summary, step_ii) 134 | 135 | step_ii += 1 136 | 137 | train_cost = total_train_cost / total_train_examples 138 | 139 | num_dev_batches = len(dev_target_batch) 140 | total_batch_cost = 0 141 | total_batch_examples = 0 142 | 143 | # val_batch_cost, _ = model.train_on_batch(session, dev_input_batch[0], dev_target_batch[0], train=False) 144 | for batch in random.sample(range(num_dev_batches_per_epoch), num_dev_batches_per_epoch): 145 | cur_batch_size = len(dev_target_batch[batch]) 146 | total_batch_examples += cur_batch_size 147 | 148 | _, _val_batch_cost, _ = model.train_on_batch(session, dev_input_batch[batch], dev_target_batch[batch], train=False) 149 | 150 | total_batch_cost += cur_batch_size * _val_batch_cost 151 | 152 | 153 | val_batch_cost = None 154 | try: 155 | val_batch_cost = total_batch_cost / total_batch_examples 156 | except ZeroDivisionError: 157 | val_batch_cost = 0 158 | 159 | log = "Epoch {}/{}, train_cost = {:.3f}, val_cost = {:.3f}, time = {:.3f}" 160 | print( 161 | log.format(curr_epoch + 1, Config.num_epochs, train_cost, val_batch_cost, time.time() - start)) 162 | 163 | # if args.print_every is not None and (curr_epoch + 1) % args.print_every == 0: 164 | # batch_ii = 0 165 | # model.print_results(train_feature_minibatches[batch_ii], train_labels_minibatches[batch_ii]) 166 | 167 | if (curr_epoch + 1) % 10 == 0: 168 | checkpoint_name = 'checkpoints/%dlayer_%flr_model' % (Config.num_layers, Config.lr) 169 | if freq_weighted: 170 | checkpoint_name = checkpoint_name + '_freq_weighted' 171 | saver.save(session, checkpoint_name, global_step=curr_epoch + 1) 172 | 173 | 174 | def model_test(test_input): 175 | 176 | test_rate, test_audio = wavfile.read(test_input) 177 | clean_rate, clean_audio = wavfile.read(CLEAN_FILE) 178 | noise_rate, noise_audio = wavfile.read(NOISE_FILE) 179 | 180 | length = len(clean_audio) 181 | noise_audio = noise_audio[:length] 182 | 183 | clean_spec = stft.spectrogram(clean_audio) 184 | noise_spec = stft.spectrogram(noise_audio) 185 | test_spec = stft.spectrogram(test_audio) 186 | 187 | reverted_clean = stft.ispectrogram(clean_spec) 188 | reverted_noise = stft.ispectrogram(noise_spec) 189 | 190 | test_data = np.array([test_spec.transpose() / 100000]) # make data a batch of 1 191 | 192 | with tf.Graph().as_default(): 193 | model = SeparationModel() 194 | saver = tf.train.Saver(tf.trainable_variables()) 195 | 196 | with tf.Session() as session: 197 | ckpt = tf.train.get_checkpoint_state('checkpoints/') 198 | if ckpt: 199 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 200 | saver.restore(session, ckpt.model_checkpoint_path) 201 | else: 202 | print("Created model with fresh parameters.") 203 | session.run(tf.initialize_all_variables()) 204 | 205 | test_data_shape = np.shape(test_data) 206 | dummy_target = np.zeros((test_data_shape[0], test_data_shape[1], 2 * test_data_shape[2])) 207 | 208 | output, _, _ = model.train_on_batch(session, test_data, dummy_target, train=False) 209 | 210 | num_freq_bin = output.shape[2] / 2 211 | clean_output = output[0,:,:num_freq_bin] 212 | noise_output = output[0,:,num_freq_bin:] 213 | 214 | clean_mask, noise_mask = create_mask(clean_output, noise_output) 215 | 216 | clean_spec = createSpectrogram(np.multiply(clean_mask.transpose(), test_spec), test_spec.stft_settings) 217 | noise_spec = createSpectrogram(np.multiply(noise_mask.transpose(), test_spec), test_spec.stft_settings) 218 | 219 | clean_wav = stft.ispectrogram(clean_spec) 220 | noise_wav = stft.ispectrogram(noise_spec) 221 | 222 | sdr, sir, sar, _ = bss_eval_sources(np.array([reverted_clean, reverted_noise]), np.array([clean_wav, noise_wav]), False) 223 | print(sdr, sir, sar) 224 | 225 | writeWav('data/test_combined/output_clean.wav', 44100, clean_wav) 226 | writeWav('data/test_combined/output_noise.wav', 44100, noise_wav) 227 | 228 | def model_batch_test(): 229 | 230 | test_batch = h5py.File('%stest_batch' % (DIR)) 231 | data = test_batch['data'].value 232 | 233 | with open('%stest_settings.pkl' % (DIR), 'rb') as f: 234 | settings = pickle.load(f) 235 | 236 | # print(settings[:2]) 237 | 238 | combined, clean, noise = zip(data) 239 | combined = combined[0] 240 | clean = clean[0] 241 | noise = noise[0] 242 | target = np.concatenate((clean,noise), axis=2) 243 | 244 | # test_rate, test_audio = wavfile.read('data/test_combined/combined.wav') 245 | # test_spec = stft.spectrogram(test_audio) 246 | 247 | combined_batch, target_batch = create_batch(combined, target, 50) 248 | 249 | original_combined_batch = [copy.deepcopy(batch) for batch in combined_batch] 250 | 251 | with tf.Graph().as_default(): 252 | model = SeparationModel() 253 | saver = tf.train.Saver(tf.trainable_variables()) 254 | 255 | with tf.Session() as session: 256 | ckpt = tf.train.get_checkpoint_state('checkpoints/') 257 | if ckpt: 258 | print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 259 | saver.restore(session, ckpt.model_checkpoint_path) 260 | else: 261 | print("Created model with fresh parameters.") 262 | session.run(tf.initialize_all_variables()) 263 | 264 | curr_mask_array = [] 265 | prev_mask_array = None 266 | diff = float('inf') 267 | iters = 0 268 | 269 | while True: 270 | iters += 1 271 | output, _, _ = model.train_on_batch(session, combined_batch[0], target_batch[0], train=False) 272 | 273 | num_freq_bin = output.shape[2] / 2 274 | clean_outputs = output[:,:,:num_freq_bin] 275 | noise_outputs = output[:,:,num_freq_bin:] 276 | 277 | # clean = [target[:,:num_freq_bin] for target in target_batch] 278 | # noise = [target[:,num_freq_bin:] for target in target_batch] 279 | 280 | num_outputs = len(clean_outputs) 281 | 282 | results = [] 283 | 284 | for i in xrange(num_outputs): 285 | orig_clean_output = clean_outputs[i] 286 | orig_noise_output = noise_outputs[i] 287 | 288 | stft_settings = copy.deepcopy(settings[i]) 289 | orig_length = stft_settings['orig_length'] 290 | stft_settings.pop('orig_length', None) 291 | clean_output = orig_clean_output[-orig_length:] 292 | noise_output = orig_noise_output[-orig_length:] 293 | 294 | clean_mask, noise_mask = create_mask(clean_output, noise_output) 295 | orig_clean_mask, orig_noise_mask = create_mask(orig_clean_output, orig_noise_output) 296 | 297 | curr_mask_array.append(clean_mask) 298 | # if i == 0: 299 | # print clean_mask[10:20,10:20] 300 | curr_mask_array.append(noise_mask) 301 | 302 | clean_spec = createSpectrogram(np.multiply(clean_mask.transpose(), original_combined_batch[0][i][-orig_length:].transpose()), settings[i]) 303 | noise_spec = createSpectrogram(np.multiply(noise_mask.transpose(), original_combined_batch[0][i][-orig_length:].transpose()), settings[i]) 304 | 305 | # print '-' * 20 306 | # print original_combined_batch[0][i] 307 | # print '=' * 20 308 | combined_batch[0][i] += np.multiply(orig_clean_mask, original_combined_batch[0][i]) * 0.1 309 | # print combined_batch[0][i] 310 | # print '=' * 20 311 | # print original_combined_batch[0][i] 312 | # print '-' * 20 313 | 314 | estimated_clean_wav = stft.ispectrogram(clean_spec) 315 | estimated_noise_wav = stft.ispectrogram(noise_spec) 316 | 317 | reference_clean_wav = stft.ispectrogram(SpectrogramArray(clean[i][-orig_length:], stft_settings).transpose()) 318 | reference_noise_wav = stft.ispectrogram(SpectrogramArray(noise[i][-orig_length:], stft_settings).transpose()) 319 | 320 | try: 321 | sdr, sir, sar, _ = bss_eval_sources(np.array([reference_clean_wav, reference_noise_wav]), np.array([estimated_clean_wav, estimated_noise_wav]), False) 322 | results.append((sdr[0], sdr[1], sir[0], sir[1], sar[0], sar[1])) 323 | # print('%f, %f, %f, %f, %f, %f' % (sdr[0], sdr[1], sir[0], sir[1], sar[0], sar[1])) 324 | except ValueError: 325 | print('error') 326 | continue 327 | break 328 | 329 | # diff = 1 330 | # if prev_mask_array is not None: 331 | # # print curr_mask_array[0] 332 | # # print prev_mask_array[0] 333 | # diff = sum(np.sum(np.abs(curr_mask_array[i] - prev_mask_array[i])) for i in xrange(len(prev_mask_array))) 334 | # print('Changes after iteration %d: %d' % (iters, diff)) 335 | 336 | # sdr_cleans, sdr_noises, sir_cleans, sir_noises, sar_cleans, sar_noises = zip(*results) 337 | # print('Avg sdr_cleans: %f, sdr_noises: %f, sir_cleans: %f, sir_noises: %f, sar_cleans: %f, sar_noises: %f' % (np.mean(sdr_cleans), np.mean(sdr_noises), np.mean(sir_cleans), np.mean(sir_noises), np.mean(sar_cleans), np.mean(sar_noises))) 338 | 339 | # prev_mask_array = [copy.deepcopy(mask[:,:]) for mask in curr_mask_array] 340 | 341 | # if diff == 0: 342 | # break 343 | 344 | results_filename = '%sresults_%d_%f' % ('data/results/', Config.num_layers, Config.lr) 345 | # results_filename += 'freq_weighted' 346 | 347 | with open(results_filename + '.csv', 'w+') as f: 348 | for sdr_1, sdr_2, sir_1, sir_2, sar_1, sar_2 in results: 349 | f.write('%f,%f,%f,%f,%f,%f\n' % (sdr_1, sdr_2, sir_1, sir_2, sar_1, sar_2)) 350 | 351 | # f = h5py.File(results_filename, 'w') 352 | # f.create_dataset('result', data=results, compression="gzip", compression_opts=9) 353 | 354 | 355 | def writeWav(fn, fs, data): 356 | data = data * 1.5 / np.max(np.abs(data)) 357 | wavfile.write(fn, fs, data) 358 | 359 | 360 | def create_mask(clean_output, noise_output, hard=True): 361 | clean_mask = np.zeros(clean_output.shape) 362 | noise_mask = np.zeros(noise_output.shape) 363 | 364 | if hard: 365 | for i in range(len(clean_output)): 366 | for j in range(len(clean_output[0])): 367 | if abs(clean_output[i][j]) < abs(noise_output[i][j]): 368 | noise_mask[i][j] = 1.0 369 | else: 370 | clean_mask[i][j] = 1.0 371 | else: 372 | for i in range(len(clean_output)): 373 | for j in range(len(clean_output[0])): 374 | clean_mask[i][j] = abs(clean_output[i][j]) / (abs(clean_output[i][j]) + abs(noise_output[i][j])) 375 | noise_mask[i][j] = abs(noise_output[i][j]) / (abs(clean_output[i][j]) + abs(noise_output[i][j])) 376 | 377 | 378 | return clean_mask, noise_mask 379 | 380 | def createSpectrogram(arr, settings): 381 | x = SpectrogramArray(arr, stft_settings={ 382 | 'framelength': settings['framelength'], 383 | 'hopsize': settings['hopsize'], 384 | 'overlap': settings['overlap'], 385 | 'centered': settings['centered'], 386 | 'window': settings['window'], 387 | 'halved': settings['halved'], 388 | 'transform': settings['transform'], 389 | 'padding': settings['padding'], 390 | 'outlength': settings['outlength'], 391 | } 392 | ) 393 | return x 394 | 395 | if __name__ == "__main__": 396 | parser = argparse.ArgumentParser() 397 | parser.add_argument('--train', nargs='?', default=True, type=distutils.util.strtobool) 398 | parser.add_argument('--test_single_input', nargs='?', default='data/test_combined/combined.wav', type=str) 399 | parser.add_argument('--freq_weighted', nargs='?', default=True, type=distutils.util.strtobool) 400 | parser.add_argument('--test_batch', nargs='?', default=False, type=distutils.util.strtobool) 401 | args = parser.parse_args() 402 | 403 | if args.test_batch: 404 | model_batch_test() 405 | elif args.train: 406 | model_train(args.freq_weighted) 407 | else: 408 | model_test(args.test_single_input) 409 | 410 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Source separation algorithms attempt to extract recordings of individual 4 | sources from a recording of a mixture of sources. Evaluation methods for 5 | source separation compare the extracted sources from reference sources and 6 | attempt to measure the perceptual quality of the separation. 7 | See also the bss_eval MATLAB toolbox: 8 | http://bass-db.gforge.inria.fr/bss_eval/ 9 | Conventions 10 | ----------- 11 | An audio signal is expected to be in the format of a 1-dimensional array where 12 | the entries are the samples of the audio signal. When providing a group of 13 | estimated or reference sources, they should be provided in a 2-dimensional 14 | array, where the first dimension corresponds to the source number and the 15 | second corresponds to the samples. 16 | Metrics 17 | ------- 18 | * :func:`mir_eval.separation.bss_eval_sources`: Computes the bss_eval_sources 19 | metrics from bss_eval, which optionally optimally match the estimated sources 20 | to the reference sources and measure the distortion and artifacts present in 21 | the estimated sources as well as the interference between them. 22 | * :func:`mir_eval.separation.bss_eval_sources_framewise`: Computes the 23 | bss_eval_sources metrics on a frame-by-frame basis. 24 | * :func:`mir_eval.separation.bss_eval_images`: Computes the bss_eval_images 25 | metrics from bss_eval, which includes the metrics in 26 | :func:`mir_eval.separation.bss_eval_sources` plus the image to spatial 27 | distortion ratio. 28 | * :func:`mir_eval.separation.bss_eval_images_framewise`: Computes the 29 | bss_eval_images metrics on a frame-by-frame basis. 30 | References 31 | ---------- 32 | .. [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric 33 | Févotte, "Performance measurement in blind audio source separation," IEEE 34 | Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006. 35 | ''' 36 | 37 | import numpy as np 38 | import scipy.fftpack 39 | from scipy.linalg import toeplitz 40 | from scipy.signal import fftconvolve 41 | import collections 42 | import itertools 43 | import warnings 44 | 45 | 46 | # The maximum allowable number of sources (prevents insane computational load) 47 | MAX_SOURCES = 100 48 | 49 | 50 | def validate(reference_sources, estimated_sources): 51 | """Checks that the input data to a metric are valid, and throws helpful 52 | errors if not. 53 | Parameters 54 | ---------- 55 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 56 | matrix containing true sources 57 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 58 | matrix containing estimated sources 59 | """ 60 | 61 | if reference_sources.shape != estimated_sources.shape: 62 | raise ValueError('The shape of estimated sources and the true ' 63 | 'sources should match. reference_sources.shape ' 64 | '= {}, estimated_sources.shape ' 65 | '= {}'.format(reference_sources.shape, 66 | estimated_sources.shape)) 67 | 68 | if reference_sources.ndim > 3 or estimated_sources.ndim > 3: 69 | raise ValueError('The number of dimensions is too high (must be less ' 70 | 'than 3). reference_sources.ndim = {}, ' 71 | 'estimated_sources.ndim ' 72 | '= {}'.format(reference_sources.ndim, 73 | estimated_sources.ndim)) 74 | 75 | if reference_sources.size == 0: 76 | warnings.warn("reference_sources is empty, should be of size " 77 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 78 | "be empty np.ndarrays") 79 | elif _any_source_silent(reference_sources): 80 | raise ValueError('All the reference sources should be non-silent (not ' 81 | 'all-zeros), but at least one of the reference ' 82 | 'sources is all 0s, which introduces ambiguity to the' 83 | ' evaluation. (Otherwise we can add infinitely many ' 84 | 'all-zero sources.)') 85 | 86 | if estimated_sources.size == 0: 87 | warnings.warn("estimated_sources is empty, should be of size " 88 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 89 | "be empty np.ndarrays") 90 | elif _any_source_silent(estimated_sources): 91 | raise ValueError('All the estimated sources should be non-silent (not ' 92 | 'all-zeros), but at least one of the estimated ' 93 | 'sources is all 0s. Since we require each reference ' 94 | 'source to be non-silent, having a silent estimated ' 95 | 'source will result in an underdetermined system.') 96 | 97 | if (estimated_sources.shape[0] > MAX_SOURCES or 98 | reference_sources.shape[0] > MAX_SOURCES): 99 | raise ValueError('The supplied matrices should be of shape (nsrc,' 100 | ' nsampl) but reference_sources.shape[0] = {} and ' 101 | 'estimated_sources.shape[0] = {} which is greater ' 102 | 'than mir_eval.separation.MAX_SOURCES = {}. To ' 103 | 'override this check, set ' 104 | 'mir_eval.separation.MAX_SOURCES to a ' 105 | 'larger value.'.format(reference_sources.shape[0], 106 | estimated_sources.shape[0], 107 | MAX_SOURCES)) 108 | 109 | 110 | def _any_source_silent(sources): 111 | """Returns true if the parameter sources has any silent first dimensions""" 112 | return np.any(np.all(np.sum( 113 | sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1)) 114 | 115 | 116 | def bss_eval_sources(reference_sources, estimated_sources, 117 | compute_permutation=True): 118 | """ 119 | Ordering and measurement of the separation quality for estimated source 120 | signals in terms of filtered true source, interference and artifacts. 121 | The decomposition allows a time-invariant filter distortion of length 122 | 512, as described in Section III.B of [#vincent2006performance]_. 123 | Passing ``False`` for ``compute_permutation`` will improve the computation 124 | performance of the evaluation; however, it is not always appropriate and 125 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_sources. 126 | Examples 127 | -------- 128 | >>> # reference_sources[n] should be an ndarray of samples of the 129 | >>> # n'th reference source 130 | >>> # estimated_sources[n] should be the same for the n'th estimated 131 | >>> # source 132 | >>> (sdr, sir, sar, 133 | ... perm) = mir_eval.separation.bss_eval_sources(reference_sources, 134 | ... estimated_sources) 135 | Parameters 136 | ---------- 137 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 138 | matrix containing true sources (must have same shape as 139 | estimated_sources) 140 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 141 | matrix containing estimated sources (must have same shape as 142 | reference_sources) 143 | compute_permutation : bool, optional 144 | compute permutation of estimate/source combinations (True by default) 145 | Returns 146 | ------- 147 | sdr : np.ndarray, shape=(nsrc,) 148 | vector of Signal to Distortion Ratios (SDR) 149 | sir : np.ndarray, shape=(nsrc,) 150 | vector of Source to Interference Ratios (SIR) 151 | sar : np.ndarray, shape=(nsrc,) 152 | vector of Sources to Artifacts Ratios (SAR) 153 | perm : np.ndarray, shape=(nsrc,) 154 | vector containing the best ordering of estimated sources in 155 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 156 | true source number ``j``). Note: ``perm`` will be ``[0, 1, ..., 157 | nsrc-1]`` if ``compute_permutation`` is ``False``. 158 | References 159 | ---------- 160 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 161 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 162 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 163 | (2007-2010): Achievements and remaining challenges", Signal Processing, 164 | 92, pp. 1928-1936, 2012. 165 | """ 166 | 167 | # make sure the input is of shape (nsrc, nsampl) 168 | if estimated_sources.ndim == 1: 169 | estimated_sources = estimated_sources[np.newaxis, :] 170 | if reference_sources.ndim == 1: 171 | reference_sources = reference_sources[np.newaxis, :] 172 | 173 | validate(reference_sources, estimated_sources) 174 | # If empty matrices were supplied, return empty lists (special case) 175 | if reference_sources.size == 0 or estimated_sources.size == 0: 176 | return np.array([]), np.array([]), np.array([]), np.array([]) 177 | 178 | nsrc = estimated_sources.shape[0] 179 | 180 | # does user desire permutations? 181 | if compute_permutation: 182 | # compute criteria for all possible pair matches 183 | sdr = np.empty((nsrc, nsrc)) 184 | sir = np.empty((nsrc, nsrc)) 185 | sar = np.empty((nsrc, nsrc)) 186 | for jest in range(nsrc): 187 | for jtrue in range(nsrc): 188 | s_true, e_spat, e_interf, e_artif = \ 189 | _bss_decomp_mtifilt(reference_sources, 190 | estimated_sources[jest], 191 | jtrue, 512) 192 | sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \ 193 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 194 | 195 | # select the best ordering 196 | perms = list(itertools.permutations(list(range(nsrc)))) 197 | mean_sir = np.empty(len(perms)) 198 | dum = np.arange(nsrc) 199 | for (i, perm) in enumerate(perms): 200 | mean_sir[i] = np.mean(sir[perm, dum]) 201 | popt = perms[np.argmax(mean_sir)] 202 | idx = (popt, dum) 203 | return (sdr[idx], sir[idx], sar[idx], np.asarray(popt)) 204 | else: 205 | # compute criteria for only the simple correspondence 206 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 207 | sdr = np.empty(nsrc) 208 | sir = np.empty(nsrc) 209 | sar = np.empty(nsrc) 210 | for j in range(nsrc): 211 | s_true, e_spat, e_interf, e_artif = \ 212 | _bss_decomp_mtifilt(reference_sources, 213 | estimated_sources[j], 214 | j, 512) 215 | sdr[j], sir[j], sar[j] = \ 216 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 217 | 218 | # return the default permutation for compatibility 219 | popt = np.arange(nsrc) 220 | return (sdr, sir, sar, popt) 221 | 222 | 223 | def bss_eval_sources_framewise(reference_sources, estimated_sources, 224 | window=30*44100, hop=15*44100, 225 | compute_permutation=False): 226 | """Framewise computation of bss_eval_sources 227 | Please be aware that this function does not compute permutations (by 228 | default) on the possible relations between reference_sources and 229 | estimated_sources due to the dangers of a changing permutation. Therefore 230 | (by default), it assumes that ``reference_sources[i]`` corresponds to 231 | ``estimated_sources[i]``. To enable computing permutations please set 232 | ``compute_permutation`` to be ``True`` and check that the returned ``perm`` 233 | is identical for all windows. 234 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 235 | using only a single window or are shorter than the window length, the 236 | result of :func:`mir_eval.separation.bss_eval_sources` called on 237 | ``reference_sources`` and ``estimated_sources`` (with the 238 | ``compute_permutation`` parameter passed to 239 | :func:`mir_eval.separation.bss_eval_sources`) is returned. 240 | Examples 241 | -------- 242 | >>> # reference_sources[n] should be an ndarray of samples of the 243 | >>> # n'th reference source 244 | >>> # estimated_sources[n] should be the same for the n'th estimated 245 | >>> # source 246 | >>> (sdr, sir, sar, 247 | ... perm) = mir_eval.separation.bss_eval_sources_framewise( 248 | reference_sources, 249 | ... estimated_sources) 250 | Parameters 251 | ---------- 252 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 253 | matrix containing true sources (must have the same shape as 254 | ``estimated_sources``) 255 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 256 | matrix containing estimated sources (must have the same shape as 257 | ``reference_sources``) 258 | window : int, optional 259 | Window length for framewise evaluation (default value is 30s at a 260 | sample rate of 44.1kHz) 261 | hop : int, optional 262 | Hop size for framewise evaluation (default value is 15s at a 263 | sample rate of 44.1kHz) 264 | compute_permutation : bool, optional 265 | compute permutation of estimate/source combinations for all windows 266 | (False by default) 267 | Returns 268 | ------- 269 | sdr : np.ndarray, shape=(nsrc, nframes) 270 | vector of Signal to Distortion Ratios (SDR) 271 | sir : np.ndarray, shape=(nsrc, nframes) 272 | vector of Source to Interference Ratios (SIR) 273 | sar : np.ndarray, shape=(nsrc, nframes) 274 | vector of Sources to Artifacts Ratios (SAR) 275 | perm : np.ndarray, shape=(nsrc, nframes) 276 | vector containing the best ordering of estimated sources in 277 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 278 | true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for 279 | all windows if ``compute_permutation`` is ``False`` 280 | """ 281 | 282 | # make sure the input is of shape (nsrc, nsampl) 283 | if estimated_sources.ndim == 1: 284 | estimated_sources = estimated_sources[np.newaxis, :] 285 | if reference_sources.ndim == 1: 286 | reference_sources = reference_sources[np.newaxis, :] 287 | 288 | validate(reference_sources, estimated_sources) 289 | # If empty matrices were supplied, return empty lists (special case) 290 | if reference_sources.size == 0 or estimated_sources.size == 0: 291 | return np.array([]), np.array([]), np.array([]), np.array([]) 292 | 293 | nsrc = reference_sources.shape[0] 294 | 295 | nwin = int( 296 | np.floor((reference_sources.shape[1] - window + hop) / hop) 297 | ) 298 | # if fewer than 2 windows would be evaluated, return the sources result 299 | if nwin < 2: 300 | result = bss_eval_sources(reference_sources, 301 | estimated_sources, 302 | compute_permutation) 303 | return [np.expand_dims(score, -1) for score in result] 304 | 305 | # compute the criteria across all windows 306 | sdr = np.empty((nsrc, nwin)) 307 | sir = np.empty((nsrc, nwin)) 308 | sar = np.empty((nsrc, nwin)) 309 | perm = np.empty((nsrc, nwin)) 310 | 311 | # k iterates across all the windows 312 | for k in range(nwin): 313 | win_slice = slice(k * hop, k * hop + window) 314 | ref_slice = reference_sources[:, win_slice] 315 | est_slice = estimated_sources[:, win_slice] 316 | # check for a silent frame 317 | if (not _any_source_silent(ref_slice) and 318 | not _any_source_silent(est_slice)): 319 | sdr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_sources( 320 | ref_slice, est_slice, compute_permutation 321 | ) 322 | else: 323 | # if we have a silent frame set results as np.nan 324 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 325 | 326 | return sdr, sir, sar, perm 327 | 328 | 329 | def bss_eval_images(reference_sources, estimated_sources, 330 | compute_permutation=True): 331 | """Implementation of the bss_eval_images function from the 332 | BSS_EVAL Matlab toolbox. 333 | Ordering and measurement of the separation quality for estimated source 334 | signals in terms of filtered true source, interference and artifacts. 335 | This method also provides the ISR measure. 336 | The decomposition allows a time-invariant filter distortion of length 337 | 512, as described in Section III.B of [#vincent2006performance]_. 338 | Passing ``False`` for ``compute_permutation`` will improve the computation 339 | performance of the evaluation; however, it is not always appropriate and 340 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_images. 341 | Examples 342 | -------- 343 | >>> # reference_sources[n] should be an ndarray of samples of the 344 | >>> # n'th reference source 345 | >>> # estimated_sources[n] should be the same for the n'th estimated 346 | >>> # source 347 | >>> (sdr, isr, sir, sar, 348 | ... perm) = mir_eval.separation.bss_eval_images(reference_sources, 349 | ... estimated_sources) 350 | Parameters 351 | ---------- 352 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 353 | matrix containing true sources 354 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 355 | matrix containing estimated sources 356 | compute_permutation : bool, optional 357 | compute permutation of estimate/source combinations (True by default) 358 | Returns 359 | ------- 360 | sdr : np.ndarray, shape=(nsrc,) 361 | vector of Signal to Distortion Ratios (SDR) 362 | isr : np.ndarray, shape=(nsrc,) 363 | vector of source Image to Spatial distortion Ratios (ISR) 364 | sir : np.ndarray, shape=(nsrc,) 365 | vector of Source to Interference Ratios (SIR) 366 | sar : np.ndarray, shape=(nsrc,) 367 | vector of Sources to Artifacts Ratios (SAR) 368 | perm : np.ndarray, shape=(nsrc,) 369 | vector containing the best ordering of estimated sources in 370 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 371 | true source number ``j``). Note: ``perm`` will be ``(1,2,...,nsrc)`` 372 | if ``compute_permutation`` is ``False``. 373 | References 374 | ---------- 375 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 376 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 377 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 378 | (2007-2010): Achievements and remaining challenges", Signal Processing, 379 | 92, pp. 1928-1936, 2012. 380 | """ 381 | 382 | # make sure the input has 3 dimensions 383 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 384 | estimated_sources = np.atleast_3d(estimated_sources) 385 | reference_sources = np.atleast_3d(reference_sources) 386 | # we will ensure input doesn't have more than 3 dimensions in validate 387 | 388 | validate(reference_sources, estimated_sources) 389 | # If empty matrices were supplied, return empty lists (special case) 390 | if reference_sources.size == 0 or estimated_sources.size == 0: 391 | return np.array([]), np.array([]), np.array([]), \ 392 | np.array([]), np.array([]) 393 | 394 | # determine size parameters 395 | nsrc = estimated_sources.shape[0] 396 | nsampl = estimated_sources.shape[1] 397 | nchan = estimated_sources.shape[2] 398 | 399 | # does the user desire permutation? 400 | if compute_permutation: 401 | # compute criteria for all possible pair matches 402 | sdr = np.empty((nsrc, nsrc)) 403 | isr = np.empty((nsrc, nsrc)) 404 | sir = np.empty((nsrc, nsrc)) 405 | sar = np.empty((nsrc, nsrc)) 406 | for jest in range(nsrc): 407 | for jtrue in range(nsrc): 408 | s_true, e_spat, e_interf, e_artif = \ 409 | _bss_decomp_mtifilt_images( 410 | reference_sources, 411 | np.reshape( 412 | estimated_sources[jest], 413 | (nsampl, nchan), 414 | order='F' 415 | ), 416 | jtrue, 417 | 512 418 | ) 419 | sdr[jest, jtrue], isr[jest, jtrue], \ 420 | sir[jest, jtrue], sar[jest, jtrue] = \ 421 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 422 | 423 | # select the best ordering 424 | perms = list(itertools.permutations(range(nsrc))) 425 | mean_sir = np.empty(len(perms)) 426 | dum = np.arange(nsrc) 427 | for (i, perm) in enumerate(perms): 428 | mean_sir[i] = np.mean(sir[perm, dum]) 429 | popt = perms[np.argmax(mean_sir)] 430 | idx = (popt, dum) 431 | return (sdr[idx], isr[idx], sir[idx], sar[idx], np.asarray(popt)) 432 | else: 433 | # compute criteria for only the simple correspondence 434 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 435 | sdr = np.empty(nsrc) 436 | isr = np.empty(nsrc) 437 | sir = np.empty(nsrc) 438 | sar = np.empty(nsrc) 439 | Gj = [0] * nsrc # prepare G matrics with zeroes 440 | G = np.zeros(1) 441 | for j in range(nsrc): 442 | # save G matrix to avoid recomputing it every call 443 | s_true, e_spat, e_interf, e_artif, Gj_temp, G = \ 444 | _bss_decomp_mtifilt_images(reference_sources, 445 | np.reshape(estimated_sources[j], 446 | (nsampl, nchan), 447 | order='F'), 448 | j, 512, Gj[j], G) 449 | Gj[j] = Gj_temp 450 | sdr[j], isr[j], sir[j], sar[j] = \ 451 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 452 | 453 | # return the default permutation for compatibility 454 | popt = np.arange(nsrc) 455 | return (sdr, isr, sir, sar, popt) 456 | 457 | 458 | def bss_eval_images_framewise(reference_sources, estimated_sources, 459 | window=30*44100, hop=15*44100, 460 | compute_permutation=False): 461 | """Framewise computation of bss_eval_images 462 | Please be aware that this function does not compute permutations (by 463 | default) on the possible relations between ``reference_sources`` and 464 | ``estimated_sources`` due to the dangers of a changing permutation. 465 | Therefore (by default), it assumes that ``reference_sources[i]`` 466 | corresponds to ``estimated_sources[i]``. To enable computing permutations 467 | please set ``compute_permutation`` to be ``True`` and check that the 468 | returned ``perm`` is identical for all windows. 469 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 470 | using only a single window or are shorter than the window length, the 471 | result of ``bss_eval_sources`` called on ``reference_sources`` and 472 | ``estimated_sources`` (with the ``compute_permutation`` parameter passed to 473 | ``bss_eval_images``) is returned 474 | Examples 475 | -------- 476 | >>> # reference_sources[n] should be an ndarray of samples of the 477 | >>> # n'th reference source 478 | >>> # estimated_sources[n] should be the same for the n'th estimated 479 | >>> # source 480 | >>> (sdr, isr, sir, sar, 481 | ... perm) = mir_eval.separation.bss_eval_images_framewise( 482 | reference_sources, 483 | ... estimated_sources, 484 | window, 485 | .... hop) 486 | Parameters 487 | ---------- 488 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 489 | matrix containing true sources (must have the same shape as 490 | ``estimated_sources``) 491 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 492 | matrix containing estimated sources (must have the same shape as 493 | ``reference_sources``) 494 | window : int 495 | Window length for framewise evaluation 496 | hop : int 497 | Hop size for framewise evaluation 498 | compute_permutation : bool, optional 499 | compute permutation of estimate/source combinations for all windows 500 | (False by default) 501 | Returns 502 | ------- 503 | sdr : np.ndarray, shape=(nsrc, nframes) 504 | vector of Signal to Distortion Ratios (SDR) 505 | isr : np.ndarray, shape=(nsrc, nframes) 506 | vector of source Image to Spatial distortion Ratios (ISR) 507 | sir : np.ndarray, shape=(nsrc, nframes) 508 | vector of Source to Interference Ratios (SIR) 509 | sar : np.ndarray, shape=(nsrc, nframes) 510 | vector of Sources to Artifacts Ratios (SAR) 511 | perm : np.ndarray, shape=(nsrc, nframes) 512 | vector containing the best ordering of estimated sources in 513 | the mean SIR sense (estimated source number perm[j] corresponds to 514 | true source number j) 515 | Note: perm will be range(nsrc) for all windows if compute_permutation 516 | is False 517 | """ 518 | 519 | # make sure the input has 3 dimensions 520 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 521 | estimated_sources = np.atleast_3d(estimated_sources) 522 | reference_sources = np.atleast_3d(reference_sources) 523 | # we will ensure input doesn't have more than 3 dimensions in validate 524 | 525 | validate(reference_sources, estimated_sources) 526 | # If empty matrices were supplied, return empty lists (special case) 527 | if reference_sources.size == 0 or estimated_sources.size == 0: 528 | return np.array([]), np.array([]), np.array([]), np.array([]) 529 | 530 | nsrc = reference_sources.shape[0] 531 | 532 | nwin = int( 533 | np.floor((reference_sources.shape[1] - window + hop) / hop) 534 | ) 535 | # if fewer than 2 windows would be evaluated, return the images result 536 | if nwin < 2: 537 | result = bss_eval_images(reference_sources, 538 | estimated_sources, 539 | compute_permutation) 540 | return [np.expand_dims(score, -1) for score in result] 541 | 542 | # compute the criteria across all windows 543 | sdr = np.empty((nsrc, nwin)) 544 | isr = np.empty((nsrc, nwin)) 545 | sir = np.empty((nsrc, nwin)) 546 | sar = np.empty((nsrc, nwin)) 547 | perm = np.empty((nsrc, nwin)) 548 | 549 | # k iterates across all the windows 550 | for k in range(nwin): 551 | win_slice = slice(k * hop, k * hop + window) 552 | ref_slice = reference_sources[:, win_slice, :] 553 | est_slice = estimated_sources[:, win_slice, :] 554 | # check for a silent frame 555 | if (not _any_source_silent(ref_slice) and 556 | not _any_source_silent(est_slice)): 557 | sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = \ 558 | bss_eval_images( 559 | ref_slice, est_slice, compute_permutation 560 | ) 561 | else: 562 | # if we have a silent frame set results as np.nan 563 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 564 | 565 | return sdr, isr, sir, sar, perm 566 | 567 | 568 | def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen): 569 | """Decomposition of an estimated source image into four components 570 | representing respectively the true source image, spatial (or filtering) 571 | distortion, interference and artifacts, derived from the true source 572 | images using multichannel time-invariant filters. 573 | """ 574 | nsampl = estimated_source.size 575 | # decomposition 576 | # true source image 577 | s_true = np.hstack((reference_sources[j], np.zeros(flen - 1))) 578 | # spatial (or filtering) distortion 579 | e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source, 580 | flen) - s_true 581 | # interference 582 | e_interf = _project(reference_sources, 583 | estimated_source, flen) - s_true - e_spat 584 | # artifacts 585 | e_artif = -s_true - e_spat - e_interf 586 | e_artif[:nsampl] += estimated_source 587 | return (s_true, e_spat, e_interf, e_artif) 588 | 589 | 590 | def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen, 591 | Gj=None, G=None): 592 | """Decomposition of an estimated source image into four components 593 | representing respectively the true source image, spatial (or filtering) 594 | distortion, interference and artifacts, derived from the true source 595 | images using multichannel time-invariant filters. 596 | Adapted version to work with multichannel sources. 597 | Improved performance can be gained by passing Gj and G parameters initially 598 | as all zeros. These parameters store the results from the computation of 599 | the G matrix in _project_images and then return them for subsequent calls 600 | to this function. This only works when not computing permuations. 601 | """ 602 | nsampl = np.shape(estimated_source)[0] 603 | nchan = np.shape(estimated_source)[1] 604 | # are we saving the Gj and G parameters? 605 | saveg = Gj is not None and G is not None 606 | # decomposition 607 | # true source image 608 | s_true = np.hstack((np.reshape(reference_sources[j], 609 | (nsampl, nchan), 610 | order="F").transpose(), 611 | np.zeros((nchan, flen - 1)))) 612 | # spatial (or filtering) distortion 613 | if saveg: 614 | e_spat, Gj = _project_images(reference_sources[j, np.newaxis, :], 615 | estimated_source, flen, Gj) 616 | else: 617 | e_spat = _project_images(reference_sources[j, np.newaxis, :], 618 | estimated_source, flen) 619 | e_spat = e_spat - s_true 620 | # interference 621 | if saveg: 622 | e_interf, G = _project_images(reference_sources, 623 | estimated_source, flen, G) 624 | else: 625 | e_interf = _project_images(reference_sources, 626 | estimated_source, flen) 627 | e_interf = e_interf - s_true - e_spat 628 | # artifacts 629 | e_artif = -s_true - e_spat - e_interf 630 | e_artif[:, :nsampl] += estimated_source.transpose() 631 | # return Gj and G only if they were passed in 632 | if saveg: 633 | return (s_true, e_spat, e_interf, e_artif, Gj, G) 634 | else: 635 | return (s_true, e_spat, e_interf, e_artif) 636 | 637 | 638 | def _project(reference_sources, estimated_source, flen): 639 | """Least-squares projection of estimated source on the subspace spanned by 640 | delayed versions of reference sources, with delays between 0 and flen-1 641 | """ 642 | nsrc = reference_sources.shape[0] 643 | nsampl = reference_sources.shape[1] 644 | 645 | # computing coefficients of least squares problem via FFT ## 646 | # zero padding and FFT of input data 647 | reference_sources = np.hstack((reference_sources, 648 | np.zeros((nsrc, flen - 1)))) 649 | estimated_source = np.hstack((estimated_source, np.zeros(flen - 1))) 650 | n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) 651 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 652 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 653 | # inner products between delayed versions of reference_sources 654 | G = np.zeros((nsrc * flen, nsrc * flen)) 655 | for i in range(nsrc): 656 | for j in range(nsrc): 657 | ssf = sf[i] * np.conj(sf[j]) 658 | ssf = np.real(scipy.fftpack.ifft(ssf)) 659 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 660 | r=ssf[:flen]) 661 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 662 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 663 | # inner products between estimated_source and delayed versions of 664 | # reference_sources 665 | D = np.zeros(nsrc * flen) 666 | for i in range(nsrc): 667 | ssef = sf[i] * np.conj(sef) 668 | ssef = np.real(scipy.fftpack.ifft(ssef)) 669 | D[i * flen: (i+1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1])) 670 | 671 | # Computing projection 672 | # Distortion filters 673 | try: 674 | C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F') 675 | except np.linalg.linalg.LinAlgError: 676 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F') 677 | # Filtering 678 | sproj = np.zeros(nsampl + flen - 1) 679 | for i in range(nsrc): 680 | sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1] 681 | return sproj 682 | 683 | 684 | def _project_images(reference_sources, estimated_source, flen, G=None): 685 | """Least-squares projection of estimated source on the subspace spanned by 686 | delayed versions of reference sources, with delays between 0 and flen-1. 687 | Passing G as all zeros will populate the G matrix and return it so it can 688 | be passed into the next call to avoid recomputing G (this will only works 689 | if not computing permutations). 690 | """ 691 | nsrc = reference_sources.shape[0] 692 | nsampl = reference_sources.shape[1] 693 | nchan = reference_sources.shape[2] 694 | reference_sources = np.reshape(np.transpose(reference_sources, (2, 0, 1)), 695 | (nchan*nsrc, nsampl), order='F') 696 | 697 | # computing coefficients of least squares problem via FFT ## 698 | # zero padding and FFT of input data 699 | reference_sources = np.hstack((reference_sources, 700 | np.zeros((nchan*nsrc, flen - 1)))) 701 | estimated_source = \ 702 | np.hstack((estimated_source.transpose(), np.zeros((nchan, flen - 1)))) 703 | n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) 704 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 705 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 706 | 707 | # inner products between delayed versions of reference_sources 708 | if G is None: 709 | saveg = False 710 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 711 | for i in range(nchan * nsrc): 712 | for j in range(i+1): 713 | ssf = sf[i] * np.conj(sf[j]) 714 | ssf = np.real(scipy.fftpack.ifft(ssf)) 715 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 716 | r=ssf[:flen]) 717 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 718 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 719 | else: # avoid recomputing G (only works if no permutation is desired) 720 | saveg = True # return G 721 | if np.all(G == 0): # only compute G if passed as 0 722 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 723 | for i in range(nchan * nsrc): 724 | for j in range(i+1): 725 | ssf = sf[i] * np.conj(sf[j]) 726 | ssf = np.real(scipy.fftpack.ifft(ssf)) 727 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 728 | r=ssf[:flen]) 729 | G[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 730 | G[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 731 | 732 | # inner products between estimated_source and delayed versions of 733 | # reference_sources 734 | D = np.zeros((nchan * nsrc * flen, nchan)) 735 | for k in range(nchan * nsrc): 736 | for i in range(nchan): 737 | ssef = sf[k] * np.conj(sef[i]) 738 | ssef = np.real(scipy.fftpack.ifft(ssef)) 739 | D[k * flen: (k+1) * flen, i] = \ 740 | np.hstack((ssef[0], ssef[-1:-flen:-1])).transpose() 741 | 742 | # Computing projection 743 | # Distortion filters 744 | try: 745 | C = np.linalg.solve(G, D).reshape(flen, nchan*nsrc, nchan, order='F') 746 | except np.linalg.linalg.LinAlgError: 747 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan*nsrc, nchan, 748 | order='F') 749 | # Filtering 750 | sproj = np.zeros((nchan, nsampl + flen - 1)) 751 | for k in range(nchan * nsrc): 752 | for i in range(nchan): 753 | sproj[i] += fftconvolve(C[:, k, i].transpose(), 754 | reference_sources[k])[:nsampl + flen - 1] 755 | # return G only if it was passed in 756 | if saveg: 757 | return sproj, G 758 | else: 759 | return sproj 760 | 761 | 762 | def _bss_source_crit(s_true, e_spat, e_interf, e_artif): 763 | """Measurement of the separation quality for a given source in terms of 764 | filtered true source, interference and artifacts. 765 | """ 766 | # energy ratios 767 | s_filt = s_true + e_spat 768 | sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif)**2)) 769 | sir = _safe_db(np.sum(s_filt**2), np.sum(e_interf**2)) 770 | sar = _safe_db(np.sum((s_filt + e_interf)**2), np.sum(e_artif**2)) 771 | return (sdr, sir, sar) 772 | 773 | 774 | def _bss_image_crit(s_true, e_spat, e_interf, e_artif): 775 | """Measurement of the separation quality for a given image in terms of 776 | filtered true source, spatial error, interference and artifacts. 777 | """ 778 | # energy ratios 779 | sdr = _safe_db(np.sum(s_true**2), np.sum((e_spat+e_interf+e_artif)**2)) 780 | isr = _safe_db(np.sum(s_true**2), np.sum(e_spat**2)) 781 | sir = _safe_db(np.sum((s_true+e_spat)**2), np.sum(e_interf**2)) 782 | sar = _safe_db(np.sum((s_true+e_spat+e_interf)**2), np.sum(e_artif**2)) 783 | return (sdr, isr, sir, sar) 784 | 785 | 786 | def _safe_db(num, den): 787 | """Properly handle the potential +Inf db SIR, instead of raising a 788 | RuntimeWarning. Only denominator is checked because the numerator can never 789 | be 0. 790 | """ 791 | if den == 0: 792 | return np.Inf 793 | return 10 * np.log10(num / den) 794 | 795 | 796 | def evaluate(reference_sources, estimated_sources, **kwargs): 797 | """Compute all metrics for the given reference and estimated signals. 798 | NOTE: This will always compute :func:`mir_eval.separation.bss_eval_images` 799 | for any valid input and will additionally compute 800 | :func:`mir_eval.separation.bss_eval_sources` for valid input with fewer 801 | than 3 dimensions. 802 | Examples 803 | -------- 804 | >>> # reference_sources[n] should be an ndarray of samples of the 805 | >>> # n'th reference source 806 | >>> # estimated_sources[n] should be the same for the n'th estimated source 807 | >>> scores = mir_eval.separation.evaluate(reference_sources, 808 | ... estimated_sources) 809 | Parameters 810 | ---------- 811 | reference_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 812 | matrix containing true sources 813 | estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 814 | matrix containing estimated sources 815 | kwargs 816 | Additional keyword arguments which will be passed to the 817 | appropriate metric or preprocessing functions. 818 | Returns 819 | ------- 820 | scores : dict 821 | Dictionary of scores, where the key is the metric name (str) and 822 | the value is the (float) score achieved. 823 | """ 824 | # Compute all the metrics 825 | scores = collections.OrderedDict() 826 | 827 | sdr, isr, sir, sar, perm = util.filter_kwargs( 828 | bss_eval_images, 829 | reference_sources, 830 | estimated_sources, 831 | **kwargs 832 | ) 833 | scores['Images - Source to Distortion'] = sdr.tolist() 834 | scores['Images - Image to Spatial'] = isr.tolist() 835 | scores['Images - Source to Interference'] = sir.tolist() 836 | scores['Images - Source to Artifact'] = sar.tolist() 837 | scores['Images - Source permutation'] = perm.tolist() 838 | 839 | sdr, isr, sir, sar, perm = util.filter_kwargs( 840 | bss_eval_images_framewise, 841 | reference_sources, 842 | estimated_sources, 843 | **kwargs 844 | ) 845 | scores['Images Frames - Source to Distortion'] = sdr.tolist() 846 | scores['Images Frames - Image to Spatial'] = isr.tolist() 847 | scores['Images Frames - Source to Interference'] = sir.tolist() 848 | scores['Images Frames - Source to Artifact'] = sar.tolist() 849 | scores['Images Frames - Source permutation'] = perm.tolist() 850 | 851 | # Verify we can compute sources on this input 852 | if reference_sources.ndim < 3 and estimated_sources.ndim < 3: 853 | sdr, sir, sar, perm = util.filter_kwargs( 854 | bss_eval_sources_framewise, 855 | reference_sources, 856 | estimated_sources, 857 | **kwargs 858 | ) 859 | scores['Sources Frames - Source to Distortion'] = sdr.tolist() 860 | scores['Sources Frames - Source to Interference'] = sir.tolist() 861 | scores['Sources Frames - Source to Artifact'] = sar.tolist() 862 | scores['Sources Frames - Source permutation'] = perm.tolist() 863 | 864 | sdr, sir, sar, perm = util.filter_kwargs( 865 | bss_eval_sources, 866 | reference_sources, 867 | estimated_sources, 868 | **kwargs 869 | ) 870 | scores['Sources - Source to Distortion'] = sdr.tolist() 871 | scores['Sources - Source to Interference'] = sir.tolist() 872 | scores['Sources - Source to Artifact'] = sar.tolist() 873 | scores['Sources - Source permutation'] = perm.tolist() 874 | 875 | return scores 876 | 877 | 878 | # result = evaluate(reference, estimated) 879 | # print(result) --------------------------------------------------------------------------------