├── .gitignore ├── png └── wavenet.png ├── hyperparams.py ├── network.py ├── eval.py ├── README.md ├── data_utils.py ├── train.py └── module.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.DS_Store 3 | .git/* 4 | .idea/* 5 | data/* 6 | logdir/* 7 | -------------------------------------------------------------------------------- /png/wavenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soobinseo/wavenet/HEAD/png/wavenet.png -------------------------------------------------------------------------------- /hyperparams.py: -------------------------------------------------------------------------------- 1 | hidden_dim = 128 2 | dilation = [1,2,4,8,16,32,64,128,256,512] 3 | sample_rate = 16000 4 | timestep = 6080 5 | is_training = True 6 | use_mulaw = True 7 | batch_size = 1 8 | num_epochs = 100 9 | save_dir = './logdir' 10 | test_data = 'test.wav' -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from module import * 2 | import hyperparams as hp 3 | 4 | def network(input_, use_mulaw=hp.use_mulaw): 5 | input_ = conv1d(input_, output_channels=hp.hidden_dim, filter_width=3) 6 | 7 | skip_connections = list() 8 | for i in hp.dilation: 9 | skip, res = residual_block(input_, rate=i, scope="res_%d" % i) 10 | input_ = res 11 | skip_connections.append(skip) 12 | 13 | skip_output = tf.add_n(skip_connections) 14 | output = skip_connection(skip_output, use_mulaw=use_mulaw) 15 | 16 | return output 17 | 18 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from network import * 2 | from data_utils import * 3 | import hyperparams as hp 4 | import librosa 5 | 6 | class Graph: 7 | def __init__(self): 8 | self.graph = tf.Graph() 9 | 10 | with self.graph.as_default(): 11 | self.x = tf.placeholder(tf.float32, [None, hp.timestep, 1], name='X') 12 | 13 | output = network(self.x, use_mulaw=hp.use_mulaw) 14 | 15 | if hp.use_mulaw: 16 | self.prediction = mu_law_decode(tf.argmax(output, axis=2)) 17 | else: 18 | self.prediction = tf.squeeze(output, -1) 19 | 20 | def main(): 21 | 22 | g = Graph() 23 | 24 | mixture = librosa.load('./data/' + hp.test_data, sr=hp.sample_rate)[0] 25 | mixture_len = len(mixture) // hp.timestep 26 | print mixture_len 27 | mixture = np.expand_dims(mixture[:mixture_len * hp.timestep].reshape([-1,hp.timestep]),-1) 28 | 29 | with g.graph.as_default(): 30 | 31 | with tf.Session() as sess: 32 | saver = tf.train.Saver() 33 | saver.restore(sess, tf.train.latest_checkpoint(hp.save_dir)) 34 | print "restore successfully!" 35 | 36 | outputs = [] 37 | for part in mixture: 38 | part = np.expand_dims(part, axis=0) 39 | output = sess.run(g.prediction, feed_dict={g.x:part}) 40 | np.squeeze(output, axis=0) 41 | outputs.append(output) 42 | 43 | result = np.vstack(outputs).reshape(-1) 44 | librosa.output.write_wav("./data/result.wav", result, sr=hp.sample_rate) 45 | 46 | if __name__ == '__main__': 47 | main() 48 | 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wavenet 2 | 3 | 4 | 5 | ## Description 6 | * This is a Tensorflow implementaion of Audio source separation (mixture to vocal) using the Wavenet. Although the Wavenet model used the causal convolution (only previous sequence must be used for training) to generate(predict) next sequence, this task is for audio separation purposes, so the latter sequence may be used for training. I used original dilated 1-D convolution. Except this, the network structures are the same as the [paper](https://deepmind.com/blog/wavenet-generative-model-raw-audio/). See the file `hyperparams.py` for the detailed hyperparameters. 7 | 8 | ## Requirements 9 | * NumPy >= 1.11.1 10 | * TensorFlow >= 1.0.0 11 | * librosa 12 | 13 | ## Data 14 | I used DSD100 dataset which consists of pairs of mixture audio files and vocal audio files. The complete dataset (~14 GB) can be downloaded [here](http://liutkus.net/DSD100.zip). The data was pre-processed with sample_rate=16000, and was divided into 380ms units. Therefore, the number of timesteps for network input was then 6080 raw data. 15 | 16 | ## File description 17 | * `hyperparams.py` includes all hyper parameters that are needed. 18 | * `data_utils.py` loads training data and preprocess it into units of raw data sequences. 19 | * `modules.py` contains all methods, building blocks and skip connections for networks. 20 | * `networks.py` builds networks. 21 | * `train.py` is for training. 22 | * `eval.py` is for generating separated vocal sample. 23 | 24 | ## Training the network 25 | * STEP 1. Adjust hyper parameters in `hyperparams.py` if necessary. 26 | * STEP 2. Download and extract DSD100 data as mentioned above at 'data' directory, and run `data_utils.py`. 27 | * STEP 3. Run `train.py`. 28 | 29 | ## Generate seperated vocal audio 30 | * Prepare a test data (the name of this data should be defined at hyperparams.py) and locate it in 'data' directory and run `eval.py`. 31 | 32 | ## Notes 33 | * I applied L1-loss instead of NLL-loss using the mu-law companding. -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import librosa 4 | import hyperparams as hp 5 | 6 | def load_data(): 7 | mixtures, vocals = list(), list() 8 | for path, subdirs, files in os.walk('./data/DSD100/Mixtures/Dev'): 9 | for name in [f for f in files if f.endswith(".wav")]: 10 | # a = librosa.load(os.path.join(path, name), sr=44100)[0].shape 11 | mixtures.append(os.path.join(path, name)) 12 | 13 | for path, subdirs, files in os.walk('./data/DSD100/Sources/Dev'): 14 | for subdir in subdirs: 15 | vocal = os.path.join(os.path.join(path, subdir), "vocals.wav") 16 | vocals.append(vocal) 17 | 18 | num_wavs = len(mixtures) 19 | 20 | return mixtures, vocals, num_wavs 21 | 22 | def get_rawwave(_input): 23 | return librosa.load(_input, sr=hp.sample_rate) 24 | 25 | def make_rawdata(is_training=True, name="data"): 26 | 27 | m, v, n = load_data() 28 | arrays = [] 29 | arrays_2 = [] 30 | for i, j in zip(m, v): 31 | data = get_rawwave(i)[0] 32 | lens = len(data) // hp.timestep 33 | arrays.append(np.expand_dims(np.reshape(data[:hp.timestep * lens], [-1, hp.timestep]), -1)) 34 | 35 | if is_training: 36 | data_2 = get_rawwave(j)[0] 37 | arrays_2.append(np.expand_dims(np.reshape(data_2[:hp.timestep * lens], [-1, hp.timestep]), -1)) 38 | print np.vstack(arrays).shape 39 | np.save("./data/mixtures_%s.npy" % name, np.vstack(arrays)) 40 | if is_training: 41 | np.save("./data/vocals.npy", np.vstack(arrays_2)) 42 | 43 | def dataset_shuffling(x, y): 44 | shuffled_idx = np.arange(len(y)) 45 | np.random.shuffle(shuffled_idx) 46 | return x[shuffled_idx, :], y[shuffled_idx, :] 47 | 48 | def get_batch(x, y, curr_index, batch_size): 49 | batch_x = x[curr_index * batch_size: (curr_index+1)*batch_size] 50 | batch_y = y[curr_index * batch_size: (curr_index+1)*batch_size] 51 | return batch_x, batch_y 52 | 53 | if __name__ == '__main__': 54 | make_rawdata(is_training=hp.is_training) 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from network import * 2 | from data_utils import * 3 | import hyperparams as hp 4 | 5 | 6 | class Graph: 7 | def __init__(self): 8 | self.graph = tf.Graph() 9 | 10 | with self.graph.as_default(): 11 | self.x = tf.placeholder(tf.float32, [None, hp.timestep, 1], name='X') 12 | self.y = tf.placeholder(tf.float32, [None, hp.timestep, 1], name='Y') 13 | if hp.use_mulaw: 14 | label = mu_law_encode(self.y) 15 | else: 16 | label = self.y 17 | 18 | output = network(self.x, use_mulaw=hp.use_mulaw) 19 | 20 | if hp.use_mulaw: 21 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=label)) 22 | else: 23 | self.loss = tf.reduce_mean(tf.abs(output - label)) 24 | 25 | self.train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(self.loss) 26 | 27 | tf.summary.scalar("loss", self.loss) 28 | self.merged = tf.summary.merge_all() 29 | 30 | def main(): 31 | mixture = np.load('./data/mixtures_data.npy') 32 | vocals = np.load('./data/vocals.npy') 33 | 34 | num_batch = len(mixture) // hp.batch_size 35 | 36 | g = Graph() 37 | 38 | with g.graph.as_default(): 39 | # config = tf.ConfigProto(allow_soft_placement=True) 40 | # config.gpu_options.allocator_type = 'BFC' 41 | # config.gpu_options.per_process_gpu_memory_fraction = 0.80 42 | # config.gpu_options.allow_growth = True 43 | 44 | saver = tf.train.Saver() 45 | with tf.Session() as sess: 46 | init = tf.global_variables_initializer() 47 | sess.run(init) 48 | for epoch in xrange(hp.num_epochs): 49 | 50 | mixture, vocals = dataset_shuffling(mixture, vocals) 51 | for i in range(num_batch): 52 | batch_X, batch_Y = get_batch(mixture, vocals, i, hp.batch_size) 53 | sess.run(g.train_op, feed_dict={g.x:batch_X, g.y:batch_Y}) 54 | 55 | if i % 100 == 0: 56 | print "step %d, CEloss:%.4f" %(i,sess.run(g.loss, feed_dict={g.x:batch_X, g.y:batch_Y})) 57 | saver.save(sess, hp.save_dir+"/model_%d.ckpt" % (epoch*num_batch + i)) 58 | if __name__ == '__main__': 59 | main() 60 | 61 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def atrous_conv1d(tensor, output_channels, is_causal=False, rate=1, pad='SAME', stddev=0.02, name="aconv1d"): 4 | """ 5 | Args: 6 | tensor: A 3-D tensor. 7 | output_channels: An integer. Dimension of output channel. 8 | is_causal: A boolean. If true, apply causal convolution. 9 | rate: An integer. Dilation rate. 10 | pad: Either "SAME" or "VALID". If "SAME", make padding, else no padding. 11 | stddev: A float. Standard deviation for truncated normal initializer. 12 | name: A string. Name of scope. 13 | Returns: 14 | A tensor of the same shape as `tensor`, which has been 15 | processed through dilated convolution layer. 16 | """ 17 | 18 | # Set filter size 19 | size = (2 if is_causal else 3) 20 | 21 | # Get input dimension 22 | in_dim = tensor.get_shape()[-1].value 23 | rate = [rate] 24 | 25 | with tf.variable_scope(name): 26 | # Make filter 27 | filter = tf.get_variable("w", [size, in_dim, output_channels], 28 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 29 | 30 | # Pre processing for dilated convolution 31 | if is_causal: 32 | # Causal convolution pre-padding 33 | if pad == 'SAME': 34 | pad_len = (size - 1) * rate 35 | x = tf.expand_dims(tf.pad(tensor, [[0, 0], [pad_len, 0], [0, 0]]),axis=1, name="X") 36 | else: 37 | x = tf.expand_dims(tensor, axis=1) 38 | # Apply 2d convolution 39 | out = tf.nn.atrous_conv2d(x, filter, rate=rate, padding='VALID') 40 | else: 41 | # Apply 2d convolution 42 | out = tf.nn.convolution(tensor, 43 | filter, dilation_rate=rate, padding=pad, data_format='NWC') 44 | # Reduce dimension 45 | # out = tf.squeeze(out, axis=1) 46 | 47 | return out 48 | 49 | def conv1d(input_, output_channels, filter_width = 1, stride = 1, stddev=0.02, name = 'conv1d'): 50 | """ 51 | Args: 52 | tensor: A 3-D tensor. 53 | output_channels: An integer. Dimension of output channel. 54 | filter_width: An integer. Size of filter. 55 | stride: An integer. Stride of convolution. 56 | stddev: A float. Standard deviation for truncated normal initializer. 57 | name: A string. Name of scope. 58 | Returns: 59 | A tensor of the shape as [batch size, timesteps, output channel], which has been 60 | processed through 1-D convolution layer. 61 | """ 62 | 63 | # Get input dimension 64 | input_shape = input_.get_shape() 65 | input_channels = input_shape[-1].value 66 | 67 | with tf.variable_scope(name): 68 | # Make filter 69 | filter_ = tf.get_variable('w', [filter_width, input_channels, output_channels], 70 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 71 | 72 | # Convolution layer 73 | conv = tf.nn.conv1d(input_, filter_, stride = stride, padding = 'SAME') 74 | biases = tf.get_variable('biases', [output_channels], initializer=tf.constant_initializer(0.0)) 75 | 76 | # Add bias 77 | conv = tf.nn.bias_add(conv, biases) 78 | 79 | return conv 80 | 81 | def residual_block(input_, rate, scope="res"): 82 | 83 | input_dim = input_.get_shape()[-1].value 84 | 85 | with tf.variable_scope(scope): 86 | aconv_f = atrous_conv1d(input_, 87 | output_channels=input_dim // 2, 88 | rate=rate, 89 | name="filter_aconv") 90 | aconv_g = atrous_conv1d(input_, 91 | output_channels=input_dim // 2, 92 | rate=rate, 93 | name="gate_aconv") 94 | aconv = tf.multiply(aconv_f, tf.sigmoid(aconv_g)) 95 | 96 | skip_connection = conv1d(aconv, 97 | output_channels=input_dim, 98 | name="skip_connection") 99 | res_output = conv1d(aconv, 100 | output_channels=input_dim, 101 | name="res_output") 102 | 103 | return skip_connection, res_output + input_ 104 | 105 | def skip_connection(tensor, logit_dim=256, use_mulaw=True): 106 | 107 | dim = tensor.get_shape()[-1].value 108 | 109 | with tf.variable_scope("last_skip_connection"): 110 | tensor = tf.nn.relu(tensor) 111 | tensor = conv1d(tensor, dim, name="conv1") 112 | tensor = tf.nn.relu(tensor) 113 | tensor = conv1d(tensor, logit_dim if use_mulaw else 1, name="conv2") 114 | return tensor 115 | 116 | 117 | def mu_law_encode(audio, quantization_channels=256): 118 | '''Quantizes waveform amplitudes.''' 119 | with tf.name_scope('encode'): 120 | mu = tf.to_float(quantization_channels - 1) 121 | # Perform mu-law companding transformation (ITU-T, 1988). 122 | # Minimum operation is here to deal with rare large amplitudes caused 123 | # by resampling. 124 | safe_audio_abs = tf.minimum(tf.abs(audio), 1.0) 125 | magnitude = tf.log1p(mu * safe_audio_abs) / tf.log1p(mu) 126 | signal = tf.sign(audio) * magnitude 127 | # Quantize signal to the specified number of levels. 128 | return tf.one_hot(tf.to_int32((signal + 1) / 2 * mu + 0.5), quantization_channels) 129 | 130 | 131 | def mu_law_decode(output, quantization_channels=256): 132 | '''Recovers waveform from quantized values.''' 133 | with tf.name_scope('decode'): 134 | mu = quantization_channels - 1 135 | # Map values back to [-1, 1]. 136 | signal = 2 * (tf.to_float(output) / mu) - 1 137 | # Perform inverse of mu-law transformation. 138 | magnitude = (1 / mu) * ((1 + mu) ** abs(signal) - 1) 139 | return tf.sign(signal) * magnitude 140 | 141 | 142 | 143 | --------------------------------------------------------------------------------