├── .gitignore ├── CNNparameters.py ├── README ├── RNNparameters.py ├── bss_eval.py ├── build_hparams.py ├── constants.py ├── data_lib.py ├── embedding_summary.py ├── graph.py ├── helper.py ├── hparams_logs └── __init__.py ├── hyperparameters.py ├── inference.py ├── kmeans.py ├── loader.py ├── main.py ├── mir_bss_eval.py ├── preprocess_libri.py ├── preprocess_wsj0.py ├── pylogs └── last_experiment_num.log ├── rtl_loader.py ├── start_tensorboard.py ├── summaries.py ├── train.py ├── tuples.py └── utilities.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_STORE 2 | __pycache__ 3 | *.swp 4 | *.wav 5 | *.wv1 6 | *.wv2 7 | logs/* 8 | data/ 9 | -------------------------------------------------------------------------------- /CNNparameters.py: -------------------------------------------------------------------------------- 1 | """Model parameters for the RNN.""" 2 | import constants 3 | import tuples 4 | 5 | class CNNParameters(): 6 | ########################## 7 | ### NETWORK PARAMETERS ### 8 | ########################## 9 | filter_shape = (3, 3) # (Time, Freq) 10 | 11 | dilation_heights = ([1, 2, 4] + [8, 16, 32] + 12 | [1, 2, 4] + [8, 16, 32] + [1]) 13 | dilation_widths = dilation_heights 14 | n_c = 128 15 | channels = ([n_c, n_c, n_c] + [n_c, n_c, n_c] + 16 | [n_c, n_c, n_c] + [n_c, n_c, n_c] + [-1]) # -1 replaced by embeding_size 17 | use_residual = ([False, True, False] + [True, False, True] + 18 | [False, True, False] + [True, False, True] + [False]) 19 | 20 | padding = "SAME" 21 | 22 | ########################### 23 | ### TRAINING PARAMETERS ### 24 | ########################### 25 | max_steps = int(1e7) 26 | batch_size = 8 27 | learning_rate = 1e-3 28 | 29 | ## Opitmizer and LR decay functions 30 | optimizer = tuples.Adam 31 | use_exponential_decay = False # Use piecewise decay 32 | use_batch_normalization = True 33 | 34 | ## Piecewise decay parameters 35 | boundaries = [10000, 50000, 100000] 36 | rate_factors = [1.0, 0.50, 0.10, 0.01] 37 | 38 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | Source Separation Project 2 | 3 | Run main.py to train the model. Use TensorBoard to view results. 4 | 5 | ### Data Setup ### 6 | # WSJ0 7 | 1) Download wsj0.tar.gz (https://catalog.ldc.upenn.edu/ldc93s6a) 8 | 2) tar -xf wsj0.tar.gz -C data/wsj0_sph/ 9 | 2) Install sph2pipe (https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools) 10 | 3) python preprocess_wsj0.py (requires sph2pipe) 11 | 12 | # LibriSpeech 13 | 1) Download test-clean.tar.gz and train-clean-100.tar.gz (http://www.openslr.org/12/) 14 | 2) tar -xf *clean*.tar.gz -C data/ 15 | 3) python prepocess_libri.py 16 | 17 | # RealTalkLibri (RTL) 18 | 1) Download rtl.tar.gz 19 | 2) tar -xf rtl.tar.gc -C data/ 20 | 21 | ### Code Overview ### 22 | # Hyperparameters 23 | * hyperparameters - General Hyperparameters 24 | * CNNparameters - CNN specific parameters 25 | * RNNparameters - RNN specific parameters 26 | 27 | # Core 28 | * main - Construct the graph and train it 29 | * graph - Build the neural network model 30 | * train - Train the network 31 | 32 | # Data 33 | * loader.py - Make batches and prepare labels 34 | * rtl_loader.py - Make batches and prepare labels for RTL data 35 | * data_lib - transform between waveform, spectrogram, and neural network 36 | input representations 37 | * bss_eval - Metric for calculating proxy_SDR 38 | * mir_bss_eval - Metric for calculating SDR 39 | 40 | # Model Results & Visualizations 41 | * summaries - image, audio, scalar summary plots 42 | * embedding_summary - Visualizing embeddings in PCA space (in TensorBoard) 43 | 44 | # Misc 45 | * kmeans - kmeans implementation 46 | * helper - various useful functions 47 | * utilities - save hparams, track experiment number 48 | 49 | # Versions: 50 | Python: 3.6.1 51 | TensorFlow: 1.6.0-dev20180116 52 | CUDA: 9.0, V9.0.176 53 | cudNN: 8.0 54 | -------------------------------------------------------------------------------- /RNNparameters.py: -------------------------------------------------------------------------------- 1 | """Model parameters for the RNN.""" 2 | 3 | import constants 4 | import tuples 5 | 6 | 7 | class RNNParameters(): 8 | ########################## 9 | ### Network Parameters ### 10 | ########################## 11 | num_layers = 4 12 | layer_size = 500 # For FWD and BWD, so hidden layer has 2 * layer_size units 13 | 14 | ########################### 15 | ### Training Parameters ### 16 | ########################### 17 | max_steps = int(1e7) 18 | batch_size = 8 19 | learning_rate = 1e-3 20 | clip_gradient_norm = 200 21 | 22 | ## Opitmizer and LR decay functions 23 | optimizer = tuples.RMSProp 24 | use_exponential_decay = True # otherwise piecewise 25 | 26 | ## Exponential Decay parameters 27 | decay_steps = 2000 28 | decay_rate = 0.95 29 | -------------------------------------------------------------------------------- /bss_eval.py: -------------------------------------------------------------------------------- 1 | """TensorFlow implementation of Blind Source Separation metric* 2 | 3 | This code implements BSS using the method from II.B in the paper* 4 | which runs significantly faster than the time distortion one (III.B). We 5 | run this version during training to reduce computational load. 6 | 7 | * [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric 8 | Févotte, "Performance measurement in blind audio source separation," IEEE 9 | Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006. 10 | """ 11 | 12 | import tensorflow as tf 13 | import pdb as pdb 14 | import numpy as np 15 | import mir_bss_eval 16 | 17 | 18 | def tf_log10(x): 19 | numerator = tf.log(x) 20 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype)) 21 | return numerator / denominator 22 | 23 | def tf_square_norm(x, axis=None): 24 | return tf.square(tf.norm(x, axis=2)) 25 | 26 | def compute_proxy_SDR(st, ei, ea): 27 | return 10 * tf_log10(tf_square_norm(st) / tf_square_norm(ei + ea)) 28 | 29 | def compute_targets(sources, source_estimates): 30 | dotProd = tf.einsum("btw,btw->bt", sources, source_estimates) 31 | normalizer = tf.square(tf.norm(sources, axis=2)) 32 | targets = sources * tf.expand_dims(dotProd / normalizer, -1) 33 | return targets 34 | 35 | def compute_c(sources, source_estimates): 36 | gram = tf.einsum("btw,buw->btu", sources, sources) 37 | Ginv = tf.matrix_inverse(gram) 38 | products = tf.einsum("btw,buw->btu", source_estimates, sources) 39 | return tf.einsum("btu,buv->btv", Ginv, products) 40 | 41 | def compute_components(sources, source_estimates): 42 | source_targets = compute_targets(sources, source_estimates) 43 | 44 | c = compute_c(sources, source_estimates) 45 | subspace_projection = tf.einsum("btu,btw->buw", c, sources) 46 | interferences = subspace_projection - source_targets 47 | 48 | artifacts = source_estimates - subspace_projection 49 | 50 | return source_targets, interferences, artifacts 51 | 52 | def eval_proxy_SDR(sources, source_estimates): 53 | source_targets, interferences, artifacts = ( 54 | compute_components(sources, source_estimates)) 55 | 56 | proxy_SDR = compute_proxy_SDR(source_targets, interferences, artifacts) 57 | return tf.reduce_mean(proxy_SDR) 58 | -------------------------------------------------------------------------------- /build_hparams.py: -------------------------------------------------------------------------------- 1 | """Combine FLAGS, *MODEL*parameters.py, and hyperparameters together.""" 2 | 3 | import hyperparameters 4 | import RNNparameters 5 | import CNNparameters 6 | import pdb 7 | 8 | import constants 9 | import helper 10 | 11 | def transfer_variables(A, hparams): 12 | """Transfer variables from model parameters to hparams""" 13 | for (k,v) in vars(A).items(): 14 | if not k.startswith('__'): 15 | hparams.add_hparam(k, v) 16 | return hparams 17 | 18 | def find(listy, x): 19 | """Return index of x in listy, and None if it doesn't exist""" 20 | return listy.index(x) if x in listy else None 21 | 22 | def set_model_type(hparams, FLAGS): 23 | """Set hparams.model to FLAGS.hparams[model] if it is specified there. 24 | - We need to do this to load the correct hparams.""" 25 | if not FLAGS.hparams: 26 | return 27 | 28 | keyword = "model=" 29 | model_pos = find(FLAGS.hparams, keyword) 30 | if model_pos is None: 31 | return 32 | 33 | model_name_pos = model_pos + len(keyword) 34 | end_pos = find(FLAGS.hparams[model_name_pos:], ",") 35 | if end_pos is None: 36 | end_pos = len(FLAGS.hparams) 37 | 38 | hparams.model = FLAGS.hparams[model_name_pos:end_pos] 39 | 40 | def add_model_parameters(hparams, FLAGS): 41 | """Take the parameters in [MODEL]parameters.py and add them to hparams.""" 42 | set_model_type(hparams, FLAGS) 43 | 44 | if helper.model_is_recurrent(hparams.model): 45 | return transfer_variables(RNNparameters.RNNParameters, hparams) 46 | elif helper.model_is_convolutional(hparams.model): 47 | hparams = transfer_variables(CNNparameters.CNNParameters, hparams) 48 | hparams.channels[-1] = hparams.embedding_size 49 | return hparams 50 | raise Exception("Invalid Model: %s" % hparams.model) 51 | 52 | def build_hparams(FLAGS): 53 | """Build all hyperparameters associated with the core computation.""" 54 | hparams = add_model_parameters(hyperparameters.params, FLAGS) 55 | hparams.training = True 56 | if FLAGS.hparams: 57 | hparams.parse(FLAGS.hparams) 58 | if FLAGS.eval_model: 59 | hparams.summary_frequency = 1 60 | hparams.test_frequency = 1 61 | hparams.save_frequency = 5 62 | hparams.training = False 63 | 64 | hparams.sdr_frequency = hparams.test_frequency * constants.AVG_SDR_ON_N_BATCHES 65 | # See STFT scipy doc 66 | hparams.waveform_size = (hparams.ntimebins - 1) * constants.ndiff 67 | 68 | return hparams 69 | 70 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | """Constants for STFT and misc items.""" 2 | 3 | import socket 4 | 5 | 6 | ####################### 7 | ### STFT PARAMETERS ### 8 | ####################### 9 | nperseg = 256 10 | noverlap = int(nperseg * 3/4) 11 | ndiff = nperseg - noverlap 12 | nfreqbins = int((nperseg / 2)) + 1 13 | 14 | ##################### 15 | ### MISCELLAENOUS ### 16 | ##################### 17 | Fs = 8000 18 | max_input_snr = 5 19 | 20 | kmeans_max_iters = 50 # Max iterations when running k-means 21 | 22 | if 'fattire' in socket.gethostname(): 23 | use_port = 54122 # Brian 24 | else: 25 | use_port = 54621 26 | 27 | TRAIN_OP_NAME="train_op" 28 | ORACLE_SRC_EST_NAME="oracle_source_estimates" 29 | KMEANS_SRC_EST_NAME="kmeans_source_estimates" 30 | 31 | LAST_EXPERIMENT_NUM_FILE = "pylogs/last_experiment_num.log" 32 | 33 | AVG_SDR_ON_N_BATCHES = 50 34 | -------------------------------------------------------------------------------- /data_lib.py: -------------------------------------------------------------------------------- 1 | """Transform between waveform space, spectrogram, and NN input.""" 2 | 3 | import numpy as np 4 | import pdb 5 | import scipy.signal as signal 6 | 7 | import constants 8 | 9 | 10 | def stft(x): 11 | """Compute the STFT.""" 12 | return signal.stft(x, constants.Fs, nperseg=constants.nperseg, 13 | noverlap=constants.noverlap)[2] 14 | 15 | def istft(X): 16 | """Compute the iSTFT.""" 17 | return signal.istft(X, constants.Fs, nperseg=constants.nperseg, 18 | noverlap=constants.noverlap)[1] 19 | 20 | def apply_to_magnitude(X): 21 | X = np.log(X + 1e-6) 22 | return X 23 | 24 | def unapply_to_magnitude(X): 25 | return np.exp(X) - 1e-6 26 | 27 | def wav_to_nn_representation(wav): 28 | """Convert the waveform into the representation for the NN.""" 29 | XP = stft(wav) 30 | X, P = np.float32(np.abs(XP)), np.angle(XP) 31 | X = apply_to_magnitude(X) 32 | return X.T, P.T 33 | 34 | def mag_phase_to_complex(X, phases): 35 | return X * np.exp(phases * 1j) 36 | 37 | def nn_representation_to_wav_spect(X, phases): 38 | """Given the NN input, return the spectrogram and waveform representation.""" 39 | X = unapply_to_magnitude(np.transpose(X, axes=[0, 2, 1])) 40 | P = np.transpose(phases, axes=[0, 2, 1]) 41 | XP = mag_phase_to_complex(X, P) 42 | 43 | waveforms = [] 44 | batch_size = X.shape[0] 45 | for b in range(batch_size): 46 | waveforms.append(istft(XP[b, :, :])) 47 | 48 | return np.array(waveforms, dtype=np.float32), X 49 | -------------------------------------------------------------------------------- /embedding_summary.py: -------------------------------------------------------------------------------- 1 | """Functions for setting up embedding visualization summary.""" 2 | 3 | import numpy as np 4 | import os 5 | import pdb 6 | import tensorflow as tf 7 | from tensorflow.contrib.tensorboard.plugins import projector 8 | 9 | import helper 10 | 11 | 12 | class Embedding(): 13 | def __init__(self, embedding_config, embedding_assign, show_embeddings=True): 14 | self.show_embeddings = show_embeddings 15 | self.embedding_config = embedding_config 16 | self.embedding_assign = embedding_assign 17 | 18 | def get_assign_op(self): 19 | return self.embedding_assign 20 | 21 | def visualize_embeddings(self, train_writer): 22 | if self.show_embeddings: 23 | projector.visualize_embeddings(train_writer, self.embedding_config) 24 | 25 | def handle_embedding(hparams, embeddings): 26 | if not hparams.show_embeddings: 27 | # Embedding of noops 28 | return Embedding(tf.constant(0), tf.constant(0), False) 29 | 30 | with tf.name_scope("Embedding"): 31 | embedding_act = embeddings[0, :, :] 32 | 33 | init_value = np.zeros(embedding_act.shape, dtype=np.float32) 34 | embedding_var = tf.Variable(init_value, name="Variable") 35 | embedding_assign = tf.assign(embedding_var, embedding_act, name="Assign") 36 | 37 | # Embedding 38 | embedding_config = projector.ProjectorConfig() 39 | embedding_obj = embedding_config.embeddings.add() 40 | embedding_obj.tensor_name = embedding_var.name 41 | 42 | embedding_obj.metadata_path = os.path.abspath('.') + '/' + hparams.logdir + '/train/metadata.tsv' 43 | embedding_info = Embedding(embedding_config, embedding_assign) 44 | return embedding_info 45 | 46 | color_to_label = {"blue": 0, 47 | "yellow": 1, 48 | "red": 2, 49 | "purple": 3, 50 | "pink": 4, 51 | "grey": 5, 52 | "turqoise": 6, 53 | "blue-grey": 7, 54 | "green": 8, 55 | "orange": 9} 56 | 57 | def get_label(hparams, thresholded, max_idx, i): 58 | assert hparams.num_targets == 2, "More colors unsupported yet" 59 | # Need these so we can make thresholded values grey 60 | if i == 0: 61 | return color_to_label["yellow"] 62 | if i == 1: 63 | return color_to_label["purple"] 64 | if i == 2: 65 | return color_to_label["pink"] 66 | 67 | if thresholded == 0: 68 | return color_to_label["grey"] 69 | if max_idx == 0: 70 | return color_to_label["red"] # Spk A 71 | return color_to_label["blue"] # Spk B 72 | 73 | def write_tsv(hparams, X_mixtures, masks): 74 | if not hparams.show_embeddings: 75 | return 76 | 77 | X_mixtures = helper.np_collapse_freq_into_time(X_mixtures) 78 | masks = helper.np_collapse_freq_into_time(masks) 79 | with open(hparams.logdir + "/train/metadata.tsv", "w") as f: 80 | # Write data from first batch mixture 81 | X_mixture = X_mixtures[0, :] 82 | mask = masks[0, :, :] 83 | 84 | threshold_mask = helper.np_get_threshold_mask(hparams, X_mixture) 85 | max_idx = np.argmax(mask, axis=1) 86 | for i in range(masks.shape[1]): 87 | label = get_label(hparams, threshold_mask[i], max_idx[i], i) 88 | f.write("%d\n" % label) 89 | 90 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | """Build the entire TensorFlow graph: network, loss, optimizer, summaries.""" 2 | 3 | import tensorflow as tf 4 | import pdb 5 | import numpy as np 6 | 7 | import summaries 8 | import embedding_summary 9 | import helper 10 | import constants 11 | 12 | 13 | ############## 14 | ### INPUTS ### 15 | ############## 16 | def get_input_shapes(hparams): 17 | X_mixtures_shape = [hparams.batch_size, hparams.ntimebins, constants.nfreqbins] 18 | phases_shape = [hparams.batch_size, hparams.ntimebins, constants.nfreqbins] 19 | oracle_mask_shape = [hparams.batch_size, hparams.ntimebins, constants.nfreqbins, 20 | hparams.num_targets] 21 | sources_shape = [hparams.batch_size, hparams.num_targets, hparams.waveform_size] 22 | X_sources_shape = [hparams.batch_size, hparams.ntimebins, constants.nfreqbins, 23 | hparams.num_targets] 24 | 25 | return (X_mixtures_shape, phases_shape, oracle_mask_shape, sources_shape, 26 | X_sources_shape) 27 | 28 | def build_input_placeholders(hparams): 29 | place_holders = [] 30 | for shape in get_input_shapes(hparams): 31 | place_holders.append(tf.placeholder(tf.float32, shape=shape)) 32 | return place_holders 33 | 34 | 35 | ################# 36 | ### RNN MODEL ### 37 | ################# 38 | def make_multi_rnn_cell(hparams): 39 | cells = [] 40 | for _ in range(hparams.num_layers): 41 | cells.append(tf.contrib.rnn.BasicLSTMCell(hparams.layer_size)) 42 | return tf.contrib.rnn.MultiRNNCell(cells) 43 | 44 | def make_rnn_net(hparams, X_mixtures): 45 | both_outputs, _ = tf.nn.bidirectional_dynamic_rnn( 46 | make_multi_rnn_cell(hparams), 47 | make_multi_rnn_cell(hparams), 48 | X_mixtures, 49 | dtype=tf.float32) 50 | outputs = tf.concat(both_outputs, 2) 51 | 52 | outputs = helper.collapse_time_into_batch(outputs) 53 | output_size = constants.nfreqbins * hparams.embedding_size 54 | embeddings = tf.contrib.layers.linear(outputs, output_size) 55 | embeddings = tf.reshape(embeddings, 56 | [embeddings.shape[0].value, -1, hparams.embedding_size]) 57 | 58 | return helper.uncollapse_time_from_batch(hparams, embeddings) 59 | 60 | 61 | ################# 62 | ### CNN MODEL ### 63 | ################# 64 | def conv2d(hparams, x, i): 65 | return tf.layers.conv2d(x, 66 | hparams.channels[i], 67 | hparams.filter_shape, 68 | use_bias=True, 69 | dilation_rate=[hparams.dilation_heights[i], hparams.dilation_widths[i]], 70 | padding=hparams.padding, 71 | kernel_initializer=None) 72 | 73 | def make_cnn_net(hparams, X_mixtures): 74 | """Make the Dilated convolultional architecture.""" 75 | num_layers = len(hparams.channels) 76 | prev_layer = 0. 77 | layer = tf.expand_dims(X_mixtures, -1) 78 | for i in range(num_layers): 79 | layer = conv2d(hparams, layer, i) 80 | if i == num_layers - 1: # Last layer 81 | break 82 | 83 | if hparams.use_batch_normalization: 84 | layer = tf.layers.batch_normalization(layer, axis=3, training=True) 85 | if hparams.use_residual[i] and prev_layer != 0: 86 | layer += prev_layer 87 | 88 | layer = tf.nn.relu(layer) 89 | 90 | if hparams.use_residual[i]: 91 | prev_layer = layer 92 | 93 | return layer 94 | 95 | def print_num_parameters(): 96 | """Print the number of parameters in the model.""" 97 | num = np.sum([np.prod(v.shape.as_list()) for v in tf.trainable_variables()]) 98 | print ("Model has %d parameters" % num) 99 | 100 | def make_net(hparams, X_mixtures): 101 | if helper.model_is_convolutional(hparams.model): 102 | embeddings = make_cnn_net(hparams, X_mixtures) 103 | elif helper.model_is_recurrent(hparams.model): 104 | embeddings = make_rnn_net(hparams, X_mixtures) 105 | else: 106 | raise Exception("Unknown model: %s" % hparams.model) 107 | print_num_parameters() 108 | 109 | return embeddings 110 | 111 | 112 | ############ 113 | ### LOSS ### 114 | ############ 115 | def mse_loss(hparams, attractors, X_mixtures, embeddings, X_sources): 116 | mask_estimate = tf.einsum("bik,bck->bic", embeddings, attractors) 117 | mask_estimate = tf.nn.softmax(mask_estimate) 118 | X_source_estimates = tf.expand_dims(X_mixtures, -1) * mask_estimate 119 | 120 | mse_loss = tf.reduce_mean(tf.square(X_sources - X_source_estimates)) 121 | return mse_loss 122 | 123 | ################# 124 | ### OPTIMIZER ### 125 | ################# 126 | def build_optimizer(hparams, loss): 127 | if not hparams.training: 128 | tf.add_to_collection(constants.TRAIN_OP_NAME, tf.constant(0, tf.float32)) 129 | return [tf.constant(0, tf.float32)] * 2 130 | global_step = tf.Variable(0, trainable=False) 131 | 132 | if hparams.use_exponential_decay: 133 | learning_rate = tf.train.exponential_decay(hparams.learning_rate, global_step, 134 | hparams.decay_steps, hparams.decay_rate, staircase=True) 135 | else: 136 | values = list(hparams.learning_rate * np.array(hparams.rate_factors)) 137 | learning_rate = tf.train.piecewise_constant(global_step, hparams.boundaries, values) 138 | 139 | optimizer = hparams.optimizer.func(learning_rate) 140 | variables = tf.trainable_variables() 141 | gradients = tf.gradients(loss, variables) 142 | 143 | if helper.model_is_recurrent(hparams.model) and hparams.clip_gradient_norm > -1: 144 | gradients = [None if gradient is None else 145 | tf.clip_by_norm(gradient, hparams.clip_gradient_norm) 146 | for gradient in gradients] 147 | 148 | train_op = optimizer.apply_gradients(zip(gradients, variables), 149 | global_step=global_step) 150 | tf.add_to_collection(constants.TRAIN_OP_NAME, train_op) 151 | 152 | return global_step, learning_rate 153 | 154 | 155 | ############# 156 | ### GRAPH ### 157 | ############# 158 | def build_inference_graph(hparams): 159 | X_mixtures_shape = [hparams.batch_size, hparams.ntimebins, constants.nfreqbins] 160 | phases_shape = [hparams.batch_size, hparams.ntimebins, constants.nfreqbins] 161 | 162 | X_mixtures = tf.placeholder(tf.float32, shape=X_mixtures_shape) 163 | phases = tf.placeholder(tf.float32, shape=phases_shape) 164 | 165 | threshold_mask = helper.get_threshold_mask(hparams, X_mixtures) 166 | embeddings = make_net(hparams, X_mixtures) 167 | embeddings = tf.nn.l2_normalize(embeddings, axis=embeddings.shape.ndims - 1) 168 | 169 | X_mixtures_rs = helper.collapse_freq_into_time(X_mixtures) 170 | embeddings = helper.collapse_freq_into_time(embeddings) 171 | threshold_mask = helper.collapse_freq_into_time(threshold_mask) 172 | 173 | inference_summaries = summaries.setup_inference_summary(hparams, threshold_mask, 174 | X_mixtures_rs, phases, embeddings) 175 | 176 | inference_summaries = tf.summary.merge(inference_summaries) 177 | return X_mixtures, phases, inference_summaries 178 | 179 | def build_train_graph(hparams): 180 | inputs = build_input_placeholders(hparams) 181 | X_mixtures, phases, oracle_mask, sources, X_sources = inputs 182 | threshold_mask = helper.get_threshold_mask(hparams, X_mixtures) 183 | 184 | embeddings = make_net(hparams, X_mixtures) 185 | embeddings = tf.nn.l2_normalize(embeddings, axis=embeddings.shape.ndims - 1) 186 | 187 | # Put time-frequency in the same axis for clustering embeddings 188 | X_mixtures = helper.collapse_freq_into_time(X_mixtures) 189 | oracle_mask = helper.collapse_freq_into_time(oracle_mask) 190 | X_sources = helper.collapse_freq_into_time(X_sources) 191 | embeddings = helper.collapse_freq_into_time(embeddings) 192 | threshold_mask = helper.collapse_freq_into_time(threshold_mask) 193 | 194 | attractors = helper.get_attractors(hparams, threshold_mask, embeddings, oracle_mask) 195 | loss = mse_loss(hparams, attractors, X_mixtures, embeddings, X_sources) 196 | global_step, learning_rate = build_optimizer(hparams, loss) 197 | 198 | loss_summary = tf.summary.scalar("Loss", loss) 199 | 200 | # Summaries 201 | learning_rate_summary = tf.summary.scalar("Learning Rate", learning_rate) 202 | extra_data_summary = [learning_rate_summary] 203 | 204 | train_summary, test_summary, oracle_summary, SDR_summary = ( 205 | summaries.create_summaries(hparams, threshold_mask, attractors, 206 | X_mixtures, phases, oracle_mask, sources, X_sources, embeddings)) 207 | 208 | loss_summaries = [loss_summary] 209 | train_summary = tf.summary.merge(loss_summaries + extra_data_summary + train_summary) 210 | test_summary = tf.summary.merge(loss_summaries + test_summary) 211 | oracle_summary = tf.summary.merge(oracle_summary) 212 | SDR_summary = tf.summary.merge(SDR_summary) 213 | 214 | embedding_info = embedding_summary.handle_embedding(hparams, embeddings) 215 | 216 | return (inputs, embedding_info, loss, loss_summary, 217 | [train_summary, test_summary, oracle_summary, SDR_summary]) 218 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | """Various non-core functions.""" 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import pdb 6 | 7 | import constants 8 | 9 | 10 | #################### 11 | ### THRESHOLDING ### 12 | #################### 13 | def get_threshold_mask(hparams, x): 14 | """Threshold the mixtures to 1 or 0 for each TF bin. 15 | Input: 16 | X_mixtures: B x T x F 17 | Output: 18 | X_mixtures: B x T x F \in {0,1} 19 | """ 20 | 21 | axis = list(range(1, x.shape.ndims)) 22 | min_val = tf.reduce_min(x, axis=axis, keepdims=True) 23 | max_val = tf.reduce_max(x, axis=axis, keepdims=True) 24 | thresh = min_val + hparams.threshold_factor * (max_val - min_val) 25 | cond = tf.less(x, thresh) 26 | return tf.where(cond, tf.zeros(tf.shape(x)), tf.ones(tf.shape(x))) 27 | 28 | def np_get_threshold_mask(hparams, x): 29 | min_val = np.min(x) 30 | max_val = np.max(x) 31 | thresh = min_val + hparams.threshold_factor * (max_val - min_val) 32 | return (x > thresh).astype(np.int32) 33 | 34 | def get_attractors(hparams, threshold_mask, embeddings, oracle_mask): 35 | """Calculate the attractors of the embeddings. 36 | 37 | Input: 38 | threshold_mask: BxN - Binary Mask indicating non-thresholded TF bins 39 | embeddings: BxNxK - All N K-dimensional embeddings 40 | oracle_mask: BxNxC - Binary Mask indicating classification of each TF bin 41 | 42 | Output: 43 | attractors: BxCxK - C attractor points in the embedding space 44 | """ 45 | 46 | threshold_mask = tf.expand_dims(threshold_mask, -1) * oracle_mask 47 | bin_count = tf.reduce_sum(threshold_mask, axis=1) # Count of non-threshold TF bins 48 | bin_count = tf.expand_dims(bin_count, -1) 49 | 50 | unnormalized_attractors = tf.einsum("bik,bic->bck", embeddings, threshold_mask) 51 | attractors = tf.divide(unnormalized_attractors, bin_count + 1e-6) # Dont' divide by 0 52 | 53 | return attractors 54 | 55 | ############ 56 | ### MISC ### 57 | ############ 58 | def np_collapse_freq_into_time(x): 59 | """Collapse the freq and time dimensions.""" 60 | if x.ndim == 4: 61 | return np.reshape(x, [x.shape[0], x.shape[1] * x.shape[2], -1]) 62 | return np.reshape(x, [x.shape[0], x.shape[1] * x.shape[2]]) 63 | 64 | def collapse_freq_into_time(x): 65 | """Collapse the freq and time dimensions.""" 66 | if x.shape.ndims == 4: 67 | return tf.reshape(x, [x.shape[0], x.shape[1] * x.shape[2], -1]) 68 | return tf.reshape(x, [x.shape[0], x.shape[1] * x.shape[2]]) 69 | 70 | def uncollapse_freq_into_time(hparams, x): 71 | """UNCollapse the freq and time dimensions.""" 72 | if x.shape.ndims == 3: 73 | return tf.reshape(x, [x.shape[0], hparams.ntimebins, constants.nfreqbins, -1]) 74 | return tf.reshape(x, [x.shape[0], hparams.ntimebins, constants.nfreqbins]) 75 | 76 | def collapse_time_into_batch(x): 77 | """Collapse the batch and time dimensions.""" 78 | return tf.reshape(x, [-1] + x.shape.as_list()[2:]) 79 | 80 | def uncollapse_time_from_batch(hparams, x): 81 | """Separate the batch and time dimensions.""" 82 | return tf.reshape(x, [hparams.batch_size, -1] + x.shape.as_list()[1:]) 83 | 84 | def model_is_recurrent(model): 85 | return "lstm" in model.lower() 86 | 87 | def model_is_convolutional(model): 88 | return "cnn" in model.lower() 89 | 90 | def get_oracle_waveform_savedir(hparams): 91 | return "ORACLE_%s" % hparams.data_source 92 | 93 | def get_kmeans_waveform_savedir(hparams): 94 | if model_is_convolutional(hparams.model): 95 | name = "%s_%d_c%d_%s_%d" % (hparams.model, hparams.filter_shape[1], 96 | hparams.channels[0], hparams.data_source, hparams.ntimebins) 97 | else: 98 | name = "%s_%s_%d" % (hparams.model, hparams.data_source, hparams.ntimebins) 99 | 100 | if hparams.add_white_noise: 101 | name = "white_noise_" + name 102 | return name 103 | 104 | def flush(*args): 105 | for arg in args: 106 | arg.flush() 107 | 108 | -------------------------------------------------------------------------------- /hparams_logs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShariqM/source_separation/3ecd8b26a7bbcb174eb6f133b7d5e31b181875ee/hparams_logs/__init__.py -------------------------------------------------------------------------------- /hyperparameters.py: -------------------------------------------------------------------------------- 1 | """Hyperparameters for the entire graph, model parameters are in modelparameters.py""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | params = tf.contrib.training.HParams( 7 | model = "CNN", # "BLSTM" or "CNN" 8 | data_source = "RTL", # "WSJ0", "LIBRI", or "RTL" 9 | 10 | num_targets = 2, # 2 speakers 11 | 12 | threshold_factor = 0.6, # Threshold TF bins below 0.6 * MAX 13 | embedding_size = 20, 14 | ntimebins = 400, 15 | 16 | ##################### 17 | ### MISCELLAENOUS ### 18 | ##################### 19 | add_white_noise = False, 20 | run_inference_test = False, # Test the model long speech data 21 | save_estimate_waveforms = True, 22 | save_oracle_waveforms = False, 23 | show_embeddings = False, 24 | summary_frequency = 50, 25 | save_frequency = 100, 26 | test_frequency = 50) 27 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """Only run inference, no training, loss or labels.""" 2 | 3 | import tensorflow as tf 4 | import pdb 5 | import time 6 | 7 | import loader 8 | 9 | def setup_inference(hparams, sess, init): 10 | sess.run(init) 11 | inference_writer = tf.summary.FileWriter(hparams.logdir + "/inference", sess.graph) 12 | 13 | if hparams.load_experiment_num: 14 | print ("Loading: %s" % hparams.loaddir) 15 | saver = tf.train.Saver() 16 | saver.restore(sess, tf.train.latest_checkpoint(hparams.loaddir)) 17 | 18 | return inference_writer 19 | 20 | def run_inference(hparams, X_mixtures_ph, phases_ph, inference_summaries): 21 | init = tf.global_variables_initializer() 22 | with tf.Session() as sess: 23 | inference_writer = setup_inference(hparams, sess, init) 24 | 25 | for step in range(hparams.max_steps): 26 | start = time.time() 27 | X_mixtures, phases = loader.get_inference_data(hparams) 28 | feed_dict = {X_mixtures_ph: X_mixtures, phases_ph: phases} 29 | 30 | raw_summary = sess.run(inference_summaries, feed_dict=feed_dict) 31 | 32 | if step % 25 == 0: 33 | # Write every 10 b/c TB seems to mix up audio segments... 34 | inference_writer.add_summary(raw_summary, step) 35 | print ("\t%d) Elapsed = %.3f secs" % (step, (time.time() - start))) 36 | -------------------------------------------------------------------------------- /kmeans.py: -------------------------------------------------------------------------------- 1 | """K-Means implementation.""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import constants 7 | 8 | 9 | def init_centers(embeddings, num_targets): 10 | """Pick the initial cluster positions for these embeddings. 11 | 12 | Args: 13 | embeddings: is a NxK matrix with K dimensional embeddings. 14 | 15 | Use a non-stochastic variant of the k-means++ initialization algorithm. 16 | 17 | https://en.wikipedia.org/wiki/K-means%2B%2B 18 | 19 | 1) Pick a random data point to be the first cluster point. 20 | 2) Mark each data point, x, with a distance, D(x) which is equal to 21 | the distance from x to the nearest cluster (minimum distance) 22 | 3) The next cluster is argmax_x (D(x)) 23 | 24 | """ 25 | n_embeddings, embedding_size = embeddings.shape 26 | 27 | centers = np.zeros((num_targets, embedding_size)) 28 | distances = np.zeros((num_targets, n_embeddings)) 29 | 30 | rand_idx = np.random.randint(embeddings.shape[0]) 31 | centers[0, :] = embeddings[rand_idx, :] 32 | 33 | for i in range(1, num_targets): 34 | distances[i, :] = np.sum(np.square(centers[i-1, :] - embeddings), axis=1) 35 | distances[i, :] = np.min(distances[:(i+1), :], axis=0) # Smallest distances 36 | idx = np.argmax(distances[i, :]) 37 | centers[i, :] = embeddings[idx, :] 38 | 39 | return centers 40 | 41 | def assign(centers, embeddings): 42 | """Assign embeddings to the closest cluster.""" 43 | num_targets = centers.shape[0] 44 | embeddings = np.expand_dims(embeddings, axis=2) 45 | centers = np.expand_dims(centers.T, axis=0) 46 | 47 | distances = np.linalg.norm(embeddings - centers, axis=1) 48 | assignments = np.argmin(distances, axis=1) 49 | return np.eye(num_targets)[assignments] 50 | 51 | def update_centers(assignments, embeddings): 52 | """Given the assignments of embeddings update the centers.""" 53 | normalizer = np.sum(assignments, axis=0, keepdims=True).T 54 | centers = np.dot(assignments.T, embeddings) / (normalizer + 1e-6) 55 | return centers 56 | 57 | def get_centers_impl(embeddings, threshold_mask, num_targets): 58 | """Run K-means on the embeddings and output the K-means centers. 59 | 60 | Arguments: 61 | embeddings: BxNxK - Embeddings for each batch (N index the TF bins) 62 | threshold_mask: BxN - Binary matrix indicating if this TF-bin was not thresholded 63 | num_targets: integer - hparams.num_targets 64 | """ 65 | 66 | batch_size, n_embeddings, embedding_size = embeddings.shape 67 | kmeans_centers = np.zeros((batch_size, num_targets, embedding_size), 68 | dtype=np.float32) 69 | 70 | converged_at = [] 71 | for b in range(batch_size): 72 | indices = np.where(threshold_mask[b, :] == 1.0)[0] 73 | # Only use embeddings that passed the threshold 74 | b_embeddings = embeddings[b, indices, :] 75 | centers = init_centers(b_embeddings, num_targets) 76 | 77 | converged = False 78 | for i in range(constants.kmeans_max_iters): 79 | assignments = assign(centers, b_embeddings) 80 | 81 | new_centers = update_centers(assignments, b_embeddings) 82 | if np.allclose(new_centers, centers): 83 | converged = True 84 | break 85 | centers = new_centers 86 | converged_at.append(i) 87 | 88 | kmeans_centers[b, :, :] = centers 89 | print ("K-Means converged in %.2f iters on average" % (np.mean(converged_at))) 90 | 91 | return kmeans_centers 92 | 93 | def get_centers(hparams, embeddings, threshold_mask): 94 | kmeans_centers = tf.py_func(get_centers_impl, 95 | [embeddings, threshold_mask, hparams.num_targets], tf.float32) 96 | 97 | kmeans_centers.set_shape((embeddings.shape[0].value, hparams.num_targets, 98 | hparams.embedding_size)) 99 | 100 | return kmeans_centers 101 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | """Create batches of data for the network to train on.""" 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import glob 6 | import pdb 7 | from scipy.io.wavfile import read 8 | from functools import partial 9 | 10 | import rtl_loader 11 | import data_lib 12 | import constants 13 | 14 | 15 | def make_X_and_mask(hparams, wavs): 16 | X_sources = np.zeros([hparams.ntimebins, constants.nfreqbins, len(wavs)]) 17 | for (i, wav) in enumerate(wavs): 18 | X_sources[:, :, i], _ = data_lib.wav_to_nn_representation(wav) 19 | 20 | max_idx = np.argmax(X_sources, axis=2) 21 | mask = (max_idx[...,None] == np.arange(hparams.num_targets)).astype(int) 22 | 23 | return X_sources, mask 24 | 25 | #################### 26 | ### MAKE MIXTURE ### 27 | #################### 28 | def snr_to_weight(snr): 29 | return 10 ** (snr / 20) 30 | 31 | def prepare_wavs(hparams, wav_i, wav_j): 32 | """Read the wav file, truncate at COLA length, and reweight them.""" 33 | snr = constants.max_input_snr * np.random.random() 34 | wav_i = snr_to_weight(snr) * wav_i 35 | wav_j = snr_to_weight(-snr) * wav_j 36 | return [wav_i, wav_j] 37 | 38 | def simulate_mixture(hparams, wav_i, wav_j): 39 | wavs = prepare_wavs(hparams, wav_i, wav_j) 40 | 41 | mixture = wavs[0] + wavs[1] 42 | if hparams.add_white_noise: 43 | l = constants.Fs / 4 44 | start = int(len(mixture) / 2) 45 | mixture[start: start + l] += 0.5 * np.max(np.abs(mixture)) * np.random.randn(l) 46 | 47 | X_mixture, phase = data_lib.wav_to_nn_representation(mixture) 48 | return X_mixture, phase, wavs 49 | 50 | 51 | ###################### 52 | ### PICK WAV FILES ### 53 | ###################### 54 | def get_wav(hparams, files): 55 | """Return a waveform that is long enough using files.""" 56 | wav = np.array([]) 57 | while (len(wav) <= hparams.waveform_size): # [<=] so randint() (below) works 58 | wav_file = np.random.choice(files) 59 | Fs, x = read(wav_file) 60 | assert Fs == constants.Fs 61 | wav = np.concatenate((wav, x)) 62 | 63 | start = np.random.randint(len(wav) - hparams.waveform_size) 64 | return wav[start : start + hparams.waveform_size] 65 | 66 | def pick_read_wavs(hparams, spk_folders): 67 | spk_folder_i = np.random.choice(spk_folders) 68 | spk_folder_j = np.random.choice(spk_folders) 69 | 70 | while spk_folder_i == spk_folder_j: 71 | spk_folder_j = np.random.choice(spk_folders) 72 | 73 | wav_ext = "wv1" if "WSJ0" in hparams.data_source else "wav" 74 | wav_i = get_wav(hparams, glob.glob(spk_folder_i + "/*/*.%s" % wav_ext)) 75 | wav_j = get_wav(hparams, glob.glob(spk_folder_j + "/*/*.%s" % wav_ext)) 76 | return wav_i, wav_j 77 | 78 | 79 | ######################### 80 | ### DATA CONSTRUCTION ### 81 | ######################### 82 | def mp_worker(spk_folders, hparams): 83 | np.random.seed() # b/c Multiprocessing units get same seed 84 | wav_i, wav_j = pick_read_wavs(hparams, spk_folders) 85 | 86 | X_mixture, phase, wavs = simulate_mixture(hparams, wav_i, wav_j) 87 | X_sources, mask = make_X_and_mask(hparams, wavs) 88 | 89 | sources = np.zeros((hparams.num_targets, hparams.waveform_size)) 90 | sources[0, :] = wavs[0] 91 | sources[1, :] = wavs[1] 92 | 93 | return X_mixture, phase, mask, sources, X_sources 94 | 95 | def get_inference_data(hparams): 96 | files = glob.glob("data/ours/custom/*2*.wav") 97 | X_mixtures, phases = [], [] 98 | for i in range(hparams.batch_size): 99 | mixture = get_wav(hparams, files) 100 | X_mixture, phase = data_lib.wav_to_nn_representation(mixture) 101 | X_mixtures.append(X_mixture) 102 | phases.append(phase) 103 | 104 | return np.stack(X_mixtures), np.stack(phases) 105 | 106 | def make_data(pool, hparams, test=False): 107 | if hparams.data_source == "RTL": 108 | return rtl_loader.make_data(pool, hparams, test) 109 | 110 | if hparams.data_source == "WSJ0": 111 | use_dir = "test" if test else "train" 112 | spk_path = 'data/wsj0/%s/*' % use_dir 113 | spk_folders = glob.glob(spk_path) 114 | elif hparams.data_source == "LIBRI": 115 | use_dir = "test-clean" if test else "train-clean-100" 116 | spk_path = 'data/LibriSpeech/%s/*' % use_dir 117 | spk_folders = glob.glob(spk_path) 118 | else: 119 | raise Exception("Invalid data source: %s", hparams.data_source) 120 | 121 | r = pool.map_async(partial(mp_worker, hparams=hparams), [spk_folders] * hparams.batch_size) 122 | results = r.get() 123 | 124 | stacked_results = [] 125 | for result in zip(*results): 126 | stacked_results.append(np.stack(result)) 127 | 128 | return stacked_results 129 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Build the graph and train it.""" 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | from optparse import OptionParser 6 | import argparse 7 | import pdb 8 | 9 | import graph 10 | import train 11 | import inference 12 | import utilities 13 | import build_hparams 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Train the Keller Net..') 17 | parser.add_argument("-l", type=int, dest="load_experiment_num", default=0, 18 | help="Load weights from Model, don't load on 0") 19 | parser.add_argument('--hparams', type=str, 20 | help='Comma separated list of "name=value" pairs.') 21 | parser.add_argument('-e', action='store_true', default=False, dest='eval_model', 22 | help='Run summary and test alot to evaluate results.') 23 | FLAGS = parser.parse_args() 24 | 25 | def initialize(): 26 | """Build hyperparameters, setup loading/saving""" 27 | hparams = build_hparams.build_hparams(FLAGS) 28 | hparams.add_hparam('load_experiment_num', FLAGS.load_experiment_num) 29 | 30 | experiment_num = utilities.prepare_experiment() 31 | hparams.add_hparam('logdir', "logs/%d" % (experiment_num)) 32 | 33 | hparams.add_hparam('loaddir', "impt_logs/%d/" % (hparams.load_experiment_num)) 34 | hparams.add_hparam('savedir', hparams.logdir + "/model.ckpt") 35 | 36 | return hparams 37 | 38 | def main(): 39 | """Build hparams, the graph, and train it.""" 40 | hparams = initialize() 41 | 42 | if hparams.run_inference_test: 43 | hparams.batch_size = 2 44 | X_mixtures, phases, inference_summaries = graph.build_inference_graph(hparams) 45 | inference.run_inference(hparams, X_mixtures, phases, inference_summaries) 46 | else: 47 | inputs, embedding_info, loss, loss_summary, summaries = ( 48 | graph.build_train_graph(hparams)) 49 | 50 | train.run_train(hparams, inputs, embedding_info, 51 | loss, loss_summary, summaries) 52 | 53 | if __name__ == '__main__': 54 | main() 55 | -------------------------------------------------------------------------------- /mir_bss_eval.py: -------------------------------------------------------------------------------- 1 | """Python implementation of Blind Source Separation metric* 2 | 3 | This code implements BSS using method the time distortion method from III.B 4 | in the paper*. 5 | 6 | This code is a slim version of: 7 | - https://github.com/craffel/mir_eval/blob/master/mir_eval/separation.py 8 | 9 | Which was implemented based on the following bss_eval MATLAB toolbox: 10 | http://bass-db.gforge.inria.fr/bss_eval/ 11 | 12 | - :func:`mir_eval.separation.bss_eval_sources`: Computes the bss_eval_sources 13 | metrics from bss_eval, which optionally optimally match the estimated sources 14 | to the reference sources and measure the distortion and artifacts present in 15 | the estimated sources as well as the interference between them. 16 | 17 | * [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric 18 | Févotte, "Performance measurement in blind audio source separation," IEEE 19 | Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006. 20 | """ 21 | 22 | import numpy as np 23 | import scipy.fftpack 24 | from scipy.linalg import toeplitz 25 | from scipy.signal import fftconvolve 26 | import collections 27 | import itertools 28 | 29 | 30 | def _safe_db(num, den): 31 | """Properly handle the potential +Inf db SIR, instead of raising a 32 | RuntimeWarning. Only denominator is checked because the numerator can never 33 | be 0. 34 | """ 35 | if den == 0: 36 | return np.Inf 37 | return 10 * np.log10(num / den) 38 | 39 | def _bss_source_crit(s_true, e_spat, e_interf, e_artif): 40 | """Measurement of the separation quality for a given source in terms of 41 | filtered true source, interference and artifacts. 42 | """ 43 | # energy ratios 44 | s_filt = s_true + e_spat 45 | sdr = _safe_db(np.sum(s_filt**2), np.sum((e_interf + e_artif)**2)) 46 | sir = _safe_db(np.sum(s_filt**2), np.sum(e_interf**2)) 47 | sar = _safe_db(np.sum((s_filt + e_interf)**2), np.sum(e_artif**2)) 48 | return (sdr, sir, sar) 49 | 50 | def _project(reference_sources, estimated_source, flen): 51 | """least-squares projection of estimated source on the subspace spanned by 52 | delayed versions of reference sources, with delays between 0 and flen-1 53 | """ 54 | nsrc = reference_sources.shape[0] 55 | nsampl = reference_sources.shape[1] 56 | 57 | # computing coefficients of least squares problem via fft ## 58 | # zero padding and fft of input data 59 | reference_sources = np.hstack((reference_sources, 60 | np.zeros((nsrc, flen - 1)))) 61 | estimated_source = np.hstack((estimated_source, np.zeros(flen - 1))) 62 | n_fft = int(2**np.ceil(np.log2(nsampl + flen - 1.))) 63 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 64 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 65 | # inner products between delayed versions of reference_sources 66 | g = np.zeros((nsrc * flen, nsrc * flen)) 67 | for i in range(nsrc): 68 | for j in range(nsrc): 69 | ssf = sf[i] * np.conj(sf[j]) 70 | ssf = np.real(scipy.fftpack.ifft(ssf)) 71 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 72 | r=ssf[:flen]) 73 | g[i * flen: (i+1) * flen, j * flen: (j+1) * flen] = ss 74 | g[j * flen: (j+1) * flen, i * flen: (i+1) * flen] = ss.T 75 | # inner products between estimated_source and delayed versions of 76 | # reference_sources 77 | d = np.zeros(nsrc * flen) 78 | for i in range(nsrc): 79 | ssef = sf[i] * np.conj(sef) 80 | ssef = np.real(scipy.fftpack.ifft(ssef)) 81 | d[i * flen: (i+1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1])) 82 | 83 | # computing projection 84 | # distortion filters 85 | try: 86 | c = np.linalg.solve(g, d).reshape(flen, nsrc, order='f') 87 | except np.linalg.linalg.linalgerror: 88 | c = np.linalg.lstsq(g, d)[0].reshape(flen, nsrc, order='f') 89 | # filtering 90 | sproj = np.zeros(nsampl + flen - 1) 91 | for i in range(nsrc): 92 | sproj += fftconvolve(c[:, i], reference_sources[i])[:nsampl + flen - 1] 93 | return sproj 94 | 95 | def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen): 96 | """Decomposition of an estimated source image into four components 97 | representing respectively the true source image, spatial (or filtering) 98 | distortion, interference and artifacts, derived from the true source 99 | images using multichannel time-invariant filters. 100 | """ 101 | nsampl = estimated_source.size 102 | # decomposition 103 | # true source image 104 | s_true = np.hstack((reference_sources[j], np.zeros(flen - 1))) 105 | # spatial (or filtering) distortion 106 | e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source, 107 | flen) - s_true 108 | # interference 109 | e_interf = _project(reference_sources, 110 | estimated_source, flen) - s_true - e_spat 111 | # artifacts 112 | e_artif = -s_true - e_spat - e_interf 113 | e_artif[:nsampl] += estimated_source 114 | return (s_true, e_spat, e_interf, e_artif) 115 | 116 | def bss_eval_sources(reference_sources, estimated_sources, 117 | compute_permutation=True): 118 | """ 119 | Ordering and measurement of the separation quality for estimated source 120 | signals in terms of filtered true source, interference and artifacts. 121 | 122 | The decomposition allows a time-invariant filter distortion of length 123 | 512, as described in Section III.B of [#vincent2006performance]_. 124 | 125 | Passing ``False`` for ``compute_permutation`` will improve the computation 126 | performance of the evaluation; however, it is not always appropriate and 127 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_sources. 128 | 129 | Examples 130 | -------- 131 | >>> # reference_sources[n] should be an ndarray of samples of the 132 | >>> # n'th reference source 133 | >>> # estimated_sources[n] should be the same for the n'th estimated 134 | >>> # source 135 | >>> (sdr, sir, sar, 136 | ... perm) = mir_eval.separation.bss_eval_sources(reference_sources, 137 | ... estimated_sources) 138 | 139 | Parameters 140 | ---------- 141 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 142 | matrix containing true sources (must have same shape as 143 | estimated_sources) 144 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 145 | matrix containing estimated sources (must have same shape as 146 | reference_sources) 147 | compute_permutation : bool, optional 148 | compute permutation of estimate/source combinations (True by default) 149 | 150 | Returns 151 | ------- 152 | sdr : np.ndarray, shape=(nsrc,) 153 | vector of Signal to Distortion Ratios (SDR) 154 | sir : np.ndarray, shape=(nsrc,) 155 | vector of Source to Interference Ratios (SIR) 156 | sar : np.ndarray, shape=(nsrc,) 157 | vector of Sources to Artifacts Ratios (SAR) 158 | perm : np.ndarray, shape=(nsrc,) 159 | vector containing the best ordering of estimated sources in 160 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 161 | true source number ``j``). Note: ``perm`` will be ``[0, 1, ..., 162 | nsrc-1]`` if ``compute_permutation`` is ``False``. 163 | 164 | References 165 | ---------- 166 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 167 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 168 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 169 | (2007-2010): Achievements and remaining challenges", Signal Processing, 170 | 92, pp. 1928-1936, 2012. 171 | 172 | """ 173 | 174 | # make sure the input is of shape (nsrc, nsampl) 175 | if estimated_sources.ndim == 1: 176 | estimated_sources = estimated_sources[np.newaxis, :] 177 | if reference_sources.ndim == 1: 178 | reference_sources = reference_sources[np.newaxis, :] 179 | 180 | #validate(reference_sources, estimated_sources) 181 | # If empty matrices were supplied, return empty lists (special case) 182 | if reference_sources.size == 0 or estimated_sources.size == 0: 183 | return np.array([]), np.array([]), np.array([]), np.array([]) 184 | 185 | nsrc = estimated_sources.shape[0] 186 | 187 | # does user desire permutations? 188 | if compute_permutation: 189 | # compute criteria for all possible pair matches 190 | sdr = np.empty((nsrc, nsrc)) 191 | sir = np.empty((nsrc, nsrc)) 192 | sar = np.empty((nsrc, nsrc)) 193 | for jest in range(nsrc): 194 | for jtrue in range(nsrc): 195 | s_true, e_spat, e_interf, e_artif = \ 196 | _bss_decomp_mtifilt(reference_sources, 197 | estimated_sources[jest], 198 | jtrue, 512) 199 | sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \ 200 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 201 | 202 | # select the best ordering 203 | perms = list(itertools.permutations(list(range(nsrc)))) 204 | mean_sir = np.empty(len(perms)) 205 | dum = np.arange(nsrc) 206 | for (i, perm) in enumerate(perms): 207 | mean_sir[i] = np.mean(sir[perm, dum]) 208 | popt = perms[np.argmax(mean_sir)] 209 | idx = (popt, dum) 210 | return (sdr[idx], sir[idx], sar[idx], np.asarray(popt)) 211 | else: 212 | # compute criteria for only the simple correspondence 213 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 214 | sdr = np.empty(nsrc) 215 | sir = np.empty(nsrc) 216 | sar = np.empty(nsrc) 217 | for j in range(nsrc): 218 | s_true, e_spat, e_interf, e_artif = \ 219 | _bss_decomp_mtifilt(reference_sources, 220 | estimated_sources[j], 221 | j, 512) 222 | sdr[j], sir[j], sar[j] = \ 223 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 224 | 225 | # return the default permutation for compatibility 226 | popt = np.arange(nsrc) 227 | return (sdr, sir, sar, popt) 228 | 229 | def mir_eval_SDR_impl(sources, source_estimates): 230 | batch_size = sources.shape[0] 231 | 232 | SDRs = np.zeros(batch_size, dtype=np.float32) 233 | for b in range(batch_size): 234 | (sdr, sir, sar, perm) = bss_eval_sources(sources[b, :, :], source_estimates[b, :, :]) 235 | SDRs[b] = np.mean(sdr) 236 | 237 | return SDRs 238 | 239 | def mir_eval_SDR(sources, source_estimates): 240 | SDRs = tf.py_func(mir_eval_SDR_impl, [sources, source_estimates], tf.float32) 241 | SDRs.set_shape(sources.shape[0]) 242 | return tf.reduce_mean(SDRs) 243 | -------------------------------------------------------------------------------- /preprocess_libri.py: -------------------------------------------------------------------------------- 1 | """Convert FLAC files in LibriSpeech to WAV files.""" 2 | import glob 3 | import multiprocessing 4 | import subprocess 5 | import pdb 6 | import constants 7 | 8 | 9 | def mp_worker(file_loc): 10 | wav_file_loc = file_loc[:-5] + ".wav" 11 | subprocess.call("ffmpeg -y -i %s -ar %d %s" % (file_loc, constants.Fs, wav_file_loc), shell=True) 12 | subprocess.call("rm -f %s" % (file_loc), shell=True) 13 | print ("Process %s\tDONE" % file_loc) 14 | 15 | def mp_handler(): 16 | file_locs = [] 17 | file_locs.append(glob.glob('data/LibriSpeech/train-clean-100/*/*/*.flac')) 18 | file_locs.append(glob.glob('data/LibriSpeech/test-clean/*/*/*.flac')) 19 | 20 | for file_loc in file_locs: 21 | print ("Processing %d files." % len(file_loc)) 22 | 23 | p = multiprocessing.Pool(2 * multiprocessing.cpu_count()) 24 | p.map(mp_worker, file_loc) 25 | 26 | if __name__ == '__main__': 27 | mp_handler() 28 | -------------------------------------------------------------------------------- /preprocess_wsj0.py: -------------------------------------------------------------------------------- 1 | """Convert WV1 files in LDC-WSJ0 to WAV files.""" 2 | 3 | import glob 4 | import os 5 | import multiprocessing 6 | from functools import partial 7 | import subprocess 8 | import pdb 9 | 10 | def load_src(name, fpath): 11 | import imp 12 | return imp.load_source(name, os.path.join(os.path.dirname(__file__), fpath)) 13 | 14 | load_src("constants", "../constants.py") 15 | import constants 16 | 17 | 18 | def get_spk_name(file_name): 19 | return file_name[:3] 20 | 21 | def mp_worker(file_loc, dest_dir): 22 | file_name = file_loc.split('/')[-1] 23 | folder_name = get_spk_name(file_name) 24 | whole_dir = dest_dir + (folder_name + '/') * 2 # Use two folders (like in LibriSpeech) 25 | subprocess.call("mkdir -p %s" % whole_dir, shell=True) 26 | 27 | c = int(file_loc[-1]) # Channel 28 | wav_file_loc = whole_dir + file_name[:-4] + ".wv%d" % c 29 | tmp_2_file_loc = whole_dir + file_name[:-4] + "_tmp_2.wav" 30 | tmp_file_loc = whole_dir + file_name[:-4] + "_tmp.wav" 31 | subprocess.call("sph2pipe -f raw %s -f wav %s" % (file_loc, tmp_file_loc), shell=True) 32 | subprocess.call("ffmpeg -y -i %s -ar %d %s" % (tmp_file_loc, constants.Fs, tmp_2_file_loc), shell=True) 33 | 34 | subprocess.call("mv %s %s" % (tmp_2_file_loc, wav_file_loc), shell=True) 35 | subprocess.call("rm %s" % (tmp_file_loc), shell=True) 36 | print ("Process %s\tDONE" % file_loc) 37 | 38 | def mp_handler(): 39 | for c in (1,): # Ignore channel 2 recording 40 | train_file_locs = glob.glob('data/wsj0_sph/*/wsj0/si_tr_s/*/*.wv%d' % c) 41 | train_dest_dir = 'data/wsj0/train/' 42 | pdb.set_trace() 43 | 44 | test_file_locs = glob.glob('data/wsj0_sph/*/wsj0/si_*t_05/*/*.wv%d' % c) 45 | test_dest_dir = 'data/wsj0/test/' 46 | 47 | both_file_locs = (train_file_locs, test_file_locs) 48 | both_dest_dirs = (train_dest_dir, test_dest_dir) 49 | 50 | p = multiprocessing.Pool(2 * multiprocessing.cpu_count()) 51 | for (file_locs, dest_dir) in zip(both_file_locs, both_dest_dirs): 52 | print ("Processing %d files." % len(file_locs)) 53 | p.map(partial(mp_worker, dest_dir=dest_dir), file_locs) 54 | print ("*** Loop completed. ***") 55 | 56 | if __name__ == '__main__': 57 | mp_handler() 58 | -------------------------------------------------------------------------------- /pylogs/last_experiment_num.log: -------------------------------------------------------------------------------- 1 | 60001180 2 | -------------------------------------------------------------------------------- /rtl_loader.py: -------------------------------------------------------------------------------- 1 | """Create batches of RTL data for the network to train on.""" 2 | 3 | import numpy as np 4 | import glob 5 | import pdb 6 | from scipy.io.wavfile import read 7 | from functools import partial 8 | 9 | import data_lib 10 | import loader 11 | import constants 12 | 13 | 14 | #################### 15 | ### MAKE MIXTURE ### 16 | #################### 17 | def read_real_wav(wav_file): 18 | Fs, x = read(wav_file) 19 | assert (Fs == constants.Fs), "FS was: %d | %s" % (Fs, wav_file) 20 | return x 21 | 22 | def pick_mixed_wavs(hparams, folder): 23 | mixed_files = glob.glob(folder + "/*mixed*") 24 | wav_file_mixed = np.random.choice(mixed_files) 25 | dirs = wav_file_mixed.split('/') 26 | ID = "_".join(dirs[-1].split('_')[:2]) # 2_701 27 | 28 | files = glob.glob(folder + "/%s_*" % ID) 29 | files = [f for f in files if "mixed" not in f] 30 | assert len(files) == 2, (ID, files) 31 | 32 | wav_file_i = files[0] 33 | wav_file_j = files[1] 34 | 35 | wav_i = read_real_wav(wav_file_i) 36 | wav_j = read_real_wav(wav_file_j) 37 | mixture = read_real_wav(wav_file_mixed) 38 | assert (len(wav_i) == len(wav_j) == len(mixture)) 39 | 40 | start = np.random.randint(len(wav_i) - hparams.waveform_size) 41 | wav_i = wav_i[start: start + hparams.waveform_size] 42 | wav_j = wav_j[start: start + hparams.waveform_size] 43 | mixture = mixture[start: start + hparams.waveform_size] 44 | 45 | return wav_i, wav_j, mixture 46 | 47 | 48 | ######################### 49 | ### DATA CONSTRUCTION ### 50 | ######################### 51 | def mp_worker(folder, hparams): 52 | np.random.seed() # b/c Multiprocessing units get same seed 53 | wav_i, wav_j, mixture = pick_mixed_wavs(hparams, folder) 54 | 55 | X_mixture, phase = data_lib.wav_to_nn_representation(mixture) 56 | 57 | wavs = [wav_i, wav_j] 58 | X_sources, mask = loader.make_X_and_mask(hparams, wavs) 59 | 60 | sources = np.zeros((hparams.num_targets, hparams.waveform_size)) 61 | sources[0, :] = wavs[0] 62 | sources[1, :] = wavs[1] 63 | 64 | return X_mixture, phase, mask, sources, X_sources 65 | 66 | def make_data(pool, hparams, test=False): 67 | if test: 68 | poss_folders = [] 69 | for i in (7, 16): 70 | poss_folders.append("data/RTL/extracted_%d/libri/*" % i) 71 | else: # train 72 | poss_folders = [] 73 | for i in (2, 10, 12, 14, 15, 17, 18, 19, 20): 74 | poss_folders.append("data/RTL/extracted_%d/libri/*" % i) 75 | 76 | folders = np.random.choice(poss_folders, hparams.batch_size) 77 | r = pool.map_async(partial(mp_worker, hparams=hparams), folders) 78 | results = r.get() 79 | 80 | stacked_results = [] 81 | for result in zip(*results): 82 | stacked_results.append(np.stack(result)) 83 | 84 | return stacked_results 85 | -------------------------------------------------------------------------------- /start_tensorboard.py: -------------------------------------------------------------------------------- 1 | """Start TensorBoard with the specified logs.""" 2 | 3 | import subprocess 4 | import constants 5 | import argparse 6 | import utilities 7 | 8 | parser = argparse.ArgumentParser(description='Train the Keller Net..') 9 | parser.add_argument("-l", type=int, dest="load_experiment_num", default=0, 10 | help="Herro") 11 | parser.add_argument("-p", type=int, dest="port", default=0, 12 | help="Herro") 13 | FLAGS = parser.parse_args() 14 | 15 | 16 | if FLAGS.load_experiment_num: 17 | experiment_num = FLAGS.load_experiment_num 18 | offset_experiment_num = experiment_num + utilities.get_offset() 19 | log_file = "impt_logs/%d" % offset_experiment_num 20 | else: 21 | f = open(constants.LAST_EXPERIMENT_NUM_FILE, 'r') 22 | experiment_num = int(f.readline()[:-1]) 23 | offset_experiment_num = experiment_num + utilities.get_offset() 24 | log_file = "logs/%d" % offset_experiment_num 25 | f.close() 26 | 27 | 28 | port = constants.use_port 29 | if FLAGS.port: 30 | port = FLAGS.port 31 | 32 | print ("Tensorboard for", log_file) 33 | subprocess.call("tensorboard --logdir=%s --port=%d" % 34 | (log_file, port), shell=True) 35 | -------------------------------------------------------------------------------- /summaries.py: -------------------------------------------------------------------------------- 1 | """Code for setting up scalar, audio, and image summaries.""" 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import pdb 6 | from scipy.io import loadmat 7 | import itertools 8 | import string 9 | 10 | import helper 11 | import kmeans 12 | import bss_eval 13 | import mir_bss_eval 14 | import constants 15 | import data_lib 16 | 17 | ORACLE_NAME="Oracle" 18 | ATTRACTOR_NAME="Attractor" 19 | KMEANS_NAME="K-Means" 20 | SDR_PROXY_NAME="SDR_proxy.dB" 21 | 22 | ########################### 23 | ### SUMMARIES & HELPERS ### 24 | ########################### 25 | def itos(i): 26 | """Integer to string.""" 27 | letter_count = dict(zip(range(26), string.ascii_uppercase)) 28 | return letter_count[i] 29 | 30 | def scale_color_impl(x): 31 | """Nonlinear color scaling so colors are more saturated.""" 32 | return (1 / np.log(2)) * tf.log(1 + x) 33 | 34 | def scale_color(x): 35 | """Apply twice for more saturation.""" 36 | return scale_color_impl(scale_color_impl(x)) 37 | 38 | def make_image_summary(X_source_estimates, mask): 39 | # Use mask to calculate red and blue channels respectively 40 | hue = mask[:, :, :, 0] * 1 + mask[:, :, :, 1] * (2/3) 41 | hue = tf.transpose(hue, perm=[0, 2, 1]) 42 | hue = tf.expand_dims(hue, -1) 43 | 44 | # Saturation is determined by power, make it look good 45 | saturation = sum(X_source_estimates) 46 | max_for_batch = tf.reduce_max(saturation, axis=[1, 2], keepdims=True) 47 | saturation = saturation / max_for_batch # Normalize to [0,1] 48 | saturation = scale_color(saturation) # Make colors more saturated 49 | saturation /= 0.95 50 | saturation += 0.05 51 | saturation = tf.expand_dims(saturation, -1) 52 | 53 | value = 0.3 * tf.ones(hue.shape) 54 | 55 | hsv_image = tf.squeeze(tf.stack([hue, saturation, value], axis=3), axis=[4]) 56 | hsv_image = tf.reverse(hsv_image, axis=[1]) 57 | image = tf.image.hsv_to_rgb(hsv_image) 58 | return tf.summary.image("Mixture-Color", image) 59 | 60 | def make_audio_summary(name, signal): 61 | normalizer = tf.reduce_max(tf.abs(signal), axis=1, keepdims=True) 62 | return tf.summary.audio(name, signal / normalizer, constants.Fs) 63 | 64 | 65 | ###################################### 66 | ### CALCULATIONS - ESTIMATES & SDR ### 67 | ###################################### 68 | def get_source_estimates(hparams, mask, tmp_X_source_estimates, phases): 69 | """Take the X_source_estimates (spectrogram) and obtain the raw waveform 70 | representation as well as the original spectorgram (not log magnitude).""" 71 | X_source_estimates, source_estimates = [], [] 72 | 73 | for i in range(hparams.num_targets): 74 | source_estimate, X_source_estimate = ( 75 | tf.py_func(data_lib.nn_representation_to_wav_spect, 76 | [tmp_X_source_estimates[:, :, :, i], phases], [tf.float32, tf.float32])) 77 | 78 | source_estimate.set_shape((hparams.batch_size, hparams.waveform_size)) 79 | X_source_estimate.set_shape((hparams.batch_size, constants.nfreqbins, hparams.ntimebins)) 80 | 81 | X_source_estimates.append(X_source_estimate) 82 | source_estimates.append(source_estimate) 83 | 84 | return X_source_estimates, source_estimates 85 | 86 | def compute_source_estimates(hparams, X_mixtures, mask, phases): 87 | X_source_estimates = tf.expand_dims(X_mixtures, -1) * mask 88 | return get_source_estimates(hparams, mask, X_source_estimates, phases) 89 | 90 | def compute_results(hparams, X_mixtures, mask, sources, phases): 91 | """Take the mask and apply it to the input mixture to obtain quantifiable 92 | results. Since K-means may permute the results we permute to the 93 | correct one.""" 94 | 95 | X_source_estimates, source_estimates = ( 96 | compute_source_estimates(hparams, X_mixtures, mask, phases)) 97 | 98 | stacked_source_estimates = tf.stack(source_estimates, axis=1) 99 | proxy_SDR = bss_eval.eval_proxy_SDR(sources, stacked_source_estimates) 100 | 101 | return X_source_estimates, source_estimates, proxy_SDR 102 | 103 | def compute_SDR_impl(savedir): 104 | savedir = savedir.decode("utf-8") # bytes to string 105 | filename = 'waveforms/%s/example_%d.mat' 106 | SDRs = [] 107 | 108 | print ("\t-- EVALUATING SDR --") 109 | n = constants.AVG_SDR_ON_N_BATCHES 110 | for i in range(n): 111 | if i == n-1 or (i and i % 20 == 0): 112 | print ("\tExample %d/%d" % (i, n)) 113 | try: 114 | data = loadmat(filename % (savedir, i)) 115 | sources = data['sources'] 116 | source_estimates = np.squeeze(data['source_estimates']) 117 | 118 | for b in range(sources.shape[0]): 119 | (sdr, sir, sar, perm) = ( 120 | mir_bss_eval.bss_eval_sources(sources[b, :, :], source_estimates[b, :, :])) 121 | SDRs.append(np.mean(sdr)) 122 | except Exception as e: 123 | print ("\t(%d) Exception in SDR calculation:" % i, e) 124 | 125 | SDRs = np.array(SDRs, dtype=np.float32) 126 | return np.mean(SDRs) if len(SDRs) > 0 else 0.0 127 | 128 | def compute_SDR(hparams): 129 | if not hparams.save_estimate_waveforms: 130 | return tf.constant(0) 131 | name = helper.get_kmeans_waveform_savedir(hparams) 132 | return tf.py_func(compute_SDR_impl, [tf.constant(name)], tf.float32) 133 | 134 | ############### 135 | ### MASKING ### 136 | ############### 137 | def get_mask(hparams, embeddings, centers): 138 | mask_before_squash = tf.einsum("bik,bck->bic", embeddings, centers) 139 | argmax = tf.argmax(mask_before_squash, axis=2) 140 | mask = tf.one_hot(argmax, depth=hparams.num_targets, axis=2) 141 | return tf.reshape(mask, [hparams.batch_size, hparams.ntimebins, 142 | constants.nfreqbins, hparams.num_targets]) 143 | 144 | def get_attractor_mask(hparams, embeddings, attractors): 145 | return get_mask(hparams, embeddings, attractors) 146 | 147 | def permute_mask(X_mixtures, X_sources, mask): 148 | """Permute the order of the masks so it matches the sources best. 149 | We need to do this because k-means can permute the clusters.""" 150 | X_source_estimates = np.expand_dims(X_mixtures, axis=X_mixtures.ndim) * mask 151 | batch_size, _, __, num_targets = X_source_estimates.shape 152 | mask_permuted = np.zeros_like(mask, dtype=np.float32) 153 | 154 | permutations = list(itertools.permutations(range(num_targets))) 155 | for b in range(batch_size): 156 | permutation_errors = [] 157 | 158 | # Calculate the error for all possible permutations 159 | for permutation in permutations: 160 | reconstruction_error = 0 161 | for i in range(num_targets): 162 | diff = X_source_estimates[b, :, :, permutation[i]] - X_sources[b, :, :, i] 163 | reconstruction_error += np.mean(np.square(diff)) 164 | permutation_errors.append(reconstruction_error / num_targets) 165 | 166 | # Pick the permutation with the smallest error and permute 167 | argmin = np.argmin(permutation_errors) 168 | best_permutation = permutations[argmin] 169 | 170 | for i in range(num_targets): 171 | mask_permuted[b, :, :, i] = mask[b, :, :, best_permutation[i]] 172 | 173 | return mask_permuted 174 | 175 | def get_kmeans_mask_impl(hparams, X_mixtures, embeddings, threshold_mask): 176 | centers = kmeans.get_centers(hparams, embeddings, threshold_mask) 177 | return get_mask(hparams, embeddings, centers) 178 | 179 | def get_kmeans_mask(hparams, X_mixtures, X_sources, embeddings, threshold_mask): 180 | mask = get_kmeans_mask_impl(hparams, X_mixtures, embeddings, threshold_mask) 181 | return tf.py_func(permute_mask, [X_mixtures, X_sources, mask], tf.float32) 182 | 183 | ############################ 184 | ### SETUP (ENTRY POINTS) ### 185 | ############################ 186 | def setup_eval_result_summary(hparams, X_mixtures, phases, kmeans_mask): 187 | # Put ops back in their original shape 188 | X_mixtures = helper.uncollapse_freq_into_time(hparams, X_mixtures) 189 | 190 | mixture, _ = (tf.py_func(data_lib.nn_representation_to_wav_spect, 191 | [X_mixtures, phases], [tf.float32, tf.float32])) 192 | 193 | X_source_estimates, source_estimates = ( 194 | compute_source_estimates(hparams, X_mixtures, kmeans_mask, phases)) 195 | 196 | summaries = [] 197 | # Generate an audio summary for the input (mixture) 198 | with tf.name_scope(KMEANS_NAME): 199 | summaries.append(make_audio_summary("Mixture", mixture)) 200 | for (i, source_estimate) in enumerate(source_estimates): 201 | name = "Speaker_%s" % itos(i) 202 | summaries.append(make_audio_summary(name, source_estimate)) 203 | summaries.append(make_image_summary(X_source_estimates, kmeans_mask)) 204 | 205 | return summaries 206 | 207 | def setup_inference_summary(hparams, threshold_mask, X_mixtures, phases, embeddings): 208 | kmeans_mask = get_kmeans_mask_impl(hparams, X_mixtures, embeddings, threshold_mask) 209 | return setup_eval_result_summary(hparams, X_mixtures, phases, kmeans_mask) 210 | 211 | 212 | def create_audio_image_summaries(mask_name, mask, train_test_summ_ops, 213 | mixture, X_source_estimates, source_estimates): 214 | """Generate summaries for audio and images of the spectrogram.""" 215 | names = ("train-%s" % mask_name, "test-%s" % mask_name) 216 | for (name, ops) in zip(names, train_test_summ_ops): 217 | with tf.name_scope(name): 218 | ops.append(make_audio_summary("Mixture", mixture)) 219 | for (i, source_estimate) in enumerate(source_estimates): 220 | name = "Speaker_%s" % itos(i) 221 | ops.append(make_audio_summary(name, source_estimate)) 222 | ops.append(make_image_summary(X_source_estimates, mask)) 223 | 224 | def create_summaries(hparams, threshold_mask, attractors, 225 | X_mixtures, phases, oracle_mask, sources, X_sources, embeddings): 226 | # Setup summary lists 227 | train_summ_ops, test_summ_ops, oracle_summ_ops, SDR_summ_ops = [], [], [], [] 228 | 229 | # Separate the TF axis into the two original axis: T, F 230 | X_sources = helper.uncollapse_freq_into_time(hparams, X_sources) 231 | X_mixtures = helper.uncollapse_freq_into_time(hparams, X_mixtures) 232 | oracle_mask = helper.uncollapse_freq_into_time(hparams, oracle_mask) 233 | 234 | # Generate an audio summary for the input (mixture) 235 | mixture, _ = (tf.py_func(data_lib.nn_representation_to_wav_spect, 236 | [X_mixtures, phases], [tf.float32, tf.float32])) 237 | 238 | train_test_summ_ops = (train_summ_ops, test_summ_ops) 239 | # Oracle summaries 240 | X_source_estimates, source_estimates, proxy_SDR = ( 241 | compute_results(hparams, X_mixtures, oracle_mask, sources, phases)) 242 | create_audio_image_summaries(ORACLE_NAME, oracle_mask, train_test_summ_ops, 243 | mixture, X_source_estimates, source_estimates) 244 | with tf.name_scope(ORACLE_NAME): 245 | oracle_summ_ops.append(tf.summary.scalar(SDR_PROXY_NAME, proxy_SDR)) 246 | tf.add_to_collection(constants.ORACLE_SRC_EST_NAME, tf.stack(source_estimates, axis=1)) 247 | 248 | # Attractor Summaries (train & test) 249 | attractor_mask = get_attractor_mask(hparams, embeddings, attractors) 250 | X_source_estimates, source_estimates, proxy_SDR = ( 251 | compute_results(hparams, X_mixtures, attractor_mask, sources, phases)) 252 | create_audio_image_summaries(ATTRACTOR_NAME, attractor_mask, train_test_summ_ops, 253 | mixture, X_source_estimates, source_estimates) 254 | with tf.name_scope(ATTRACTOR_NAME): 255 | # train & test on same graph 256 | proxy_sdr_summary = tf.summary.scalar(SDR_PROXY_NAME, proxy_SDR) 257 | train_summ_ops.append(proxy_sdr_summary) 258 | test_summ_ops.append(proxy_sdr_summary) 259 | 260 | # Oracle Sumaries (train & test) 261 | kmeans_mask = get_kmeans_mask(hparams, X_mixtures, X_sources, embeddings, 262 | threshold_mask) 263 | X_source_estimates, source_estimates, proxy_SDR = ( 264 | compute_results(hparams, X_mixtures, kmeans_mask, sources, phases)) 265 | create_audio_image_summaries(KMEANS_NAME, attractor_mask, train_test_summ_ops, 266 | mixture, X_source_estimates, source_estimates) 267 | SDR = compute_SDR(hparams) 268 | with tf.name_scope(KMEANS_NAME): 269 | # train & test on same graph 270 | proxy_sdr_summary = tf.summary.scalar(SDR_PROXY_NAME, proxy_SDR) 271 | train_summ_ops.append(proxy_sdr_summary) 272 | test_summ_ops.append(proxy_sdr_summary) 273 | SDR_summ_ops.append(tf.summary.scalar("SDR.dB", SDR)) 274 | 275 | tf.add_to_collection(constants.KMEANS_SRC_EST_NAME, tf.stack(source_estimates, axis=1)) 276 | 277 | return train_summ_ops, test_summ_ops, oracle_summ_ops, SDR_summ_ops 278 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Train the parameters of the graph.""" 2 | 3 | import tensorflow as tf 4 | import time 5 | import numpy as np 6 | import multiprocessing 7 | import os 8 | import pdb 9 | from scipy.io import savemat 10 | 11 | import loader 12 | import helper 13 | import embedding_summary 14 | import constants 15 | 16 | 17 | def save_waveforms(hparams, step, sources, oracle_src_ests, kmeans_src_ests): 18 | example_num = (step // hparams.test_frequency) % constants.AVG_SDR_ON_N_BATCHES 19 | 20 | if hparams.save_oracle_waveforms: 21 | oracle_name = helper.get_oracle_waveform_savedir(hparams) 22 | os.makedirs("waveforms/%s" % oracle_name, exist_ok=True) 23 | savemat("waveforms/%s/example_%d" % (oracle_name, example_num), 24 | {"sources": sources, "source_estimates": oracle_src_ests}) 25 | 26 | if hparams.save_estimate_waveforms: 27 | kmeans_name = helper.get_kmeans_waveform_savedir(hparams) 28 | os.makedirs("waveforms/%s" % kmeans_name, exist_ok=True) 29 | savemat("waveforms/%s/example_%d" % (kmeans_name, example_num), 30 | {"sources": sources, "source_estimates": kmeans_src_ests}) 31 | 32 | def make_feed_dict(hparams, pool, inputs, test=False): 33 | mixtures, phases, masks, sources, X_sources = ( 34 | loader.make_data(pool, hparams, test)) 35 | 36 | return {inputs[0]: mixtures, inputs[1]: phases, inputs[2]: masks, 37 | inputs[3]: sources, inputs[4]: X_sources} 38 | 39 | def setup(hparams, sess, saver): 40 | """Initialize variables, build writers, restore model if needed.""" 41 | sess.run(tf.global_variables_initializer()) 42 | 43 | train_writer = tf.summary.FileWriter(hparams.logdir + "/train", sess.graph) 44 | test_writer = tf.summary.FileWriter(hparams.logdir + "/test", sess.graph) 45 | oracle_writer = tf.summary.FileWriter(hparams.logdir + "/oracle", sess.graph) 46 | 47 | if hparams.load_experiment_num: 48 | print ("Loading: %s" % hparams.loaddir) 49 | saver.restore(sess, tf.train.latest_checkpoint(hparams.loaddir)) 50 | 51 | return train_writer, test_writer, oracle_writer 52 | 53 | def get_saved_ops(): 54 | train = tf.get_collection(constants.TRAIN_OP_NAME) 55 | oracle_src_ests_op = tf.get_collection(constants.ORACLE_SRC_EST_NAME) 56 | kmeans_src_ests_op = tf.get_collection(constants.KMEANS_SRC_EST_NAME) 57 | return train, oracle_src_ests_op, kmeans_src_ests_op 58 | 59 | def run_train(hparams, inputs, embedding_info, loss, loss_summary, summaries): 60 | mixtures, phases, masks, sources, X_sources = inputs 61 | train_summary, test_summary, oracle_summary, SDR_summary = summaries 62 | 63 | embedding_assign = embedding_info.get_assign_op() 64 | train, oracle_src_ests_op, kmeans_src_ests_op = get_saved_ops() 65 | pool = multiprocessing.Pool(multiprocessing.cpu_count()) 66 | saver = tf.train.Saver() 67 | 68 | with tf.Session() as sess: 69 | train_writer, test_writer, oracle_writer = setup(hparams, sess, saver) 70 | 71 | for step in range(hparams.max_steps): 72 | start = time.time() 73 | feed_dict = make_feed_dict(hparams, pool, inputs) 74 | 75 | # Gradient step on training set 76 | if step % hparams.summary_frequency == 0: 77 | result = sess.run([loss, embedding_assign, train_summary, oracle_summary, 78 | train], feed_dict=feed_dict) 79 | raw_loss, _, raw_summary, raw_oracle_summary, _ = result 80 | 81 | oracle_writer.add_summary(raw_oracle_summary, step) 82 | embedding_info.visualize_embeddings(train_writer) 83 | else: 84 | result = sess.run([loss, loss_summary, train], feed_dict=feed_dict) 85 | raw_loss, raw_summary, _ = result 86 | 87 | train_writer.add_summary(raw_summary, step) 88 | 89 | # Evaluate on Test Set 90 | if step and step % hparams.test_frequency == 0: 91 | test_feed_dict = make_feed_dict(hparams, pool, inputs, test=True) 92 | result = sess.run([loss, test_summary, oracle_src_ests_op, 93 | kmeans_src_ests_op], feed_dict=test_feed_dict) 94 | 95 | raw_loss, raw_summary, oracle_src_ests, kmeans_src_ests = result 96 | test_writer.add_summary(raw_summary, step) 97 | save_waveforms(hparams, step, test_feed_dict[sources], oracle_src_ests, kmeans_src_ests) 98 | print ("\t%d) EVAL-Loss: %.3f" % (step, raw_loss)) 99 | 100 | if step and step % hparams.sdr_frequency == 0: 101 | raw_summary = sess.run([SDR_summary])[0] 102 | test_writer.add_summary(raw_summary, step) 103 | 104 | # Save Model 105 | if step % hparams.save_frequency == 0: 106 | # The embeddings are read from saved weights, make sure the TSV matches 107 | embedding_summary.write_tsv(hparams, feed_dict[mixtures], feed_dict[masks]) 108 | saver.save(sess, hparams.savedir, step) 109 | 110 | print ("Model Saved.") 111 | 112 | helper.flush(train_writer, test_writer, oracle_writer) 113 | print ("\t%d) Loss: %.3f || Elapsed = %.3f secs" % (step, raw_loss, 114 | (time.time() - start))) 115 | 116 | -------------------------------------------------------------------------------- /tuples.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import tensorflow as tf 3 | 4 | Optimizer = collections.namedtuple('Optimizer', 'func') 5 | Adam = Optimizer(func=tf.train.AdamOptimizer) 6 | RMSProp = Optimizer(func=tf.train.RMSPropOptimizer) 7 | SGD = Optimizer(func=tf.train.GradientDescentOptimizer) 8 | -------------------------------------------------------------------------------- /utilities.py: -------------------------------------------------------------------------------- 1 | """Keep track of the experiment number and save all hyperparameters to new files.""" 2 | 3 | import subprocess 4 | import pdb 5 | import socket 6 | import constants 7 | 8 | 9 | def get_offset(): 10 | # Use a different offset for other machines 11 | if 'fattire' in socket.gethostname(): 12 | return 200000000 # Brian 13 | return 100000000 14 | 15 | def prepare_experiment(): 16 | """Copy all hyperparameter files for this experiment to a different folder. 17 | Return the current experiment num""" 18 | 19 | f = open(constants.LAST_EXPERIMENT_NUM_FILE, 'r') 20 | experiment_num = int(f.readline()[:-1]) + 1 21 | f.close() 22 | 23 | offset_experiment_num = get_offset() + experiment_num 24 | 25 | fnames = ("hyperparameters", "RNNparameters", "CNNparameters", "constants") 26 | for fname in fnames: 27 | subprocess.Popen(['cp %s.py hparams_logs/%d_%s.py' % 28 | (fname, offset_experiment_num, fname)], shell=True) 29 | 30 | f = open(constants.LAST_EXPERIMENT_NUM_FILE, 'w') 31 | f.write(str(experiment_num) + '\n') 32 | f.close() 33 | return offset_experiment_num 34 | --------------------------------------------------------------------------------