├── models ├── gan.pyc ├── ops.pyc ├── hrnn.pyc ├── __init__.pyc ├── audio_reader.pyc ├── discriminator.pyc ├── __init__.py ├── discriminator.py ├── audio_reader.py ├── ops.py ├── gan.py └── hrnn.py ├── data ├── val │ ├── nb │ │ ├── p225_355_nb.wav │ │ ├── p225_356_nb.wav │ │ └── p225_357_nb.wav │ └── wb │ │ ├── p225_355_wb.wav │ │ ├── p225_356_wb.wav │ │ └── p225_357_wb.wav └── train │ ├── nb │ ├── p225_001_nb.wav │ ├── p225_002_nb.wav │ └── p225_003_nb.wav │ └── wb │ ├── p225_001_wb.wav │ ├── p225_002_wb.wav │ └── p225_003_wb.wav ├── evaluate.py ├── train_hrnn.py └── train_gan.py /models/gan.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/models/gan.pyc -------------------------------------------------------------------------------- /models/ops.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/models/ops.pyc -------------------------------------------------------------------------------- /models/hrnn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/models/hrnn.pyc -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/models/__init__.pyc -------------------------------------------------------------------------------- /models/audio_reader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/models/audio_reader.pyc -------------------------------------------------------------------------------- /models/discriminator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/models/discriminator.pyc -------------------------------------------------------------------------------- /data/val/nb/p225_355_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/val/nb/p225_355_nb.wav -------------------------------------------------------------------------------- /data/val/nb/p225_356_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/val/nb/p225_356_nb.wav -------------------------------------------------------------------------------- /data/val/nb/p225_357_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/val/nb/p225_357_nb.wav -------------------------------------------------------------------------------- /data/val/wb/p225_355_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/val/wb/p225_355_wb.wav -------------------------------------------------------------------------------- /data/val/wb/p225_356_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/val/wb/p225_356_wb.wav -------------------------------------------------------------------------------- /data/val/wb/p225_357_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/val/wb/p225_357_wb.wav -------------------------------------------------------------------------------- /data/train/nb/p225_001_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/train/nb/p225_001_nb.wav -------------------------------------------------------------------------------- /data/train/nb/p225_002_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/train/nb/p225_002_nb.wav -------------------------------------------------------------------------------- /data/train/nb/p225_003_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/train/nb/p225_003_nb.wav -------------------------------------------------------------------------------- /data/train/wb/p225_001_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/train/wb/p225_001_wb.wav -------------------------------------------------------------------------------- /data/train/wb/p225_002_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/train/wb/p225_002_wb.wav -------------------------------------------------------------------------------- /data/train/wb/p225_003_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berthyf96/audio_sr/HEAD/data/train/wb/p225_003_wb.wav -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ops import * 2 | from .hrnn import HRNN 3 | from .gan import HRNN_GAN 4 | from .discriminator import Discriminator 5 | from .audio_reader import AudioReader 6 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from tensorflow.contrib.layers import batch_norm, fully_connected, flatten 4 | from tensorflow.contrib.layers import xavier_initializer 5 | from .ops import downconv, leakyrelu, prelu, conv1d 6 | import numpy as np 7 | 8 | class Discriminator(object): 9 | 10 | def __init__(self, bias_D_conv, name): 11 | self.bias_D_conv = bias_D_conv 12 | self.d_num_fmaps = \ 13 | [16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024] 14 | self.name = name 15 | 16 | def logits_Discriminator(self, d_input, reuse=False): 17 | in_dims = d_input.get_shape().as_list() 18 | hi = d_input 19 | if len(in_dims) == 2: 20 | hi = tf.expand_dims(d_input, -1) 21 | elif len(in_dims) < 2 or len(in_dims) > 3: 22 | raise ValueError('Discriminator input must be 2-D or 3-D') 23 | 24 | with tf.variable_scope('Discriminator'): 25 | with tf.variable_scope(self.name) as scope: 26 | 27 | if reuse: 28 | scope.reuse_variables() 29 | def disc_block(block_idx, input_, kwidth, 30 | nfmaps, bnorm, activation, 31 | pooling=2): 32 | with tf.variable_scope('d_block_{}'.format(block_idx)): 33 | bias_init = None 34 | if self.bias_D_conv: 35 | bias_init = tf.constant_initializer(0.) 36 | downconv_init = \ 37 | tf.truncated_normal_initializer(stddev=0.02) 38 | 39 | # downconvolution 40 | hi_a = downconv(input_, nfmaps, kwidth=kwidth, 41 | pool=pooling, init=downconv_init, 42 | bias_init=bias_init) 43 | 44 | # VBN 45 | 46 | # activation 47 | if activation == 'leakyrelu': 48 | hi = leakyrelu(hi_a) 49 | elif activation == 'relu': 50 | hi = tf.nn.relu(hi_a) 51 | else: 52 | raise ValueError('Unrecognized activation {}' 53 | 'in D'.format(activation)) 54 | return hi 55 | 56 | # [removed] apply input noisy layer to real and fake samples 57 | 58 | for block_idx, fmaps in enumerate(self.d_num_fmaps): 59 | hi = disc_block(block_idx, hi, 31, 60 | self.d_num_fmaps[block_idx], 61 | True, 'leakyrelu') 62 | if not reuse: 63 | print('Discriminator deconved shape: ', hi.get_shape()) 64 | #hi_f = flatten(hi) 65 | d_logit_out = conv1d( 66 | hi, kwidth=1, num_kernels=1, 67 | init=tf.truncated_normal_initializer(stddev=0.02), 68 | name='logits_conv') 69 | d_logit_out = tf.squeeze(d_logit_out) 70 | d_logit_out = tf.expand_dims(d_logit_out, 1) 71 | d_logit_out = fully_connected(d_logit_out, 1, activation_fn=None) 72 | 73 | if not reuse: 74 | print('Discriminator output shape: ', d_logit_out.get_shape()) 75 | print('*****************************') 76 | return d_logit_out -------------------------------------------------------------------------------- /models/audio_reader.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import os 3 | import random 4 | import re 5 | import threading 6 | import librosa 7 | import sys 8 | import copy 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | def randomize_files(files): 13 | files_idx = [i for i in xrange(len(files))] 14 | random.shuffles(files_idx) 15 | for idx in xrange(len(files)): 16 | yield files[files_idx[idx]] 17 | 18 | def find_files(directory, pattern='*.wav'): 19 | files = [] 20 | for root, dirnames, filenames in os.walk(directory): 21 | for filename in fnmatch.filter(filenames, pattern): 22 | files.append(os.path.join(root, filename)) 23 | return files 24 | 25 | def load_generic_audio(directory, sample_rate): 26 | files = find_files(directory) 27 | print('Files length: {}'.format(len(files))) 28 | randomized_files = randomize_files(files) 29 | for filename in files: 30 | wb_filename = filename.replace('nb', 'wb') 31 | # Yield both nb_audio and wb_audio given filename 32 | print("Found: {}, {}".format(filename, wb_filename)) 33 | nb_audio, _ = librosa.load(filename, sr=sample_rate, mono=True) 34 | wb_audio, _ = librosa.load(wb_filename, sr=sample_rate, mono=True) 35 | nb_audio = nb_audio.reshape(-1, 1) 36 | wb_audio = wb_audio.reshape(-1, 1) 37 | yield nb_audio, wb_audio, filename 38 | 39 | def trim_silence(audio, threshold): 40 | '''Removes silence at the beginning and end of a sample''' 41 | energy = librosa.feature.rmse(audio) 42 | frames = np.nonzero(energy > threshold) 43 | indices = librosa.core.frames_to_samples(frames)[1] 44 | 45 | return audio[indices[0]:indices[-1]] if indices.size else audio[0:0] 46 | 47 | 48 | class AudioReader(object): 49 | '''Generic background audio reader that preprocesses audio files 50 | and enqueues them into a TensorFlow queue''' 51 | 52 | def __init__(self, 53 | nb_audio_dir, 54 | wb_audio_dir, 55 | coord, 56 | sample_rate, 57 | sample_size=None, 58 | silence_threshold=None, 59 | queue_size=32): 60 | self.nb_audio_dir = nb_audio_dir 61 | self.wb_audio_dir = wb_audio_dir 62 | self.sample_rate = sample_rate 63 | self.coord = coord 64 | self.sample_size = sample_size 65 | self.silence_threshold = silence_threshold 66 | self.threads = [] 67 | self.nb_sample_placeholder = tf.placeholder(dtype=tf.float32, shape=None) 68 | self.wb_sample_placeholder = tf.placeholder(dtype=tf.float32, shape=None) 69 | self.queue = tf.PaddingFIFOQueue(queue_size, 70 | ['float32', 'float32'], 71 | shapes=[(None, 1), (None, 1)]) 72 | self.enqueue = self.queue.enqueue( 73 | [self.nb_sample_placeholder, self.wb_sample_placeholder]) 74 | 75 | nb_files = find_files(nb_audio_dir) 76 | wb_files = find_files(wb_audio_dir) 77 | if not nb_files: 78 | raise ValueError("No audio files found in '{}'".format(nb_audio_dir)) 79 | if not wb_files: 80 | raise ValueError("No audio files found in '{}'".format(wb_audio_dir)) 81 | return 82 | 83 | def dequeue(self, num_elements): 84 | return self.queue.dequeue_many(num_elements) 85 | 86 | def thread_main(self, sess): 87 | stop = False 88 | nb_audio_list = [] 89 | wb_audio_list = [] 90 | 91 | # load_generic_audio takes NB directory and yields nb_audio, wb_audio, 92 | # and nb_filename 93 | iterator = load_generic_audio(self.nb_audio_dir, self.sample_rate) 94 | for nb_audio, wb_audio, _ in iterator: 95 | nb_audio_list.append(nb_audio) 96 | wb_audio_list.append(wb_audio) 97 | print('Compiled audio') 98 | while not stop: 99 | for nb_audio_copy, wb_audio_copy in zip(nb_audio_list, wb_audio_list): 100 | nb_audio = copy.deepcopy(nb_audio_copy) 101 | wb_audio = copy.deepcopy(wb_audio_copy) 102 | if self.coord.should_stop(): 103 | stop = True 104 | break 105 | if self.silence_threshold is not None: 106 | # Remove silence 107 | nb_audio = trim_silence(nb_audio[:, 0], self.silence_threshold) 108 | nb_audio = nb_audio.reshape(-1, 1) 109 | wb_audio = trim_silence(wb_audio[:, 0], self.silence_threshold) 110 | wb_audio = wb_audio.reshape(-1, 1) 111 | #if nb_audio.size == 0 or wb_audio.size == 0: 112 | #print('An audio file was dropped.') 113 | 114 | pad_elements = \ 115 | self.sample_size - 1 \ 116 | - (nb_audio.shape[0] + self.sample_size - 1) \ 117 | % self.sample_size 118 | nb_audio = np.concatenate( 119 | [nb_audio, np.full((pad_elements, 1), 0.0, dtype='float32')], 120 | axis=0) 121 | wb_audio = np.concatenate( 122 | [wb_audio, np.full((pad_elements, 1), 0.0, dtype='float32')], 123 | axis=0) 124 | 125 | if self.sample_size: 126 | # Keep taking chunks of size sample_size 127 | while len(nb_audio) > self.sample_size: 128 | nb_piece = nb_audio[:self.sample_size, :] 129 | wb_piece = wb_audio[:self.sample_size, :] 130 | sess.run(self.enqueue, 131 | feed_dict={self.nb_sample_placeholder: nb_piece, 132 | self.wb_sample_placeholder: wb_piece}) 133 | nb_audio = nb_audio[self.sample_size:, :] 134 | wb_audio = wb_audio[self.sample_size:, :] 135 | else: 136 | sess.run(self.enqueue, 137 | feed_dict={self.nb_sample_placeholder: nb_audio, 138 | self.wb_sample_placeholder: wb_audio}) 139 | 140 | def start_threads(self, sess, n_threads=1): 141 | for _ in range(n_threads): 142 | thread = threading.Thread(target=self.thread_main, args=(sess,)) 143 | thread.daemon = True 144 | thread.start() 145 | self.threads.append(thread) 146 | return self.threads 147 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import librosa 3 | import numpy as np 4 | import argparse 5 | from models import HRNN, HRNN_GAN, Discriminator 6 | from models import write_wav, log_mel_spectrograms 7 | import os 8 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--batch_size', type=int, default=1) 13 | parser.add_argument('--big_frame_size', type=int, default=8) 14 | parser.add_argument('--frame_size', type=int, default=2) 15 | parser.add_argument('--q_levels', type=int, default=256) 16 | parser.add_argument('--rnn_type', type=str, default='LSTM') 17 | parser.add_argument('--dim', type=int, default=1024) 18 | parser.add_argument('--n_rnn', type=int, default=1) 19 | parser.add_argument('--seq_len', type=int, default=520) 20 | parser.add_argument('--emb_size', type=int, default=256) 21 | parser.add_argument('--spec_loss_weight', type=float, required=False) 22 | parser.add_argument('--l1_reg_strength', type=float, default=0.0) 23 | parser.add_argument('--sample_rate', type=int, default=16000) 24 | parser.add_argument('--method', type=str, required=True) 25 | parser.add_argument('--step', type=int, default=700) 26 | parser.add_argument('--logdir', type=str, required=True) 27 | parser.add_argument('--inp_file', type=str, required=True) 28 | return parser.parse_args() 29 | 30 | def create_hrnn(args): 31 | net = HRNN(args) 32 | return net 33 | 34 | def create_gan(args): 35 | net = HRNN_GAN(batch_size=args.batch_size, 36 | big_frame_size=args.big_frame_size, 37 | frame_size=args.frame_size, 38 | q_levels=args.q_levels, 39 | rnn_type=args.rnn_type, 40 | dim=args.dim, 41 | n_rnn=args.n_rnn, 42 | seq_len=args.seq_len, 43 | emb_size=args.emb_size) 44 | return net 45 | 46 | def load_step(saver, sess, logdir, step): 47 | print("Trying to restore saved checkpoints from {} ...".format(logdir)) 48 | ckpt = tf.train.get_checkpoint_state(logdir) 49 | if ckpt: 50 | print("Checkpoint found: {}".format(ckpt.model_checkpoint_path)) 51 | global_step = int(ckpt.model_checkpoint_path.split('/')[-1] 52 | .split('-')[-1]) 53 | model_path = '{}/model.ckpt-{}'.format(logdir, step) 54 | saver.restore(sess, model_path) 55 | print("Restored model from global step {}".format(step)) 56 | return global_step 57 | else: 58 | print("No checkpoint found") 59 | return None 60 | return None 61 | 62 | def load_audio(nb_file, wb_file): 63 | nb_audio, _ = librosa.load(nb_file, sr=16000, mono=True) 64 | wb_audio, _ = librosa.load(wb_file, sr=16000, mono=True) 65 | return nb_audio, wb_audio 66 | 67 | def crossfade(s1, s2, overlap): 68 | s1_stop = len(s1) - overlap 69 | res = np.zeros(len(s1) + len(s2) - overlap) 70 | res[:s1_stop] = s1[:s1_stop] 71 | for i in range(overlap): 72 | alpha = float(i) / (overlap - 1) 73 | res[s1_stop + i] = (alpha * s2[i]) + ((1 - alpha) * s1[s1_stop + i]) 74 | res[s1_stop + overlap: ] = s2[overlap: ] 75 | return res 76 | 77 | def l1_loss(pred, target): 78 | return np.mean(np.absolute(pred - target)) 79 | 80 | def log_spectral_distance(pred, target, sess): 81 | pred_spectrogram_tensor = log_mel_spectrograms(pred, 16000) 82 | target_spectrogram_tensor = log_mel_spectrograms(target, 16000) 83 | pred_spectrogram, target_spectrogram = sess.run( 84 | [pred_spectrogram_tensor, target_spectrogram_tensor]) 85 | squared_distance = np.square(pred_spectrogram - target_spectrogram) 86 | return np.mean(squared_distance) 87 | 88 | def evaluate(args): 89 | if args.method == 'baseline' or args.method == 'spec': 90 | args.spec_loss_weight = 0.0 91 | net = create_hrnn(args) 92 | elif args.method == 'gan': 93 | net = create_gan(args) 94 | else: 95 | raise ValueError('Please specify a method (baseline, spec, or gan).') 96 | 97 | # input placeholders 98 | nb_input_batch = tf.Variable( 99 | tf.zeros([net.batch_size, net.seq_len, 1]), 100 | trainable=False, 101 | dtype=tf.float32) 102 | wb_input_batch = tf.Variable( 103 | tf.zeros([net.batch_size, net.seq_len, 1]), 104 | trainable=False, 105 | dtype=tf.float32) 106 | 107 | # initial lstm states 108 | train_big_frame_state = net.big_cell.zero_state( 109 | net.batch_size, tf.float32) 110 | train_frame_state = net.cell.zero_state(net.batch_size, tf.float32) 111 | final_big_frame_state_spec = net.big_cell.zero_state( 112 | net.batch_size, tf.float32) 113 | final_frame_state_spec = net.cell.zero_state(net.batch_size, tf.float32) 114 | 115 | # output variables 116 | if args.method == 'baseline' or args.method == 'spec': 117 | loss, prediction, final_big_frame_state, final_frame_state = \ 118 | net.forward( 119 | nb_input_batch, wb_input_batch, train_big_frame_state, 120 | train_frame_state, inference_only=True) 121 | else: 122 | loss, final_big_frame_state, final_frame_state, _, prediction = \ 123 | net.loss_SampleRNN( 124 | nb_input_batch, wb_input_batch, train_big_frame_state, 125 | train_frame_state) 126 | 127 | # configure session 128 | tf_config = tf.ConfigProto(allow_soft_placement=True) 129 | tf_config.gpu_options.allow_growth = True 130 | sess = tf.Session(config=tf_config) 131 | 132 | # load saved model 133 | saver = tf.train.Saver(var_list=tf.trainable_variables()) 134 | logdir = args.logdir 135 | load_step(saver, sess, logdir, args.step) 136 | 137 | test_nb_file = args.inp_file 138 | test_wb_file = test_nb_file.replace('nb', 'wb') 139 | nb_audio, wb_audio = load_audio(test_nb_file, test_wb_file) 140 | result = np.zeros(len(nb_audio) - 8) 141 | nb_audio = nb_audio.reshape(-1, 1) 142 | wb_audio = wb_audio.reshape(-1, 1) 143 | output_list = [loss, prediction, final_big_frame_state, final_frame_state] 144 | sample_size = len(nb_audio) 145 | seq_len = 520 146 | stride = 256 147 | overlap = 256 148 | print('Running model...') 149 | for i in range(0, sample_size, stride): 150 | if (i + seq_len) >= len(nb_audio): break 151 | inp_dict = {} 152 | inp_dict[nb_input_batch] = [nb_audio[i:i + seq_len]] 153 | inp_dict[wb_input_batch] = [wb_audio[i:i + seq_len]] 154 | inp_dict[train_big_frame_state] = sess.run(net.big_initial_state) 155 | inp_dict[train_frame_state] = sess.run(net.initial_state) 156 | test_loss, pred, final_big_frame_s, final_frame_s = sess.run( 157 | output_list, feed_dict=inp_dict) 158 | output = np.asarray(pred).reshape(-1) 159 | if i == 0: 160 | result = output 161 | continue 162 | result = crossfade(result, output, overlap) 163 | 164 | target = np.squeeze(wb_audio)[8: 8 + len(result)] 165 | pred = result 166 | 167 | l1 = l1_loss(pred, target) 168 | lsd = log_spectral_distance(pred, target, sess) 169 | write_wav(result, 16000, args.method + '.wav') 170 | 171 | print('Mean L1 loss = {}'.format(l1)) 172 | print('Mean LSD = {}'.format(lsd)) 173 | 174 | return 175 | 176 | args = get_args() 177 | evaluate(args) -------------------------------------------------------------------------------- /models/ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | import os 4 | import sys 5 | import numpy as np 6 | import librosa 7 | import fnmatch 8 | 9 | # ------------------------------- 10 | # MODEL OPS 11 | # ------------------------------- 12 | def create_adam_optimizer(learning_rate, momentum): 13 | return tf.train.AdamOptimizer( 14 | learning_rate=learning_rate, epsilon=1e-4) 15 | 16 | def create_sgd_optimizer(learning_rate, momentum): 17 | return tf.train.MomentumOptimizer( 18 | learning_rate=learning_rate, momentum=momentum) 19 | 20 | def create_rmsprop_optimizer(learning_rate, momentum): 21 | return tf.train.RMSPropOptimizer( 22 | learning_rate=learning_rate, momentum=momentum, epsilon=1e-5) 23 | 24 | optimizer_factory = {'adam': create_adam_optimizer, 25 | 'sgd': create_sgd_optimizer, 26 | 'rmsprop': create_rmsprop_optimizer} 27 | 28 | def save(saver, sess, logdir, step): 29 | model_name = 'model.ckpt' 30 | checkpoint_path = os.path.join(logdir, model_name) 31 | print("Storing checkpoint to {} ...".format(logdir)) 32 | sys.stdout.flush() 33 | if not os.path.exists(logdir): 34 | os.makedirs(logdir) 35 | saver.save(sess, checkpoint_path, global_step=step) 36 | return 37 | 38 | def load(saver, sess, logdir): 39 | print("Trying to restore saved checkpoints from {} ...".format(logdir)) 40 | ckpt = tf.train.get_checkpoint_state(logdir) 41 | if ckpt: 42 | print("Checkpoint found: {}".format( 43 | ckpt.model_checkpoint_path)) 44 | global_step = int(ckpt.model_checkpoint_path.split('/')[-1] 45 | .split('-')[-1]) 46 | saver.restore(sess, ckpt.model_checkpoint_path) 47 | print("Restored model from global step {}".format( 48 | global_step)) 49 | return global_step 50 | else: 51 | print("No checkpoint found") 52 | return None 53 | return None 54 | 55 | def scalar_summary(name, x): 56 | try: 57 | summ = tf.summary.scalar(name, x) 58 | except AttributeError: 59 | summ = tf.scalar_summary(name, x) 60 | return summ 61 | 62 | # ------------------------------- 63 | # I/O OPS 64 | # ------------------------------- 65 | def find_files(directory, pattern='*.wav'): 66 | files = [] 67 | for root, dirnames, filenames in os.walk(directory): 68 | for filename in fnmatch.filter(filenames, pattern): 69 | files.append(os.path.join(root, filename)) 70 | return files 71 | 72 | def get_test_batches(files, batch_size, sample_rate): 73 | # Grab batch_size number of audio files 74 | nb_audio_batch = [] 75 | wb_audio_batch = [] 76 | for i in range(batch_size): 77 | nb_filename = np.random.choice(files) 78 | wb_filename = nb_filename.replace('nb', 'wb') 79 | nb_audio, _ = librosa.load( 80 | nb_filename, sr=sample_rate, mono=True) 81 | wb_audio, _ = librosa.load( 82 | wb_filename, sr=sample_rate, mono=True) 83 | nb_audio = nb_audio.reshape(-1, 1) 84 | wb_audio = wb_audio.reshape(-1, 1) 85 | nb_audio_batch.append(nb_audio) 86 | wb_audio_batch.append(wb_audio) 87 | nb_audio_batch = np.asarray(nb_audio_batch) 88 | wb_audio_batch = np.asarray(wb_audio_batch) 89 | nb_audio_batch = nb_audio_batch.reshape(batch_size, -1, 1) 90 | wb_audio_batch = wb_audio_batch.reshape(batch_size, -1, 1) 91 | return nb_audio_batch, wb_audio_batch 92 | 93 | def write_wav(waveform, sample_rate, filename): 94 | y = np.array(waveform) 95 | librosa.output.write_wav(filename, y, sample_rate) 96 | print('Updated wav file at {}'.format(filename)) 97 | return 98 | 99 | # ------------------------------- 100 | # MATH OPS 101 | # ------------------------------- 102 | def log10(x): 103 | numerator = tf.log(x) 104 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype)) 105 | return numerator / denominator 106 | 107 | def average_gradients(tower_grads): 108 | '''Returns: List of pairs of (gradient, variable), where the gradient 109 | has been averaged across all towers''' 110 | average_grads = [] 111 | for grad_and_vars in zip(*tower_grads): 112 | grads = [] 113 | for g, _ in grad_and_vars: 114 | expanded_g = tf.expand_dims(g, 0) 115 | grads.append(expanded_g) 116 | grad = tf.concat(axis=0, values=grads) 117 | grad = tf.reduce_mean(grad, 0) 118 | 119 | v = grad_and_vars[0][1] 120 | grad_and_var = (grad, v) 121 | average_grads.append(grad_and_var) 122 | return average_grads 123 | 124 | # ------------------------------- 125 | # SIGNAL PROCESSING OPS 126 | # ------------------------------- 127 | def mu_law_encode(audio, quantization_channels): 128 | '''Quantizes waveform amplitudes.''' 129 | with tf.name_scope('encode'): 130 | mu = quantization_channels - 1 131 | # Perform mu-law companding transformation (ITU-T, 1988). 132 | safe_audio_abs = tf.minimum(tf.abs(audio), 1.0) 133 | magnitude = tf.log(1 + mu * safe_audio_abs) / tf.log(1. + mu) 134 | signal = tf.sign(audio) * magnitude 135 | # Quantize signal to the specified number of levels. 136 | return tf.cast((signal + 1.0) / 2 * mu + 0.5, tf.int32) 137 | 138 | def mu_law_decode(output, quantization_channels): 139 | '''Recovers waveform from quantized values.''' 140 | with tf.name_scope('decode'): 141 | mu = quantization_channels - 1 142 | # Map values back to [-1, 1]. 143 | casted = tf.cast(output, tf.float64) 144 | signal = 2 * (casted / mu) - 1 145 | # Perform inverse of mu-law transformation. 146 | magnitude = (1.0 / mu) * ((1 + mu)**abs(signal) - 1) 147 | return tf.sign(signal) * magnitude 148 | 149 | def log_mel_spectrograms(signals, sample_rate): 150 | fft_length = 512 151 | signals = tf.cast(signals, dtype=tf.float32) 152 | stfts = tf.contrib.signal.stft( 153 | signals, frame_length=512, frame_step=64, 154 | fft_length=fft_length) 155 | magnitude_spectrograms = tf.abs(stfts) 156 | # Warp the linear-scale, magnitude spectrograms into mel-scale. 157 | num_spectrogram_bins = fft_length // 2 + 1 158 | lower_edge_hertz, upper_edge_hertz, num_mel_bins = \ 159 | 80.0, 7600.0, 64 160 | linear_to_mel_weight_matrix = \ 161 | tf.contrib.signal.linear_to_mel_weight_matrix( 162 | num_mel_bins, num_spectrogram_bins, sample_rate, 163 | lower_edge_hertz, upper_edge_hertz) 164 | mel_spectrograms = tf.tensordot( 165 | magnitude_spectrograms, linear_to_mel_weight_matrix, 1) 166 | mel_spectrograms.set_shape( 167 | magnitude_spectrograms.shape[:-1].concatenate( 168 | linear_to_mel_weight_matrix.shape[-1:])) 169 | return log10(mel_spectrograms) 170 | 171 | # ------------------------------- 172 | # NEURAL NET OPS 173 | # ------------------------------- 174 | def downconv(x, output_dim, kwidth=5, pool=2, init=None, uniform=False, 175 | bias_init=None, name='downconv'): 176 | """ Downsampled convolution 1d """ 177 | x2d = tf.expand_dims(x, 2) 178 | w_init = init 179 | if w_init is None: 180 | w_init = xavier_initializer(uniform=uniform) 181 | with tf.variable_scope(name): 182 | W = tf.get_variable( 183 | 'W', [kwidth, 1, x.get_shape()[-1], output_dim], 184 | initializer=w_init) 185 | conv = tf.nn.conv2d( 186 | x2d, W, strides=[1, pool, 1, 1], padding='SAME') 187 | if bias_init is not None: 188 | b = tf.get_variable( 189 | 'b', [output_dim], initializer=bias_init) 190 | conv = tf.reshape( 191 | tf.nn.bias_add(conv, b), conv.get_shape()) 192 | else: 193 | conv = tf.reshape(conv, conv.get_shape()) 194 | conv = tf.reshape(conv, conv.get_shape().as_list()[:2] + 195 | [conv.get_shape().as_list()[-1]]) 196 | return conv 197 | 198 | def leakyrelu(x, alpha=0.3, name='lrelu'): 199 | return tf.maximum(x, alpha * x, name=name) 200 | 201 | def prelu(x, name='prelu', ref=False): 202 | in_shape = x.get_shape().as_list() 203 | with tf.variable_scope(name): 204 | # make one alpha per feature 205 | alpha = tf.get_variable( 206 | 'alpha', in_shape[-1], 207 | initializer=tf.constant_initializer(0.), 208 | dtype=tf.float32) 209 | pos = tf.nn.relu(x) 210 | neg = alpha * (x - tf.abs(x)) * .5 211 | if ref: 212 | # return ref to alpha vector 213 | return pos + neg, alpha 214 | else: 215 | return pos + neg 216 | 217 | def conv1d(x, kwidth=5, num_kernels=1, 218 | init=None, uniform=False, bias_init=None, 219 | name='conv1d', padding='SAME'): 220 | input_shape = x.get_shape() 221 | in_channels = input_shape[-1] 222 | assert len(input_shape) >= 3 223 | w_init = init 224 | if w_init is None: 225 | w_init = xavier_initializer(uniform=uniform) 226 | with tf.variable_scope(name): 227 | # filter shape: [kwidth, in_channels, num_kernels] 228 | W = tf.get_variable('W', [kwidth, in_channels, num_kernels], 229 | initializer=w_init 230 | ) 231 | conv = tf.nn.conv1d(x, W, stride=1, padding=padding) 232 | if bias_init is not None: 233 | b = tf.get_variable( 234 | 'b', [num_kernels], 235 | initializer=tf.constant_initializer(bias_init)) 236 | conv = conv + b 237 | return conv 238 | 239 | -------------------------------------------------------------------------------- /train_hrnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import tensorflow as tf 3 | from tensorflow import AggregationMethod as aggreg 4 | from tensorflow.python.client import timeline 5 | from models import HRNN, AudioReader 6 | from models import optimizer_factory, find_files, get_test_batches, average_gradients, load, save 7 | import argparse 8 | import os 9 | import time 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--num_gpus', type=int, default=1) 14 | parser.add_argument('--batch_size', type=int, required=True) 15 | parser.add_argument('--nb_data_dir', type=str, required=True) 16 | parser.add_argument('--wb_data_dir', type=str, required=True) 17 | parser.add_argument('--test_nb_data_dir', type=str, required=True) 18 | parser.add_argument('--test_wb_data_dir', type=str, required=True) 19 | parser.add_argument('--logdir', type=str, required=True) 20 | parser.add_argument('--ckpt_every', type=int, default=20) 21 | parser.add_argument('--num_steps', type=int, default=10000) 22 | parser.add_argument('--learning_rate', type=float, required=True) 23 | parser.add_argument('--sample_size', type=int, default=48000) 24 | parser.add_argument('--sample_rate', type=int, default=16000) 25 | parser.add_argument('--l1_reg_strength', type=float, default=0.0) 26 | parser.add_argument('--silence_threshold', type=float, default=0.0) 27 | parser.add_argument('--optimizer', type=str, default='adam', 28 | choices=optimizer_factory.keys()) 29 | parser.add_argument('--momentum', type=float, default=0.9) 30 | parser.add_argument('--seq_len', type=int, default=520) 31 | parser.add_argument('--big_frame_size', type=int, default=8) 32 | parser.add_argument('--frame_size', type=int, default=2) 33 | parser.add_argument('--q_levels', type=int, default=256) 34 | parser.add_argument('--dim', type=int, required=True) 35 | parser.add_argument('--n_rnn', type=int, choices=xrange(1,6), default=1) 36 | parser.add_argument('--emb_size', type=int, default=256) 37 | parser.add_argument('--rnn_type', choices=['LSTM', 'GRU'], default='LSTM') 38 | parser.add_argument('--max_checkpoints', type=int, default=10) 39 | parser.add_argument('--spec_loss_weight', type=float, required=True) 40 | return parser.parse_args() 41 | 42 | def create_model(args): 43 | '''Set up model, global step, and optimizer''' 44 | model = HRNN(args) 45 | global_step = tf.get_variable( 46 | 'global_step', [], initializer=tf.constant_initializer(0), 47 | trainable=False) 48 | optim = optimizer_factory[args.optimizer]( 49 | learning_rate=args.learning_rate, momentum=args.momentum) 50 | return model, global_step, optim 51 | 52 | def train(): 53 | # ----------------------------------------- 54 | # SETUP 55 | # ----------------------------------------- 56 | args = get_args() 57 | seq_len = args.seq_len 58 | if args.l1_reg_strength == 0.0: 59 | args.l1_reg_strength = None 60 | logdir = os.path.join(args.logdir, 'train') 61 | logdir_test = os.path.join(args.logdir, 'test') 62 | coord = tf.train.Coordinator() 63 | # get test files 64 | test_files = find_files(args.test_nb_data_dir) 65 | # create inputs 66 | with tf.name_scope('create_inputs'): 67 | reader = AudioReader(args.nb_data_dir, 68 | args.wb_data_dir, 69 | coord, 70 | sample_rate=args.sample_rate, 71 | sample_size=args.sample_size, 72 | silence_threshold=args.silence_threshold) 73 | nb_audio_batch, wb_audio_batch = reader.dequeue(args.batch_size) 74 | # create model 75 | net, global_step, optim = create_model(args) 76 | 77 | # set up placeholders and variables on each GPU 78 | nb_input_batch = [] 79 | wb_input_batch = [] 80 | train_big_frame_state = [] 81 | train_frame_state = [] 82 | final_big_frame_state = [] 83 | final_frame_state = [] 84 | losses = [] 85 | tower_grads = [] 86 | for i in range(args.num_gpus): 87 | with tf.device('/gpu:%d' % i): 88 | # create input placeholders 89 | nb_input_batch.append(tf.Variable( 90 | tf.zeros([net.batch_size, seq_len, 1]), 91 | trainable=False, 92 | name='nb_input_batch_rnn', 93 | dtype=tf.float32)) 94 | wb_input_batch.append(tf.Variable( 95 | tf.zeros([net.batch_size, seq_len, 1]), 96 | trainable=False, 97 | name='wb_input_batch_rnn', 98 | dtype=tf.float32)) 99 | # create initial states 100 | train_big_frame_state.append( 101 | net.big_cell.zero_state(net.batch_size, tf.float32)) 102 | final_big_frame_state.append( 103 | net.big_cell.zero_state(net.batch_size, tf.float32)) 104 | train_frame_state.append( 105 | net.cell.zero_state(net.batch_size, tf.float32)) 106 | final_frame_state.append( 107 | net.cell.zero_state(net.batch_size, tf.float32)) 108 | 109 | # network output variables 110 | with tf.variable_scope(tf.get_variable_scope()): 111 | for i in range(args.num_gpus): 112 | with tf.device('/gpu:%d' % i): 113 | with tf.name_scope('TOWER_%d' % i) as scope: 114 | # create variables 115 | print('Creating model on GPU:%d' % i) 116 | loss, final_big_frame_state[i], final_frame_state[i] = \ 117 | net.forward(nb_input_batch[i], 118 | wb_input_batch[i], 119 | train_big_frame_state[i], 120 | train_frame_state[i], 121 | l1_reg_strength=args.l1_reg_strength) 122 | tf.get_variable_scope().reuse_variables() 123 | losses.append(loss) 124 | # reuse variables for the next tower 125 | trainable = tf.trainable_variables() 126 | gradients = optim.compute_gradients( 127 | loss, trainable, 128 | aggregation_method=aggreg.EXPERIMENTAL_ACCUMULATE_N) 129 | tower_grads.append(gradients) 130 | 131 | # backpropagation 132 | grad_vars = average_gradients(tower_grads) 133 | grads, vars = zip(*grad_vars) 134 | grads_clipped, _ = tf.clip_by_global_norm(grads, 5.0) 135 | grad_vars = zip(grads_clipped, vars) 136 | apply_gradient_op = optim.apply_gradients( 137 | grad_vars, global_step=global_step) 138 | 139 | # configure session 140 | writer = tf.summary.FileWriter(logdir) 141 | test_writer = tf.summary.FileWriter(logdir_test) 142 | writer.add_graph(tf.get_default_graph()) 143 | test_writer.add_graph(tf.get_default_graph()) 144 | summaries = tf.summary.merge_all() 145 | tf_config = tf.ConfigProto(allow_soft_placement=True) 146 | tf_config.gpu_options.allow_growth = True 147 | sess = tf.Session(config=tf_config) 148 | init = tf.global_variables_initializer() 149 | sess.run(init) 150 | 151 | # load checkpoint 152 | saver = tf.train.Saver( 153 | var_list=tf.trainable_variables(), 154 | max_to_keep=args.max_checkpoints) 155 | try: 156 | saved_global_step = load(saver, sess, logdir) 157 | if saved_global_step is None: saved_global_step = -1 158 | except: 159 | raise ValueError('Something went wrong while restoring checkpoint') 160 | 161 | # start queue runners 162 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 163 | reader.start_threads(sess) 164 | 165 | # ----------------------------------------- 166 | # TRAIN + VAL 167 | # ----------------------------------------- 168 | print('Starting training...') 169 | step = None 170 | last_saved_step = saved_global_step 171 | for step in range(saved_global_step + 1, args.num_steps + 1): 172 | # initialize cells 173 | final_big_s = [] 174 | final_s = [] 175 | for g in range(args.num_gpus): 176 | final_big_s.append(sess.run(net.big_initial_state)) 177 | final_s.append(sess.run(net.initial_state)) 178 | start_time = time.time() 179 | 180 | # get input batches 181 | nb_inputs_list = [] 182 | wb_inputs_list = [] 183 | for _ in range(args.num_gpus): 184 | nb_inputs, wb_inputs = sess.run( 185 | [nb_audio_batch, wb_audio_batch]) 186 | nb_inputs_list.append(nb_inputs) 187 | wb_inputs_list.append(wb_inputs) 188 | 189 | # run BPTT 190 | audio_length = args.sample_size - args.big_frame_size 191 | bptt_length = seq_len - args.big_frame_size 192 | stateful_rnn_length = audio_length / bptt_length 193 | loss_sum = 0 194 | idx_begin = 0 195 | output_list = [summaries, 196 | losses, 197 | apply_gradient_op, 198 | final_big_frame_state, 199 | final_frame_state] 200 | for i in range(stateful_rnn_length): 201 | inp_dict = {} 202 | for g in range(args.num_gpus): 203 | # add seq_len samples as input for truncated BPTT 204 | inp_dict[nb_input_batch[g]] = \ 205 | nb_inputs_list[g][:, idx_begin:idx_begin+seq_len, :] 206 | inp_dict[wb_input_batch[g]] = \ 207 | wb_inputs_list[g][:, idx_begin:idx_begin+seq_len, :] 208 | inp_dict[train_big_frame_state[g]] = final_big_s[g] 209 | inp_dict[train_frame_state[g]] = final_s[g] 210 | idx_begin += seq_len - args.big_frame_size 211 | 212 | # forward pass 213 | summary, loss_gpus, _, final_big_s, final_s = \ 214 | sess.run(output_list, feed_dict=inp_dict) 215 | writer.add_summary(summary, step) 216 | for g in range(args.num_gpus): 217 | loss_gpu = loss_gpus[g] / stateful_rnn_length 218 | loss_sum += loss_gpu / args.num_gpus 219 | duration = time.time() - start_time 220 | print('Step {:d}: loss = {:.3f}, ({:.3f} sec/step)'.format( 221 | step, loss_sum, duration)) 222 | 223 | if step % args.ckpt_every == 0: 224 | save(saver, sess, logdir, step) 225 | last_saved_step = step 226 | 227 | # validation 228 | if step % 20 == 0: 229 | print('Testing...') 230 | test_nb_inputs, test_wb_inputs = get_test_batches( 231 | test_files, args.batch_size, args.sample_rate) 232 | test_output_list = [summaries, 233 | losses, 234 | final_big_frame_state, 235 | final_frame_state] 236 | loss_sum = 0 237 | idx_begin = 0 238 | audio_length = args.sample_size - args.big_frame_size 239 | bptt_length = seq_len - args.big_frame_size 240 | stateful_rnn_length = audio_length / bptt_length 241 | for i in range(stateful_rnn_length): 242 | inp_dict = {} 243 | for g in range(args.num_gpus): 244 | inp_dict[nb_input_batch[g]] = \ 245 | nb_inputs_list[g][:, idx_begin:idx_begin+seq_len, :] 246 | inp_dict[wb_input_batch[g]] = \ 247 | wb_inputs_list[g][:, idx_begin:idx_begin+seq_len, :] 248 | inp_dict[train_big_frame_state[g]] = \ 249 | sess.run(net.big_initial_state) 250 | inp_dict[train_frame_state[g]] = \ 251 | sess.run(net.initial_state) 252 | idx_begin += seq_len - args.big_frame_size 253 | # forward pass 254 | summary, test_loss, final_big_s, final_s = \ 255 | sess.run(test_output_list, feed_dict=inp_dict) 256 | test_writer.add_summary(summary, step) 257 | for g in range(args.num_gpus): 258 | loss_gpu = loss_gpus[g] / stateful_rnn_length 259 | loss_sum += loss_gpu / args.num_gpus 260 | print('Step {:d}: val loss = {:.3f}'.format(step, loss_sum)) 261 | 262 | # done training 263 | coord.request_stop() 264 | coord.join(threads) 265 | return 266 | 267 | train() -------------------------------------------------------------------------------- /models/gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import math_ops 3 | from tensorflow.python.ops import embedding_ops 4 | from .ops import mu_law_encode, mu_law_decode 5 | 6 | class HRNN_GAN(object): 7 | def __init__(self, 8 | batch_size, 9 | big_frame_size, 10 | frame_size, 11 | q_levels, 12 | rnn_type, 13 | dim, 14 | n_rnn, 15 | seq_len, 16 | emb_size): 17 | self.batch_size = batch_size 18 | self.big_frame_size = big_frame_size 19 | self.frame_size = frame_size 20 | self.q_levels = q_levels 21 | self.rnn_type = rnn_type 22 | self.dim = dim 23 | self.n_rnn = n_rnn 24 | self.seq_len = seq_len 25 | self.emb_size = emb_size 26 | 27 | # Configure cell type 28 | def single_cell(): 29 | return tf.contrib.rnn.GRUCell(self.dim) 30 | if self.rnn_type == 'LSTM': 31 | def single_cell(): 32 | return tf.contrib.rnn.BasicLSTMCell(self.dim) 33 | self.cell = single_cell() 34 | self.big_cell = single_cell() 35 | 36 | # Configure size of stacked RNN 37 | if self.n_rnn > 1: 38 | self.cell = tf.contrib.rnn.MultiRNNCell( 39 | [single_cell() for _ in range(self.n_rnn)]) 40 | self.big_cell = tf.contrib.rnn.MultiRNNCell( 41 | [single_cell() for _ in range(self.n_rnn)]) 42 | 43 | # Initial states 44 | self.initial_state = self.cell.zero_state( 45 | self.batch_size, tf.float32) 46 | self.big_initial_state = self.big_cell.zero_state( 47 | self.batch_size, tf.float32) 48 | 49 | def _create_network_BigFrame(self, 50 | num_steps, 51 | big_frame_state, 52 | big_input_sequences): 53 | with tf.variable_scope('SampleRNN'): 54 | with tf.variable_scope('big_frame'): 55 | big_input_frames_shape = [tf.shape(big_input_sequences)[0], 56 | tf.shape(big_input_sequences)[1] / self.big_frame_size, 57 | self.big_frame_size] 58 | big_input_frames = tf.reshape(big_input_sequences, big_input_frames_shape) 59 | big_input_frames = (big_input_frames / self.q_levels/2.0) - 1.0 60 | big_input_frames *= 2.0 61 | 62 | big_frame_outputs = [] 63 | 64 | # Create weights variable 65 | big_frame_proj_weights = tf.get_variable( 66 | 'big_frame_proj_weights', 67 | [self.dim, self.dim * self.big_frame_size/self.frame_size], 68 | dtype=tf.float32) 69 | 70 | with tf.variable_scope('big_frame_rnn'): 71 | for time_step in range(num_steps): 72 | if time_step > 0: tf.get_variable_scope().reuse_variables() 73 | # Get output and state at this time step 74 | (big_frame_output, big_frame_state) = self.big_cell( 75 | big_input_frames[:, time_step, :], 76 | big_frame_state) 77 | big_frame_outputs.append( 78 | math_ops.matmul(big_frame_output, 79 | big_frame_proj_weights)) 80 | final_big_frame_state = big_frame_state 81 | big_frame_outputs = tf.stack(big_frame_outputs) 82 | big_frame_outputs = tf.transpose(big_frame_outputs, 83 | perm=[1,0,2]) 84 | big_frame_outputs_shape = [tf.shape(big_frame_outputs)[0], 85 | tf.shape(big_frame_outputs)[1] * self.big_frame_size/self.frame_size, -1] 86 | big_frame_outputs = tf.reshape(big_frame_outputs, 87 | big_frame_outputs_shape) 88 | return big_frame_outputs, final_big_frame_state 89 | 90 | def _create_network_Frame(self, 91 | num_steps, 92 | big_frame_outputs, 93 | frame_state, 94 | input_sequences): 95 | with tf.variable_scope('SampleRNN'): 96 | with tf.variable_scope('frame'): 97 | input_frames_shape = [tf.shape(input_sequences)[0], 98 | tf.shape(input_sequences)[1] / self.frame_size, 99 | self.frame_size] 100 | input_frames = tf.reshape(input_sequences, input_frames_shape) 101 | input_frames = (input_frames / self.q_levels/2.0) - 1.0 102 | input_frames *= 2.0 103 | 104 | frame_outputs = [] 105 | 106 | # Create weights variables 107 | frame_proj_weights = tf.get_variable( 108 | 'frame_proj_weights', 109 | [self.dim, self.dim * self.frame_size], 110 | dtype=tf.float32) 111 | frame_cell_proj_weights = tf.get_variable( 112 | 'frame_cell_proj_weights', 113 | [self.frame_size, self.dim], 114 | dtype=tf.float32) 115 | 116 | with tf.variable_scope('frame_rnn'): 117 | for time_step in range(num_steps): 118 | if time_step > 0: tf.get_variable_scope().reuse_variables() 119 | # Get input 120 | cell_input = tf.reshape(input_frames[:, time_step, :], 121 | [-1, self.frame_size]) 122 | cell_input = math_ops.matmul(cell_input, 123 | frame_cell_proj_weights) 124 | # Add big frame output to input 125 | cell_input = cell_input + tf.reshape( 126 | big_frame_outputs[:, time_step, :], 127 | [-1, self.dim]) 128 | # Get output, state 129 | (frame_cell_output, frame_state) = self.cell(cell_input, 130 | frame_state) 131 | frame_outputs.append(math_ops.matmul(frame_cell_output, 132 | frame_proj_weights)) 133 | final_frame_state = frame_state 134 | frame_outputs = tf.stack(frame_outputs) 135 | frame_outputs = tf.transpose(frame_outputs, perm=[1,0,2]) 136 | frame_outputs_shape = [tf.shape(frame_outputs)[0], 137 | tf.shape(frame_outputs)[1] * self.frame_size, 138 | -1] 139 | frame_outputs = tf.reshape(frame_outputs, frame_outputs_shape) 140 | return frame_outputs, final_frame_state 141 | 142 | def _create_network_Sample(self, 143 | frame_outputs, 144 | sample_input_sequences): 145 | with tf.variable_scope('SampleRNN'): 146 | with tf.variable_scope('sample'): 147 | sample_shape = [tf.shape(sample_input_sequences)[0], 148 | tf.shape(sample_input_sequences)[1] * self.emb_size, 149 | 1] 150 | embedding = tf.get_variable('embedding', [self.q_levels, self.emb_size]) 151 | sample_input_sequences = embedding_ops.embedding_lookup( 152 | embedding, tf.reshape(sample_input_sequences, [-1])) 153 | sample_input_sequences = tf.reshape( 154 | sample_input_sequences, sample_shape) 155 | 156 | # Create a convolution filter variable 157 | filter_initializer = tf.contrib.layers.xavier_initializer_conv2d() 158 | sample_filter_shape = [self.emb_size*2, 1, self.dim] 159 | sample_filter = tf.get_variable( 160 | 'sample_filter', sample_filter_shape, 161 | initializer=filter_initializer) 162 | 163 | # Apply convolution to samples and add frame-level outputs 164 | out = tf.nn.conv1d(sample_input_sequences, 165 | sample_filter, 166 | stride=self.emb_size, 167 | padding='VALID', 168 | name='sample_conv') 169 | out = out + frame_outputs 170 | 171 | # Create weight variables 172 | sample_mlp1_weights = tf.get_variable( 173 | 'sample_mlp1', [self.dim, self.dim], dtype=tf.float32) 174 | sample_mlp2_weights = tf.get_variable( 175 | 'sample_mlp2', [self.dim, self.dim], dtype=tf.float32) 176 | sample_mlp3_weights = tf.get_variable( 177 | 'sample_mlp3', [self.dim, 1], dtype=tf.float32) 178 | 179 | # Get output 180 | out = tf.reshape(out, [-1, self.dim]) 181 | out = math_ops.matmul(out, sample_mlp1_weights) 182 | out = tf.nn.relu(out) 183 | out = math_ops.matmul(out, sample_mlp2_weights) 184 | out = tf.nn.relu(out) 185 | out = math_ops.matmul(out, sample_mlp3_weights) 186 | out = tf.reshape( 187 | out, 188 | [-1, sample_shape[1]/self.emb_size - 1, 1]) 189 | out = tf.multiply(tf.sigmoid(out), (self.q_levels - 1)) 190 | return out 191 | 192 | def _create_network_SampleRNN(self, 193 | train_big_frame_state, 194 | train_frame_state): 195 | with tf.name_scope('SampleRNN'): 196 | # Big frame 197 | big_input_sequences = self.encoded_nb_input_rnn[:, :-self.big_frame_size, :] 198 | big_input_sequences = tf.cast(big_input_sequences, tf.float32) 199 | big_frame_num_steps = \ 200 | (self.seq_len - self.big_frame_size) / self.big_frame_size 201 | # Run big-frame network 202 | big_frame_outputs, final_big_frame_state = self._create_network_BigFrame( 203 | num_steps = big_frame_num_steps, 204 | big_frame_state = train_big_frame_state, 205 | big_input_sequences = big_input_sequences) 206 | 207 | # Frame 208 | input_sequences = self.encoded_nb_input_rnn[:, 209 | self.big_frame_size-self.frame_size:-self.frame_size, :] 210 | input_sequences = tf.cast(input_sequences, tf.float32) 211 | frame_num_steps = (self.seq_len - self.big_frame_size) / self.frame_size 212 | # Run frame network 213 | frame_outputs, final_frame_state = self._create_network_Frame( 214 | num_steps = frame_num_steps, 215 | big_frame_outputs = big_frame_outputs, 216 | frame_state = train_frame_state, 217 | input_sequences = input_sequences) 218 | 219 | # Sample 220 | sample_input_sequences = self.encoded_nb_input_rnn[:, 221 | self.big_frame_size-self.frame_size:-1, :] 222 | sample_output = self._create_network_Sample( 223 | frame_outputs, 224 | sample_input_sequences = sample_input_sequences) 225 | return sample_output, final_big_frame_state, final_frame_state 226 | 227 | def loss_SampleRNN(self, 228 | nb_input_batch_rnn, 229 | wb_input_batch_rnn, 230 | train_big_frame_state, 231 | train_frame_state, 232 | l2_reg_strength=None, 233 | name='sample'): 234 | with tf.name_scope(name): 235 | self.encoded_nb_input_rnn = mu_law_encode( 236 | nb_input_batch_rnn, self.q_levels) 237 | self.encoded_wb_input_rnn = mu_law_encode( 238 | wb_input_batch_rnn, self.q_levels) 239 | wb_encoded_rnn = self.encoded_wb_input_rnn 240 | raw_output, final_big_frame_state, final_frame_state = \ 241 | self._create_network_SampleRNN( 242 | train_big_frame_state, train_frame_state) 243 | with tf.name_scope('loss'): 244 | target_output_rnn = wb_encoded_rnn[:, self.big_frame_size:, :] 245 | target_output_rnn = tf.reshape( 246 | target_output_rnn, 247 | [self.batch_size, -1, 1]) 248 | target_output_rnn = tf.cast(target_output_rnn, dtype=tf.float32) 249 | prediction = mu_law_decode(raw_output, self.q_levels) 250 | # L1 loss 251 | loss = tf.losses.absolute_difference( 252 | target_output_rnn, raw_output) 253 | reduced_loss = tf.reduce_mean(loss) 254 | tf.summary.scalar('loss', reduced_loss) 255 | if l2_reg_strength is None: 256 | return reduced_loss, final_big_frame_state, final_frame_state, target_output_rnn, prediction 257 | else: 258 | l2_loss = tf.add_n([tf.nn.l2_loss(v) 259 | for v in tf.trainable_variables() 260 | if not('bias' in v.name)]) 261 | total_loss = reduced_loss + l2_reg_strength*l2_loss 262 | tf.summary.scalar('l2_loss', l2_loss) 263 | tf.summary.scalar('total_loss', total_loss) 264 | 265 | return total_loss, final_big_frame_state, final_frame_state, target_output_rnn, prediction 266 | -------------------------------------------------------------------------------- /models/hrnn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import math_ops 3 | from tensorflow.python.ops import embedding_ops 4 | from .ops import mu_law_encode, mu_law_decode, log_mel_spectrograms 5 | 6 | class HRNN(object): 7 | def __init__(self, args): 8 | self.batch_size = args.batch_size 9 | self.big_frame_size = args.big_frame_size 10 | self.frame_size = args.frame_size 11 | self.q_levels = args.q_levels 12 | self.rnn_type = args.rnn_type 13 | self.dim = args.dim 14 | self.n_rnn = args.n_rnn 15 | self.seq_len = args.seq_len 16 | self.emb_size = args.emb_size 17 | self.spec_loss_weight = args.spec_loss_weight 18 | self.l1_reg_strength = args.l1_reg_strength 19 | self.sample_rate = args.sample_rate 20 | 21 | # configure cells 22 | def single_cell(): 23 | return tf.contrib.rnn.GRUCell(self.dim) 24 | if self.rnn_type == 'LSTM': 25 | def single_cell(): 26 | return tf.contrib.rnn.BasicLSTMCell(self.dim) 27 | self.cell = single_cell() 28 | self.big_cell = single_cell() 29 | if self.n_rnn > 1: 30 | self.cell = tf.contrib.rnn.MultiRNNCell( 31 | [single_cell() for _ in range(self.n_rnn)]) 32 | self.big_cell = tf.contrib.rnn.MultiRNNCell( 33 | [single_cell() for _ in range(self.n_rnn)]) 34 | self.initial_state = self.cell.zero_state( 35 | self.batch_size, tf.float32) 36 | self.big_initial_state = self.big_cell.zero_state( 37 | self.batch_size, tf.float32) 38 | 39 | # l1 regularizer 40 | if args.l1_reg_strength is not None: 41 | self.l1_regularizer = tf.contrib.layers.l1_regularizer( 42 | scale=args.l1_reg_strength) 43 | 44 | return 45 | 46 | def _create_network_BigFrame(self, num_steps, big_frame_state, 47 | big_input_sequences): 48 | bfs = self.big_frame_size 49 | fs = self.frame_size 50 | with tf.variable_scope('SampleRNN'): 51 | with tf.variable_scope('big_frame'): 52 | big_input_frames_shape = \ 53 | [tf.shape(big_input_sequences)[0], 54 | tf.shape(big_input_sequences)[1] / bfs, 55 | bfs] 56 | big_input_frames = tf.reshape( 57 | big_input_sequences, big_input_frames_shape) 58 | big_input_frames = \ 59 | (big_input_frames / self.q_levels/2.0) - 1.0 60 | big_input_frames *= 2.0 61 | big_frame_outputs = [] 62 | 63 | # weights 64 | big_frame_proj_weights = tf.get_variable( 65 | 'big_frame_proj_weights', 66 | [self.dim, self.dim * bfs/fs], 67 | dtype=tf.float32) 68 | 69 | with tf.variable_scope('big_frame_rnn'): 70 | for time_step in range(num_steps): 71 | if time_step > 0: 72 | tf.get_variable_scope().reuse_variables() 73 | # get output and state at this time step 74 | (big_frame_output, big_frame_state) = self.big_cell( 75 | big_input_frames[:, time_step, :], 76 | big_frame_state) 77 | big_frame_outputs.append( 78 | math_ops.matmul( 79 | big_frame_output, 80 | big_frame_proj_weights)) 81 | final_big_frame_state = big_frame_state 82 | big_frame_outputs = tf.stack(big_frame_outputs) 83 | big_frame_outputs = tf.transpose( 84 | big_frame_outputs, perm=[1,0,2]) 85 | big_frame_outputs_shape = \ 86 | [tf.shape(big_frame_outputs)[0], 87 | tf.shape(big_frame_outputs)[1] * bfs/fs, 88 | -1] 89 | big_frame_outputs = tf.reshape( 90 | big_frame_outputs, big_frame_outputs_shape) 91 | return big_frame_outputs, final_big_frame_state 92 | 93 | def _create_network_Frame(self, num_steps, big_frame_outputs, 94 | frame_state, input_sequences): 95 | fs = self.frame_size 96 | with tf.variable_scope('SampleRNN'): 97 | with tf.variable_scope('frame'): 98 | input_frames_shape = \ 99 | [tf.shape(input_sequences)[0], 100 | tf.shape(input_sequences)[1] / fs, 101 | fs] 102 | input_frames = tf.reshape( 103 | input_sequences, input_frames_shape) 104 | input_frames = (input_frames / self.q_levels/2.0) - 1.0 105 | input_frames *= 2.0 106 | frame_outputs = [] 107 | 108 | # weights 109 | frame_proj_weights = tf.get_variable( 110 | 'frame_proj_weights', 111 | [self.dim, self.dim * fs], 112 | dtype=tf.float32) 113 | frame_cell_proj_weights = tf.get_variable( 114 | 'frame_cell_proj_weights', 115 | [fs, self.dim], 116 | dtype=tf.float32) 117 | 118 | with tf.variable_scope('frame_rnn'): 119 | for time_step in range(num_steps): 120 | if time_step > 0: 121 | tf.get_variable_scope().reuse_variables() 122 | # get input 123 | cell_input = tf.reshape( 124 | input_frames[:, time_step, :], 125 | [-1, self.frame_size]) 126 | cell_input = math_ops.matmul( 127 | cell_input, frame_cell_proj_weights) 128 | # add big frame output to input 129 | bf_output = tf.reshape( 130 | big_frame_outputs[:, time_step, :], 131 | [-1, self.dim]) 132 | cell_input = tf.add(cell_input, bf_output) 133 | # get outputs 134 | (frame_cell_output, frame_state) = self.cell( 135 | cell_input, frame_state) 136 | frame_outputs.append( 137 | math_ops.matmul( 138 | frame_cell_output, frame_proj_weights)) 139 | final_frame_state = frame_state 140 | frame_outputs = tf.stack(frame_outputs) 141 | frame_outputs = tf.transpose(frame_outputs, perm=[1,0,2]) 142 | frame_outputs_shape = \ 143 | [tf.shape(frame_outputs)[0], 144 | tf.shape(frame_outputs)[1] * fs, 145 | -1] 146 | frame_outputs = tf.reshape(frame_outputs, frame_outputs_shape) 147 | return frame_outputs, final_frame_state 148 | 149 | def _create_network_Sample(self, frame_outputs, sample_input_sequences): 150 | with tf.variable_scope('SampleRNN'): 151 | with tf.variable_scope('sample'): 152 | sample_shape = \ 153 | [tf.shape(sample_input_sequences)[0], 154 | tf.shape(sample_input_sequences)[1] * self.emb_size, 155 | 1] 156 | 157 | # embedding layer 158 | embedding = tf.get_variable( 159 | 'embedding', [self.q_levels, self.emb_size]) 160 | sample_input_sequences = embedding_ops.embedding_lookup( 161 | embedding, tf.reshape(sample_input_sequences, [-1])) 162 | sample_input_sequences = tf.reshape( 163 | sample_input_sequences, sample_shape) 164 | 165 | # convolution 166 | filter_initializer = tf.contrib.layers.xavier_initializer_conv2d() 167 | sample_filter_shape = [self.emb_size*2, 1, self.dim] 168 | sample_filter = tf.get_variable( 169 | 'sample_filter', sample_filter_shape, 170 | initializer=filter_initializer) 171 | out = tf.nn.conv1d(sample_input_sequences, 172 | sample_filter, 173 | stride=self.emb_size, 174 | padding='VALID', 175 | name='sample_conv') 176 | out = tf.add(out, frame_outputs) 177 | 178 | # multilayer perceptron 179 | sample_mlp1_weights = tf.get_variable( 180 | 'sample_mlp1', [self.dim, self.dim], dtype=tf.float32) 181 | sample_mlp2_weights = tf.get_variable( 182 | 'sample_mlp2', [self.dim, self.dim], dtype=tf.float32) 183 | sample_mlp3_weights = tf.get_variable( 184 | 'sample_mlp3', [self.dim, 1], dtype=tf.float32) 185 | out = tf.reshape(out, [-1, self.dim]) 186 | out = math_ops.matmul(out, sample_mlp1_weights) 187 | out = tf.nn.relu(out) 188 | out = math_ops.matmul(out, sample_mlp2_weights) 189 | out = tf.nn.relu(out) 190 | out = math_ops.matmul(out, sample_mlp3_weights) 191 | out = tf.reshape( 192 | out, [-1, sample_shape[1]/self.emb_size - 1, 1]) 193 | out = tf.multiply(tf.sigmoid(out), (self.q_levels - 1)) 194 | return out 195 | 196 | def _create_network_SampleRNN(self, train_big_frame_state, 197 | train_frame_state): 198 | bfs = self.big_frame_size 199 | fs = self.frame_size 200 | with tf.name_scope('SampleRNN'): 201 | # big frame network 202 | big_input_sequences = self.encoded_nb_input_rnn[:, :-bfs, :] 203 | big_input_sequences = tf.cast(big_input_sequences, tf.float32) 204 | big_frame_num_steps = (self.seq_len - bfs) / bfs 205 | big_frame_outputs, final_big_frame_state = \ 206 | self._create_network_BigFrame( 207 | num_steps=big_frame_num_steps, 208 | big_frame_state=train_big_frame_state, 209 | big_input_sequences=big_input_sequences) 210 | 211 | # frame network 212 | input_sequences = self.encoded_nb_input_rnn[:, bfs-fs:-fs, :] 213 | input_sequences = tf.cast(input_sequences, tf.float32) 214 | frame_num_steps = (self.seq_len - bfs) / fs 215 | frame_outputs, final_frame_state = \ 216 | self._create_network_Frame( 217 | num_steps=frame_num_steps, 218 | big_frame_outputs=big_frame_outputs, 219 | frame_state=train_frame_state, 220 | input_sequences=input_sequences) 221 | 222 | # sample 223 | sample_input_sequences = self.encoded_nb_input_rnn[:, bfs-fs:-1, :] 224 | sample_output = self._create_network_Sample( 225 | frame_outputs, sample_input_sequences=sample_input_sequences) 226 | 227 | return sample_output, final_big_frame_state, final_frame_state 228 | 229 | def forward(self, nb_input_batch_rnn, wb_input_batch_rnn, 230 | train_big_frame_state, train_frame_state, 231 | l1_reg_strength=None, inference_only=False): 232 | bfs = self.big_frame_size 233 | with tf.name_scope('forward'): 234 | self.encoded_nb_input_rnn = mu_law_encode( 235 | nb_input_batch_rnn, self.q_levels) 236 | self.encoded_wb_input_rnn = mu_law_encode( 237 | wb_input_batch_rnn, self.q_levels) 238 | raw_output, final_big_frame_state, final_frame_state = \ 239 | self._create_network_SampleRNN( 240 | train_big_frame_state, train_frame_state) 241 | with tf.name_scope('total_loss'): 242 | # --------------------------- 243 | # L1 loss 244 | # --------------------------- 245 | target_output_rnn = \ 246 | self.encoded_wb_input_rnn[:, bfs:, :] 247 | target_output_rnn = tf.reshape( 248 | target_output_rnn, [self.batch_size, -1, 1]) 249 | target_output_rnn = tf.cast(target_output_rnn, dtype=tf.float32) 250 | prediction = raw_output 251 | reg_penalty = 0 252 | if l1_reg_strength is not None: 253 | print('Applying L1 regularization') 254 | reg_penalty = tf.contrib.layers.apply_regularization( 255 | self.l1_regularizer, tf.trainable_variables()) 256 | l1_loss = tf.losses.absolute_difference( 257 | target_output_rnn, prediction) 258 | l1_loss = tf.reduce_mean(l1_loss) + reg_penalty 259 | tf.summary.scalar('l1_loss', l1_loss) 260 | 261 | # --------------------------- 262 | # spectral loss 263 | # --------------------------- 264 | # mu-law decode prediction 265 | pred_signals = tf.squeeze(raw_output, axis=2) 266 | pred_signals = mu_law_decode(pred_signals, self.q_levels) 267 | gt_signals = tf.squeeze(wb_input_batch_rnn, axis=2) 268 | 269 | # get log-mel spectrograms of prediction and ground truth 270 | prediction_spec = log_mel_spectrograms( 271 | pred_signals, self.sample_rate) 272 | gt_spec = log_mel_spectrograms( 273 | gt_signals[:, bfs:], self.sample_rate) 274 | 275 | # compute L2 loss between log-mel spectrograms 276 | spec_loss = tf.squared_difference(gt_spec, prediction_spec) 277 | spec_loss = tf.reduce_mean(spec_loss) 278 | tf.summary.scalar('spectral_loss', spec_loss) 279 | 280 | # --------------------------- 281 | # total loss 282 | # --------------------------- 283 | total_loss = l1_loss + (self.spec_loss_weight * spec_loss) 284 | tf.summary.scalar('total_loss', total_loss) 285 | 286 | if inference_only: 287 | return total_loss, pred_signals, final_big_frame_state, final_frame_state 288 | else: 289 | return total_loss, final_big_frame_state, final_frame_state -------------------------------------------------------------------------------- /train_gan.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from datetime import datetime 4 | import json 5 | import os 6 | import sys 7 | import time 8 | import numpy as np 9 | import scipy 10 | import tensorflow as tf 11 | import librosa 12 | import fnmatch 13 | from models import HRNN_GAN, Discriminator, AudioReader 14 | from models import find_files, get_test_batches, average_gradients, load, save, optimizer_factory, scalar_summary 15 | from tensorflow import AggregationMethod as aggreg 16 | from tensorflow.python.client import timeline 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--num_gpus', type=int, default=1) 21 | parser.add_argument('--batch_size', type=int, required=True) 22 | parser.add_argument('--nb_data_dir', type=str, required=True) 23 | parser.add_argument('--wb_data_dir', type=str, required=True) 24 | parser.add_argument('--test_nb_data_dir', type=str, required=True) 25 | parser.add_argument('--test_wb_data_dir', type=str, required=True) 26 | parser.add_argument('--logdir_root', type=str, required=True) 27 | parser.add_argument('--ckpt_every', type=int, default=20) 28 | parser.add_argument('--num_steps', type=int, default=10000) 29 | parser.add_argument('--learning_rate', type=float, required=True) 30 | parser.add_argument('--sample_size', type=int, default=48000) 31 | parser.add_argument('--sample_rate', type=int, default=16000) 32 | parser.add_argument('--l2_reg_strength', type=float, default=0.0) 33 | parser.add_argument('--silence_threshold', type=float, default=0.0) 34 | parser.add_argument('--optimizer', type=str, default='adam', 35 | choices=optimizer_factory.keys()) 36 | parser.add_argument('--momentum', type=float, default=0.9) 37 | parser.add_argument('--seq_len', type=int, required=True) 38 | parser.add_argument('--big_frame_size', type=int, required=True) 39 | parser.add_argument('--frame_size', type=int, required=True) 40 | parser.add_argument('--q_levels', type=int, required=True) 41 | parser.add_argument('--dim', type=int, required=True) 42 | parser.add_argument('--n_rnn', type=int, choices=xrange(1,6), required=True) 43 | parser.add_argument('--emb_size', type=int, required=True) 44 | parser.add_argument('--rnn_type', choices=['LSTM', 'GRU'], required=True) 45 | parser.add_argument('--max_checkpoints', type=int, default=5) 46 | parser.add_argument('--d_learning_rate', type=float, required=True) 47 | parser.add_argument('--bias_D_conv', type=bool, default=True) 48 | parser.add_argument('--pretrain_num_steps', type=int, required=True) 49 | parser.add_argument('--update_d_every', type=int, required=True) 50 | return parser.parse_args() 51 | 52 | def create_model(args): 53 | net = HRNN_GAN(batch_size=args.batch_size, 54 | big_frame_size=args.big_frame_size, 55 | frame_size=args.frame_size, 56 | q_levels=args.q_levels, 57 | rnn_type=args.rnn_type, 58 | dim=args.dim, 59 | n_rnn=args.n_rnn, 60 | seq_len=args.seq_len, 61 | emb_size=args.emb_size) 62 | return net 63 | 64 | def create_discriminator(args, name): 65 | discr = Discriminator(bias_D_conv=args.bias_D_conv, 66 | name=name) 67 | return discr 68 | 69 | def train(): 70 | args = get_args() 71 | if args.l2_reg_strength == 0: 72 | args.l2_reg_strength = None 73 | logdir = os.path.join(args.logdir_root, 'train') 74 | logdir_test = os.path.join(args.logdir_root, 'test') 75 | logdir_d = os.path.join(args.logdir_root, 'discriminator') 76 | coord = tf.train.Coordinator() 77 | 78 | # Number of steps to train HRNN only 79 | pretrain_num_steps = args.pretrain_num_steps 80 | # Number of steps to train HRNN before updating discriminator 81 | update_d_every = args.update_d_every 82 | 83 | # Get testing files 84 | test_files = find_files(args.test_nb_data_dir) 85 | 86 | # Create inputs 87 | with tf.name_scope('create_inputs'): 88 | reader = AudioReader(args.nb_data_dir, 89 | args.wb_data_dir, 90 | coord, 91 | sample_rate=args.sample_rate, 92 | sample_size=args.sample_size, 93 | silence_threshold=args.silence_threshold) 94 | nb_audio_batch, wb_audio_batch = \ 95 | reader.dequeue(args.batch_size) 96 | 97 | # Create model 98 | net = create_model(args) 99 | discr = create_discriminator(args, name='discr') 100 | global_step = tf.get_variable( 101 | 'global_step', [], 102 | initializer=tf.constant_initializer(0), 103 | trainable=False) 104 | 105 | # Optimizers 106 | optim = optimizer_factory[args.optimizer]( 107 | learning_rate=args.learning_rate, 108 | momentum=args.momentum) 109 | d_optim = tf.train.AdamOptimizer( 110 | args.d_learning_rate) 111 | 112 | # Set up placeholders and variables on each GPU 113 | tower_net_grads = [] 114 | tower_net_grads_no_adv = [] 115 | tower_d_grads = [] 116 | losses = [] 117 | losses_no_adv = [] 118 | losses_adv = [] 119 | d_losses = [] 120 | wb_input_batch_rnn = [] 121 | nb_input_batch_rnn = [] 122 | samplernn_preds = [] 123 | train_big_frame_state = [] 124 | train_frame_state = [] 125 | final_big_frame_state = [] 126 | final_frame_state = [] 127 | goals = [] 128 | predictions = [] 129 | for i in xrange(args.num_gpus): 130 | with tf.device('/gpu:%d' % (i)): 131 | # Create input placeholders 132 | nb_input_batch_rnn.append( 133 | tf.Variable(tf.zeros([net.batch_size, net.seq_len, 1]), 134 | trainable=False, 135 | name='nb_input_batch_rnn', 136 | dtype=tf.float32)) 137 | wb_input_batch_rnn.append( 138 | tf.Variable(tf.zeros([net.batch_size, net.seq_len, 1]), 139 | trainable=False, 140 | name='wb_input_batch_rnn', 141 | dtype=tf.float32)) 142 | # Create initial states 143 | train_big_frame_state.append( 144 | net.big_cell.zero_state(net.batch_size, tf.float32)) 145 | final_big_frame_state.append( 146 | net.big_cell.zero_state(net.batch_size, tf.float32)) 147 | train_frame_state.append( 148 | net.cell.zero_state(net.batch_size, tf.float32)) 149 | final_frame_state.append( 150 | net.cell.zero_state(net.batch_size, tf.float32)) 151 | # Target/prediction placeholders 152 | goals.append( 153 | tf.Variable(tf.zeros([net.batch_size, net.seq_len, 1]), 154 | trainable=False, 155 | name='targets', 156 | dtype=tf.float32)) 157 | predictions.append( 158 | tf.Variable(tf.zeros([net.batch_size, net.seq_len, 1]), 159 | trainable=False, 160 | name='predictions', 161 | dtype=tf.float32)) 162 | 163 | # Network output variables 164 | with tf.variable_scope(tf.get_variable_scope()): 165 | for i in xrange(args.num_gpus): 166 | with tf.device('/gpu:%d' % (i)): 167 | with tf.name_scope('TOWER_%d' % i) as scope: 168 | print("Creating model on GPU:%d" % i) 169 | 170 | # SampleRNN outputs 171 | loss, final_big_frame_state[i], final_frame_state[i], \ 172 | goals[i], predictions[i] = \ 173 | net.loss_SampleRNN( 174 | nb_input_batch_rnn[i], 175 | wb_input_batch_rnn[i], 176 | train_big_frame_state[i], 177 | train_frame_state[i], 178 | l2_reg_strength=args.l2_reg_strength) 179 | 180 | # Discriminator inputs 181 | bfs = net.big_frame_size 182 | predictions[i] = tf.reshape( 183 | predictions[i], 184 | [net.batch_size, net.seq_len-bfs, 1]) 185 | d_rl_input = tf.concat( 186 | [tf.cast(wb_input_batch_rnn[i][:, :-bfs, :], dtype=tf.int32), 187 | tf.cast(nb_input_batch_rnn[i][:, :-bfs, :], dtype=tf.int32)], 188 | 2) 189 | d_fk_input = tf.concat( 190 | [tf.cast(predictions[i], dtype=tf.int32), 191 | tf.cast(nb_input_batch_rnn[i][:, :-bfs, :], dtype=tf.int32)], 192 | 2) 193 | d_rl_input = tf.cast(d_rl_input, dtype=tf.float32) 194 | d_fk_input = tf.cast(d_fk_input, dtype=tf.float32) 195 | 196 | # Discriminator outputs 197 | d_rl_logits = discr.logits_Discriminator( 198 | d_rl_input, reuse=False) 199 | d_fk_logits = discr.logits_Discriminator( 200 | d_fk_input, reuse=True) 201 | 202 | # Discriminator loss 203 | d_rl_loss = tf.reduce_mean( 204 | tf.squared_difference(d_rl_logits, 1.0)) 205 | d_fk_loss = tf.reduce_mean( 206 | tf.squared_difference(d_fk_logits, 0.0)) 207 | d_loss = d_rl_loss + d_fk_loss 208 | d_losses.append(d_loss) 209 | 210 | # SampleRNN loss 211 | net_adv_loss = tf.reduce_mean( 212 | tf.squared_difference(d_fk_logits, 1.0)) 213 | net_loss = net_adv_loss + loss 214 | losses.append(net_loss) 215 | losses_no_adv.append(loss) 216 | losses_adv.append(net_adv_loss) 217 | 218 | # Scalar summaries 219 | d_rl_loss_sum = scalar_summary('d_rl_loss', d_rl_loss) 220 | d_fk_loss_sum = scalar_summary('d_fk_loss', d_fk_loss) 221 | d_loss_sum = scalar_summary('d_loss', d_loss) 222 | net_loss_sum = scalar_summary( 223 | 'samplernn_loss', net_loss) 224 | net_loss_adv_sum = scalar_summary( 225 | 'samplernn_adv_loss', net_adv_loss) 226 | 227 | # Get trainable vars 228 | net_trainable = tf.trainable_variables( 229 | scope='SampleRNN') 230 | d_trainable = tf.trainable_variables( 231 | scope='Discriminator') 232 | 233 | # Gradients 234 | gradients = optim.compute_gradients( 235 | net_loss, net_trainable, 236 | aggregation_method=aggreg.EXPERIMENTAL_ACCUMULATE_N) 237 | gradients_no_adv = optim.compute_gradients( 238 | loss, net_trainable, 239 | aggregation_method=aggreg.EXPERIMENTAL_ACCUMULATE_N) 240 | d_gradients = d_optim.compute_gradients( 241 | d_loss, d_trainable, 242 | aggregation_method=aggreg.EXPERIMENTAL_ACCUMULATE_N) 243 | 244 | tower_net_grads.append(gradients) 245 | tower_net_grads_no_adv.append(gradients_no_adv) 246 | tower_d_grads.append(d_gradients) 247 | 248 | tf.get_variable_scope().reuse_variables() 249 | 250 | # Gradients 251 | net_grad_vars = average_gradients(tower_net_grads) 252 | net_grad_vars_no_adv = average_gradients(tower_net_grads_no_adv) 253 | d_grad_vars = average_gradients(tower_d_grads) 254 | 255 | # Clip gradients 256 | grads, vars = zip(*net_grad_vars) 257 | grads_no_adv, vars = zip(*net_grad_vars_no_adv) 258 | grads_clipped, _ = tf.clip_by_global_norm(grads, 5.0) 259 | grads_clipped_no_adv, _ = tf.clip_by_global_norm(grads_no_adv, 5.0) 260 | net_grad_vars = zip(grads_clipped, vars) 261 | net_grad_vars_no_adv = zip(grads_clipped_no_adv, vars) 262 | 263 | # Apply gradient ops 264 | apply_gradient_op = optim.apply_gradients( 265 | net_grad_vars, global_step=global_step) 266 | apply_gradient_op_no_adv = optim.apply_gradients( 267 | net_grad_vars_no_adv, global_step=global_step) 268 | d_apply_gradient_op = d_optim.apply_gradients( 269 | d_grad_vars, global_step=global_step) 270 | 271 | # --------------------------------------------------------------- 272 | # Start/continue training 273 | # --------------------------------------------------------------- 274 | writer = tf.summary.FileWriter(logdir) 275 | test_writer = tf.summary.FileWriter(logdir_test) 276 | writer.add_graph(tf.get_default_graph()) 277 | test_writer.add_graph(tf.get_default_graph()) 278 | summaries = tf.summary.merge_all() 279 | 280 | # Configure session 281 | tf_config = tf.ConfigProto(allow_soft_placement=True) 282 | tf_config.gpu_options.allow_growth = True 283 | sess = tf.Session(config=tf_config) 284 | init = tf.global_variables_initializer() 285 | sess.run(init) 286 | 287 | # Load checkpoint 288 | saver = tf.train.Saver(var_list=net_trainable, 289 | max_to_keep=args.max_checkpoints) 290 | d_saver = tf.train.Saver(var_list=d_trainable, 291 | max_to_keep=args.max_checkpoints) 292 | try: 293 | saved_global_step = load(saver, sess, logdir) 294 | load(d_saver, sess, logdir_d) 295 | if saved_global_step is None: saved_global_step = -1 296 | except: 297 | print("Something went wrong while restoring checkpoint.") 298 | raise 299 | 300 | # Start queue runners 301 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 302 | reader.start_threads(sess) 303 | 304 | # Train 305 | step = None 306 | last_saved_step = saved_global_step 307 | try: 308 | for step in range(saved_global_step + 1, args.num_steps + 1): 309 | final_big_s = [] 310 | final_s = [] 311 | for g in xrange(args.num_gpus): 312 | # Initialize cells 313 | final_big_s.append(sess.run(net.big_initial_state)) 314 | final_s.append(sess.run(net.initial_state)) 315 | start_time = time.time() 316 | 317 | nb_inputs_list = [] 318 | wb_inputs_list = [] 319 | for _ in xrange(args.num_gpus): 320 | # Get input batches 321 | nb_inputs, wb_inputs = sess.run( 322 | [nb_audio_batch, wb_audio_batch]) 323 | nb_inputs_list.append(nb_inputs) 324 | wb_inputs_list.append(wb_inputs) 325 | 326 | loss_sum = 0 327 | d_loss_sum = 0 328 | loss_adv_sum = 0 329 | 330 | idx_begin = 0 331 | audio_length = args.sample_size - args.big_frame_size 332 | bptt_length = args.seq_len - args.big_frame_size 333 | stateful_rnn_length = audio_length / bptt_length 334 | output_list = [summaries, 335 | losses, 336 | losses_adv, 337 | d_losses, 338 | apply_gradient_op, 339 | final_big_frame_state, 340 | final_frame_state] 341 | output_list_no_adv = [summaries, 342 | losses_no_adv, 343 | losses_adv, 344 | d_losses, 345 | apply_gradient_op_no_adv, 346 | final_big_frame_state, 347 | final_frame_state] 348 | discr_output_list = [d_apply_gradient_op] 349 | 350 | for i in range(0, stateful_rnn_length): 351 | inp_dict = {} 352 | for g in xrange(args.num_gpus): 353 | # Add seq_len samples as input for truncated BPTT 354 | inp_dict[nb_input_batch_rnn[g]] = \ 355 | nb_inputs_list[g][:, idx_begin:idx_begin+args.seq_len, :] 356 | inp_dict[wb_input_batch_rnn[g]] = \ 357 | wb_inputs_list[g][:, idx_begin:idx_begin+args.seq_len, :] 358 | inp_dict[train_big_frame_state[g]] = final_big_s[g] 359 | inp_dict[train_frame_state[g]] = final_s[g] 360 | idx_begin += args.seq_len - args.big_frame_size 361 | 362 | # Forward pass 363 | if (step < pretrain_num_steps): 364 | # Train with L1 365 | summary, loss_gpus, loss_adv_gpus, d_loss_gpus, _, final_big_s, final_s = \ 366 | sess.run(output_list_no_adv, feed_dict=inp_dict) 367 | else: 368 | # Train with L1 + adversarial loss 369 | summary, loss_gpus, loss_adv_gpus, d_loss_gpus, _, final_big_s, final_s = \ 370 | sess.run(output_list, feed_dict=inp_dict) 371 | 372 | writer.add_summary(summary, step) 373 | for g in xrange(args.num_gpus): 374 | loss_gpu = loss_gpus[g] / stateful_rnn_length 375 | d_loss_gpu = d_loss_gpus[g] / stateful_rnn_length 376 | loss_adv_gpu = loss_adv_gpus[g] / stateful_rnn_length 377 | 378 | loss_sum += loss_gpu / args.num_gpus 379 | d_loss_sum += d_loss_gpu / args.num_gpus 380 | loss_adv_sum += loss_adv_gpu / args.num_gpus 381 | duration = time.time() - start_time 382 | 383 | print('****** STEP {:d} ({:.3f} sec/step) ******'.format(step, duration)) 384 | if (step < pretrain_num_steps): 385 | print('[SampleRNN] L1 loss = {:.3f}'.format(loss_sum)) 386 | else: 387 | print('[SampleRNN] L1 + adv loss = {:.3f}'.format(loss_sum)) 388 | print('[SampleRNN] adv loss = {:.3f}'.format(loss_adv_sum)) 389 | print('[Discriminator] L2 loss = {:.3f}'.format(d_loss_sum)) 390 | 391 | if (step >= pretrain_num_steps) and (step % update_d_every == 0): 392 | # Update discriminator parameters 393 | print('Updating discriminator parameters...') 394 | _ = sess.run(discr_output_list, feed_dict=inp_dict) 395 | 396 | if step % args.ckpt_every == 0: 397 | # Save models 398 | save(saver, sess, logdir, step) 399 | save(d_saver, sess, logdir_d, step) 400 | last_saved_step = step 401 | 402 | if step % 20 == 0: 403 | # Test 404 | test_nb_inputs, test_wb_inputs = get_test_batches( 405 | test_files, args.batch_size, args.sample_rate) 406 | test_output_list = [summaries, 407 | losses, 408 | final_big_frame_state, 409 | final_frame_state] 410 | test_output_list_no_adv = [summaries, 411 | losses_no_adv, 412 | final_big_frame_state, 413 | final_frame_state] 414 | 415 | loss_sum = 0 416 | idx_begin = 0 417 | audio_length = args.sample_size - args.big_frame_size 418 | bptt_length = args.seq_len - args.big_frame_size 419 | stateful_rnn_length = audio_length / bptt_length 420 | 421 | for i in range(0, stateful_rnn_length): 422 | inp_dict = {} 423 | for g in xrange(args.num_gpus): 424 | # Add seq_len samples as input for truncated BPTT 425 | inp_dict[nb_input_batch_rnn[g]] = \ 426 | nb_inputs_list[g][:, idx_begin:idx_begin+args.seq_len, :] 427 | inp_dict[wb_input_batch_rnn[g]] = \ 428 | wb_inputs_list[g][:, idx_begin:idx_begin+args.seq_len, :] 429 | inp_dict[train_big_frame_state[g]] = \ 430 | sess.run(net.big_initial_state) 431 | inp_dict[train_frame_state[g]] = \ 432 | sess.run(net.initial_state) 433 | idx_begin += args.seq_len - args.big_frame_size 434 | 435 | # Forward pass 436 | if (step < pretrain_num_steps): 437 | summary, test_loss_gpus, final_big_s, final_s = \ 438 | sess.run(test_output_list_no_adv, feed_dict=inp_dict) 439 | else: 440 | summary, test_loss_gpus, final_big_s, final_s = \ 441 | sess.run(test_output_list, feed_dict=inp_dict) 442 | test_writer.add_summary(summary, step) 443 | 444 | for g in xrange(args.num_gpus): 445 | loss_gpu = test_loss_gpus[g] / stateful_rnn_length 446 | loss_sum += loss_gpu / args.num_gpus 447 | print('Testing loss: {}'.format(loss_sum)) 448 | 449 | except KeyboardInterrupt: 450 | print() 451 | finally: 452 | if step > last_saved_step: 453 | print('Saving HRNN model...') 454 | save(saver, sess, logdir, step) 455 | print('Saving discriminator model...') 456 | save(d_saver, sess, logdir_d, step) 457 | coord.request_stop() 458 | coord.join(threads) 459 | 460 | train() --------------------------------------------------------------------------------