├── static ├── arch.png ├── fiwgan_arch.png ├── github19244c_2_5words.wav └── start_20026_222_ch59_seed436.wav ├── backup.py ├── LICENSE.txt ├── README.md ├── loader.py ├── cinfowavegan.py ├── train_fiwgan.py └── train_ciwgan.py /static/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/arch.png -------------------------------------------------------------------------------- /static/fiwgan_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/fiwgan_arch.png -------------------------------------------------------------------------------- /static/github19244c_2_5words.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/github19244c_2_5words.wav -------------------------------------------------------------------------------- /static/start_20026_222_ch59_seed436.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/start_20026_222_ch59_seed436.wav -------------------------------------------------------------------------------- /backup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | if __name__ == '__main__': 4 | import glob 5 | import os 6 | import shutil 7 | import sys 8 | import time 9 | 10 | import tensorflow as tf 11 | 12 | train_dir, nmin = sys.argv[1:3] 13 | nsec = int(float(nmin) * 60.) 14 | 15 | backup_dir = os.path.join(train_dir, 'backup') 16 | 17 | if not os.path.exists(backup_dir): 18 | os.makedirs(backup_dir) 19 | 20 | while tf.train.latest_checkpoint(train_dir) is None: 21 | print('Waiting for first checkpoint') 22 | time.sleep(1) 23 | 24 | while True: 25 | latest_ckpt = tf.train.latest_checkpoint(train_dir) 26 | 27 | # Sleep for two seconds in case file flushing 28 | time.sleep(2) 29 | 30 | for fp in glob.glob(latest_ckpt + '*'): 31 | _, name = os.path.split(fp) 32 | backup_fp = os.path.join(backup_dir, name) 33 | print('{}->{}'.format(fp, backup_fp)) 34 | shutil.copyfile(fp, backup_fp) 35 | print('-' * 80) 36 | 37 | # Sleep for an hour 38 | time.sleep(nsec) 39 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Christopher Donahue 4 | Additions and modifications (fiwgan.py & ciwgan.py) (c) 2020 Gasper Begus 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fiwGAN (Featural InfoWaveGAN): Lexical Learning in Generative Adversarial Phonology 2 | 3 | PAPER HERE: https://www.sciencedirect.com/science/article/pii/S0893608021001052 4 | 5 | In `fiwGAN.py`. An architecture for modeling lexical learning from raw acoustic inputs called Featural InfoWaveGAN (fiwGAN) that combines Deep Convolutional GAN architecture for audio data (WaveGAN) with categorical variables in information theoretic proposal InfoGAN. Unlike InfoGAN, latent code is distributed binomially and the training is performed with sigmoid cross-entropy. Based on [WaveGAN](https://github.com/chrisdonahue/wavegan) (Donahue et al. 2019) and InfoGAN (Chen et al. 2016), partially also on code by [Rodionov](https://github.com/singnet/semantic-vision/blob/master/experiments/concept_learning/gans/info-wgan-gp/10_originfo_sepQ_v2_lr1e-3/train.py) (2018). 6 | 7 | 8 | 9 | 10 | # ciwGAN (Categorical InfoWaveGAN) 11 | 12 | An architecture for modeling lexical learning from raw acoustic inputs called Categorical InfoWaveGAN that combines Deep Convolutional GAN architecture for audio data (WaveGAN) with categorical variables in information theoretic proposal InfoGAN. 13 | 14 | Based on WaveGAN (Donahue et al. 2019) (https://github.com/chrisdonahue/wavegan) and WGAN-GP implementation of InfoGAN by Sergey Rodionov (https://github.com/singnet/semantic-vision/blob/master/experiments/concept_learning/gans/info-wgan-gp/10_originfo_sepQ_v2_lr1e-3/train.py). 15 | 16 | In addition to the Generator and the Discriminator networks, the architecture introduces a network that learns to classify generated outputs and forces the Generator to encode lexical information in its latent space. Lexical and semantic encoding is represented with a set of categorical binary variables. The network is trained on five lexical items from TIMIT. The network learns to generate lexical items and encodes the identity of each item in categorical variables of the latent space. By manipulating the categorical variables in the latent space that encode lexical information, the network outputs the five lexical items, suggesting that each lexical item is represented with unique categorical code. Such representation can serve as the basis for lexical and semantic learning from raw acoustic input. 17 | 18 | 19 | 20 | After 19,244 steps trained on _oily, water, rag, suit_ and _year_ from TIMIT, the network learns to output lexical items based on latent code. The following generated outputs are generated with the following values of c: 21 | 22 | 1. \[1, 0, 0, 0, 0\]: _suit_ 23 | 2. \[0, 1, 0, 0, 0\]: _year_ 24 | 3. \[0, 0, 1, 0, 0\]: _water_ 25 | 4. \[0, 0, 0, 1, 0\]: _oily_ 26 | 5. \[0, 0, 0, 0, 1\]: _rag_ 27 | 28 | [Audio sample 1](http://faculty.washington.edu/begus/files/github19244c_2_5words.wav) 29 | [Audio sample 2](http://faculty.washington.edu/begus/files/github19244c_2_5words.wav) 30 | 31 | To change number of categorical latent variables: 32 | 33 | ``` 34 | --num_categ n 35 | ``` 36 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from scipy.io.wavfile import read as wavread 2 | import numpy as np 3 | 4 | import tensorflow as tf 5 | 6 | import sys 7 | 8 | 9 | def decode_audio(fp, fs=None, num_channels=1, normalize=False, fast_wav=False): 10 | """Decodes audio file paths into 32-bit floating point vectors. 11 | 12 | Args: 13 | fp: Audio file path. 14 | fs: If specified, resamples decoded audio to this rate. 15 | mono: If true, averages channels to mono. 16 | fast_wav: Assume fp is a standard WAV file (PCM 16-bit or float 32-bit). 17 | 18 | Returns: 19 | A np.float32 array containing the audio samples at specified sample rate. 20 | """ 21 | if fast_wav: 22 | # Read with scipy wavread (fast). 23 | _fs, _wav = wavread(fp) 24 | if fs is not None and fs != _fs: 25 | raise NotImplementedError('Scipy cannot resample audio.') 26 | if _wav.dtype == np.int16: 27 | _wav = _wav.astype(np.float32) 28 | _wav /= 32768. 29 | elif _wav.dtype == np.float32: 30 | _wav = np.copy(_wav) 31 | else: 32 | raise NotImplementedError('Scipy cannot process atypical WAV files.') 33 | else: 34 | # Decode with librosa load (slow but supports file formats like mp3). 35 | import librosa 36 | _wav, _fs = librosa.core.load(fp, sr=fs, mono=False) 37 | if _wav.ndim == 2: 38 | _wav = np.swapaxes(_wav, 0, 1) 39 | 40 | assert _wav.dtype == np.float32 41 | 42 | # At this point, _wav is np.float32 either [nsamps,] or [nsamps, nch]. 43 | # We want [nsamps, 1, nch] to mimic 2D shape of spectral feats. 44 | if _wav.ndim == 1: 45 | nsamps = _wav.shape[0] 46 | nch = 1 47 | else: 48 | nsamps, nch = _wav.shape 49 | _wav = np.reshape(_wav, [nsamps, 1, nch]) 50 | 51 | # Average (mono) or expand (stereo) channels 52 | if nch != num_channels: 53 | if num_channels == 1: 54 | _wav = np.mean(_wav, 2, keepdims=True) 55 | elif nch == 1 and num_channels == 2: 56 | _wav = np.concatenate([_wav, _wav], axis=2) 57 | else: 58 | raise ValueError('Number of audio channels not equal to num specified') 59 | 60 | if normalize: 61 | factor = np.max(np.abs(_wav)) 62 | if factor > 0: 63 | _wav /= factor 64 | 65 | return _wav 66 | 67 | 68 | def decode_extract_and_batch( 69 | fps, 70 | batch_size, 71 | slice_len, 72 | decode_fs, 73 | decode_num_channels, 74 | decode_normalize=True, 75 | decode_fast_wav=False, 76 | decode_parallel_calls=1, 77 | slice_randomize_offset=False, 78 | slice_first_only=False, 79 | slice_overlap_ratio=0, 80 | slice_pad_end=False, 81 | repeat=False, 82 | shuffle=False, 83 | shuffle_buffer_size=None, 84 | prefetch_size=None, 85 | prefetch_gpu_num=None): 86 | """Decodes audio file paths into mini-batches of samples. 87 | 88 | Args: 89 | fps: List of audio file paths. 90 | batch_size: Number of items in the batch. 91 | slice_len: Length of the sliceuences in samples or feature timesteps. 92 | decode_fs: (Re-)sample rate for decoded audio files. 93 | decode_num_channels: Number of channels for decoded audio files. 94 | decode_normalize: If false, do not normalize audio waveforms. 95 | decode_fast_wav: If true, uses scipy to decode standard wav files. 96 | decode_parallel_calls: Number of parallel decoding threads. 97 | slice_randomize_offset: If true, randomize starting position for slice. 98 | slice_first_only: If true, only use first slice from each audio file. 99 | slice_overlap_ratio: Ratio of overlap between adjacent slices. 100 | slice_pad_end: If true, allows zero-padded examples from the end of each audio file. 101 | repeat: If true (for training), continuously iterate through the dataset. 102 | shuffle: If true (for training), buffer and shuffle the sliceuences. 103 | shuffle_buffer_size: Number of examples to queue up before grabbing a batch. 104 | prefetch_size: Number of examples to prefetch from the queue. 105 | prefetch_gpu_num: If specified, prefetch examples to GPU. 106 | 107 | Returns: 108 | A tuple of np.float32 tensors representing audio waveforms. 109 | audio: [batch_size, slice_len, 1, nch] 110 | """ 111 | # Create dataset of filepaths 112 | dataset = tf.data.Dataset.from_tensor_slices(fps) 113 | 114 | # Shuffle all filepaths every epoch 115 | if shuffle: 116 | dataset = dataset.shuffle(buffer_size=len(fps)) 117 | 118 | # Repeat 119 | if repeat: 120 | dataset = dataset.repeat() 121 | 122 | def _decode_audio_shaped(fp): 123 | _decode_audio_closure = lambda _fp: decode_audio( 124 | _fp, 125 | fs=decode_fs, 126 | num_channels=decode_num_channels, 127 | normalize=decode_normalize, 128 | fast_wav=decode_fast_wav) 129 | 130 | audio = tf.py_func( 131 | _decode_audio_closure, 132 | [fp], 133 | tf.float32, 134 | stateful=False) 135 | audio.set_shape([None, 1, decode_num_channels]) 136 | 137 | return audio 138 | 139 | # Decode audio 140 | dataset = dataset.map( 141 | _decode_audio_shaped, 142 | num_parallel_calls=decode_parallel_calls) 143 | 144 | # Parallel 145 | def _slice(audio): 146 | # Calculate hop size 147 | if slice_overlap_ratio < 0: 148 | raise ValueError('Overlap ratio must be greater than 0') 149 | slice_hop = int(round(slice_len * (1. - slice_overlap_ratio)) + 1e-4) 150 | if slice_hop < 1: 151 | raise ValueError('Overlap ratio too high') 152 | 153 | # Randomize starting phase: 154 | if slice_randomize_offset: 155 | start = tf.random_uniform([], maxval=slice_len, dtype=tf.int32) 156 | audio = audio[start:] 157 | 158 | # Extract sliceuences 159 | audio_slices = tf.contrib.signal.frame( 160 | audio, 161 | slice_len, 162 | slice_hop, 163 | pad_end=slice_pad_end, 164 | pad_value=0, 165 | axis=0) 166 | 167 | # Only use first slice if requested 168 | if slice_first_only: 169 | audio_slices = audio_slices[:1] 170 | 171 | return audio_slices 172 | 173 | def _slice_dataset_wrapper(audio): 174 | audio_slices = _slice(audio) 175 | return tf.data.Dataset.from_tensor_slices(audio_slices) 176 | 177 | # Extract parallel sliceuences from both audio and features 178 | dataset = dataset.flat_map(_slice_dataset_wrapper) 179 | 180 | # Shuffle examples 181 | if shuffle: 182 | dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) 183 | 184 | # Make batches 185 | dataset = dataset.batch(batch_size, drop_remainder=True) 186 | 187 | # Prefetch a number of batches 188 | if prefetch_size is not None: 189 | dataset = dataset.prefetch(prefetch_size) 190 | if prefetch_gpu_num is not None and prefetch_gpu_num >= 0: 191 | dataset = dataset.apply( 192 | tf.data.experimental.prefetch_to_device( 193 | '/device:GPU:{}'.format(prefetch_gpu_num))) 194 | 195 | # Get tensors 196 | iterator = dataset.make_one_shot_iterator() 197 | 198 | return iterator.get_next() 199 | -------------------------------------------------------------------------------- /cinfowavegan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv1d_transpose( 5 | inputs, 6 | filters, 7 | kernel_width, 8 | stride=4, 9 | padding='same', 10 | upsample='zeros'): 11 | if upsample == 'zeros': 12 | return tf.layers.conv2d_transpose( 13 | tf.expand_dims(inputs, axis=1), 14 | filters, 15 | (1, kernel_width), 16 | strides=(1, stride), 17 | padding='same' 18 | )[:, 0] 19 | elif upsample == 'nn': 20 | batch_size = tf.shape(inputs)[0] 21 | _, w, nch = inputs.get_shape().as_list() 22 | 23 | x = inputs 24 | 25 | x = tf.expand_dims(x, axis=1) 26 | x = tf.image.resize_nearest_neighbor(x, [1, w * stride]) 27 | x = x[:, 0] 28 | 29 | return tf.layers.conv1d( 30 | x, 31 | filters, 32 | kernel_width, 33 | 1, 34 | padding='same') 35 | else: 36 | raise NotImplementedError 37 | 38 | 39 | """ 40 | Input: [None, 100] 41 | Output: [None, slice_len, 1] 42 | """ 43 | def WaveGANGenerator( 44 | z, 45 | slice_len=16384, 46 | nch=1, 47 | kernel_len=25, 48 | dim=64, 49 | use_batchnorm=False, 50 | upsample='zeros', 51 | train=False): 52 | assert slice_len in [16384, 32768, 65536] 53 | batch_size = tf.shape(z)[0] 54 | 55 | if use_batchnorm: 56 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=train) 57 | else: 58 | batchnorm = lambda x: x 59 | 60 | # FC and reshape for convolution 61 | # [100] -> [16, 1024] 62 | dim_mul = 16 if slice_len == 16384 else 32 63 | output = z 64 | with tf.variable_scope('z_project'): 65 | output = tf.layers.dense(output, 4 * 4 * dim * dim_mul) 66 | output = tf.reshape(output, [batch_size, 16, dim * dim_mul]) 67 | output = batchnorm(output) 68 | output = tf.nn.relu(output) 69 | dim_mul //= 2 70 | 71 | # Layer 0 72 | # [16, 1024] -> [64, 512] 73 | with tf.variable_scope('upconv_0'): 74 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 75 | output = batchnorm(output) 76 | output = tf.nn.relu(output) 77 | dim_mul //= 2 78 | 79 | # Layer 1 80 | # [64, 512] -> [256, 256] 81 | with tf.variable_scope('upconv_1'): 82 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 83 | output = batchnorm(output) 84 | output = tf.nn.relu(output) 85 | dim_mul //= 2 86 | 87 | # Layer 2 88 | # [256, 256] -> [1024, 128] 89 | with tf.variable_scope('upconv_2'): 90 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 91 | output = batchnorm(output) 92 | output = tf.nn.relu(output) 93 | dim_mul //= 2 94 | 95 | # Layer 3 96 | # [1024, 128] -> [4096, 64] 97 | with tf.variable_scope('upconv_3'): 98 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample) 99 | output = batchnorm(output) 100 | output = tf.nn.relu(output) 101 | 102 | if slice_len == 16384: 103 | # Layer 4 104 | # [4096, 64] -> [16384, nch] 105 | with tf.variable_scope('upconv_4'): 106 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample) 107 | output = tf.nn.tanh(output) 108 | elif slice_len == 32768: 109 | # Layer 4 110 | # [4096, 128] -> [16384, 64] 111 | with tf.variable_scope('upconv_4'): 112 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample) 113 | output = batchnorm(output) 114 | output = tf.nn.relu(output) 115 | 116 | # Layer 5 117 | # [16384, 64] -> [32768, nch] 118 | with tf.variable_scope('upconv_5'): 119 | output = conv1d_transpose(output, nch, kernel_len, 2, upsample=upsample) 120 | output = tf.nn.tanh(output) 121 | elif slice_len == 65536: 122 | # Layer 4 123 | # [4096, 128] -> [16384, 64] 124 | with tf.variable_scope('upconv_4'): 125 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample) 126 | output = batchnorm(output) 127 | output = tf.nn.relu(output) 128 | 129 | # Layer 5 130 | # [16384, 64] -> [65536, nch] 131 | with tf.variable_scope('upconv_5'): 132 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample) 133 | output = tf.nn.tanh(output) 134 | 135 | # Automatically update batchnorm moving averages every time G is used during training 136 | if train and use_batchnorm: 137 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) 138 | if slice_len == 16384: 139 | assert len(update_ops) == 10 140 | else: 141 | assert len(update_ops) == 12 142 | with tf.control_dependencies(update_ops): 143 | output = tf.identity(output) 144 | 145 | return output 146 | 147 | 148 | def lrelu(inputs, alpha=0.2): 149 | return tf.maximum(alpha * inputs, inputs) 150 | 151 | 152 | def apply_phaseshuffle(x, rad, pad_type='reflect'): 153 | b, x_len, nch = x.get_shape().as_list() 154 | 155 | phase = tf.random_uniform([], minval=-rad, maxval=rad + 1, dtype=tf.int32) 156 | pad_l = tf.maximum(phase, 0) 157 | pad_r = tf.maximum(-phase, 0) 158 | phase_start = pad_r 159 | x = tf.pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode=pad_type) 160 | 161 | x = x[:, phase_start:phase_start+x_len] 162 | x.set_shape([b, x_len, nch]) 163 | 164 | return x 165 | 166 | 167 | """ 168 | Input: [None, slice_len, nch] 169 | Output: [None] (linear output) 170 | """ 171 | def WaveGANDiscriminator( 172 | x, 173 | kernel_len=25, 174 | dim=64, 175 | use_batchnorm=False, 176 | phaseshuffle_rad=0): 177 | batch_size = tf.shape(x)[0] 178 | slice_len = int(x.get_shape()[1]) 179 | 180 | if use_batchnorm: 181 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=True) 182 | else: 183 | batchnorm = lambda x: x 184 | 185 | if phaseshuffle_rad > 0: 186 | phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad) 187 | else: 188 | phaseshuffle = lambda x: x 189 | 190 | # Layer 0 191 | # [16384, 1] -> [4096, 64] 192 | output = x 193 | with tf.variable_scope('downconv_0'): 194 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME') 195 | output = lrelu(output) 196 | output = phaseshuffle(output) 197 | 198 | # Layer 1 199 | # [4096, 64] -> [1024, 128] 200 | with tf.variable_scope('downconv_1'): 201 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME') 202 | output = batchnorm(output) 203 | output = lrelu(output) 204 | output = phaseshuffle(output) 205 | 206 | # Layer 2 207 | # [1024, 128] -> [256, 256] 208 | with tf.variable_scope('downconv_2'): 209 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME') 210 | output = batchnorm(output) 211 | output = lrelu(output) 212 | output = phaseshuffle(output) 213 | 214 | # Layer 3 215 | # [256, 256] -> [64, 512] 216 | with tf.variable_scope('downconv_3'): 217 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME') 218 | output = batchnorm(output) 219 | output = lrelu(output) 220 | output = phaseshuffle(output) 221 | 222 | # Layer 4 223 | # [64, 512] -> [16, 1024] 224 | with tf.variable_scope('downconv_4'): 225 | output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME') 226 | output = batchnorm(output) 227 | output = lrelu(output) 228 | 229 | if slice_len == 32768: 230 | # Layer 5 231 | # [32, 1024] -> [16, 2048] 232 | with tf.variable_scope('downconv_5'): 233 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 2, padding='SAME') 234 | output = batchnorm(output) 235 | output = lrelu(output) 236 | elif slice_len == 65536: 237 | # Layer 5 238 | # [64, 1024] -> [16, 2048] 239 | with tf.variable_scope('downconv_5'): 240 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 4, padding='SAME') 241 | output = batchnorm(output) 242 | output = lrelu(output) 243 | 244 | # Flatten 245 | output = tf.reshape(output, [batch_size, -1]) 246 | 247 | # Connect to single logit 248 | with tf.variable_scope('output'): 249 | output = tf.layers.dense(output, 1)[:, 0] 250 | 251 | # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training 252 | 253 | return output 254 | 255 | 256 | 257 | def WaveGANQ( 258 | x, 259 | kernel_len=25, 260 | dim=64, 261 | use_batchnorm=False, 262 | phaseshuffle_rad=0, 263 | num_categ=10): 264 | batch_size = tf.shape(x)[0] 265 | slice_len = int(x.get_shape()[1]) 266 | 267 | if use_batchnorm: 268 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=True) 269 | else: 270 | batchnorm = lambda x: x 271 | 272 | if phaseshuffle_rad > 0: 273 | phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad) 274 | else: 275 | phaseshuffle = lambda x: x 276 | 277 | # Layer 0 278 | # [16384, 1] -> [4096, 64] 279 | output = x 280 | with tf.variable_scope('Qdownconv_0'): 281 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME') 282 | output = lrelu(output) 283 | output = phaseshuffle(output) 284 | 285 | # Layer 1 286 | # [4096, 64] -> [1024, 128] 287 | with tf.variable_scope('Qdownconv_1'): 288 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME') 289 | output = batchnorm(output) 290 | output = lrelu(output) 291 | output = phaseshuffle(output) 292 | 293 | # Layer 2 294 | # [1024, 128] -> [256, 256] 295 | with tf.variable_scope('Qdownconv_2'): 296 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME') 297 | output = batchnorm(output) 298 | output = lrelu(output) 299 | output = phaseshuffle(output) 300 | 301 | # Layer 3 302 | # [256, 256] -> [64, 512] 303 | with tf.variable_scope('Qdownconv_3'): 304 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME') 305 | output = batchnorm(output) 306 | output = lrelu(output) 307 | output = phaseshuffle(output) 308 | 309 | # Layer 4 310 | # [64, 512] -> [16, 1024] 311 | with tf.variable_scope('Qdownconv_4'): 312 | output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME') 313 | output = batchnorm(output) 314 | output = lrelu(output) 315 | 316 | if slice_len == 32768: 317 | # Layer 5 318 | # [32, 1024] -> [16, 2048] 319 | with tf.variable_scope('Qdownconv_5'): 320 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 2, padding='SAME') 321 | output = batchnorm(output) 322 | output = lrelu(output) 323 | elif slice_len == 65536: 324 | # Layer 5 325 | # [64, 1024] -> [16, 2048] 326 | with tf.variable_scope('Qdownconv_5'): 327 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 4, padding='SAME') 328 | output = batchnorm(output) 329 | output = lrelu(output) 330 | 331 | # Flatten 332 | output = tf.reshape(output, [batch_size, -1]) 333 | 334 | # Connect to single logit 335 | with tf.variable_scope('Qoutput'): 336 | Qoutput = tf.layers.dense(output, num_categ) 337 | 338 | 339 | # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training 340 | 341 | return Qoutput 342 | -------------------------------------------------------------------------------- /train_fiwgan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | fiwGAN: Featural InfoWaveGAN 3 | Gasper Begus (begus@uw.edu) 2020 4 | Based on WaveGAN (Donahue et al. 2019) and InfoGAN (Chen et al. 2016), partially also on code by Rodionov (2018). 5 | Unlike InfoGAN, the latent code is binomially distributed (features) and training is performed with sigmoid cross-entropy. 6 | ''' 7 | 8 | from __future__ import print_function 9 | 10 | try: 11 | import cPickle as pickle 12 | except: 13 | import pickle 14 | from functools import reduce 15 | import os 16 | import time 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from six.moves import xrange 21 | 22 | import loader 23 | from waveganNET import WaveGANGenerator, WaveGANDiscriminator, WaveGANQ 24 | 25 | 26 | """ 27 | Trains a WaveGAN 28 | """ 29 | def train(fps, args): 30 | with tf.name_scope('loader'): 31 | x = loader.decode_extract_and_batch( 32 | fps, 33 | batch_size=args.train_batch_size, 34 | slice_len=args.data_slice_len, 35 | decode_fs=args.data_sample_rate, 36 | decode_num_channels=args.data_num_channels, 37 | decode_fast_wav=args.data_fast_wav, 38 | decode_parallel_calls=4, 39 | slice_randomize_offset=False if args.data_first_slice else True, 40 | slice_first_only=args.data_first_slice, 41 | slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio, 42 | slice_pad_end=True if args.data_first_slice else args.data_pad_end, 43 | repeat=True, 44 | shuffle=True, 45 | shuffle_buffer_size=4096, 46 | prefetch_size=args.train_batch_size * 4, 47 | prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0] 48 | 49 | # Make z vector 50 | 51 | categ = tf.keras.backend.random_binomial([args.train_batch_size,args.num_categ],0.5) 52 | uniform = tf.random_uniform([args.train_batch_size,args.wavegan_latent_dim-args.num_categ],-1.,1.) 53 | z = tf.concat([categ,uniform],1) 54 | 55 | # Make generator 56 | with tf.variable_scope('G'): 57 | G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs) 58 | if args.wavegan_genr_pp: 59 | with tf.variable_scope('pp_filt'): 60 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') 61 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') 62 | 63 | # Print G summary 64 | print('-' * 80) 65 | print('Generator vars') 66 | nparams = 0 67 | for v in G_vars: 68 | v_shape = v.get_shape().as_list() 69 | v_n = reduce(lambda x, y: x * y, v_shape) 70 | nparams += v_n 71 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 72 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 73 | 74 | # Summarize 75 | tf.summary.audio('x', x, args.data_sample_rate) 76 | tf.summary.audio('G_z', G_z, args.data_sample_rate) 77 | G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) 78 | x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) 79 | tf.summary.histogram('x_rms_batch', x_rms) 80 | tf.summary.histogram('G_z_rms_batch', G_z_rms) 81 | tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) 82 | tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) 83 | 84 | # Make real discriminator 85 | with tf.name_scope('D_x'), tf.variable_scope('D'): 86 | D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) 87 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') 88 | 89 | # Print D summary 90 | print('-' * 80) 91 | print('Discriminator vars') 92 | nparams = 0 93 | for v in D_vars: 94 | v_shape = v.get_shape().as_list() 95 | v_n = reduce(lambda x, y: x * y, v_shape) 96 | nparams += v_n 97 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 98 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 99 | print('-' * 80) 100 | 101 | 102 | 103 | # Make fake discriminator 104 | with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): 105 | D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) 106 | 107 | # Make Q 108 | with tf.variable_scope('Q'): 109 | Q_G_z = WaveGANQ(G_z, **args.wavegan_q_kwargs) 110 | Q_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Q') 111 | 112 | # Print Q summary 113 | print('Q vars') 114 | nparams = 0 115 | for v in Q_vars: 116 | v_shape = v.get_shape().as_list() 117 | v_n = reduce(lambda x, y: x * y, v_shape) 118 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 119 | print('-' * 80) 120 | 121 | # Create loss 122 | D_clip_weights = None 123 | if args.wavegan_loss == 'dcgan': 124 | fake = tf.zeros([args.train_batch_size], dtype=tf.float32) 125 | real = tf.ones([args.train_batch_size], dtype=tf.float32) 126 | 127 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 128 | logits=D_G_z, 129 | labels=real 130 | )) 131 | 132 | D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 133 | logits=D_G_z, 134 | labels=fake 135 | )) 136 | D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 137 | logits=D_x, 138 | labels=real 139 | )) 140 | 141 | D_loss /= 2. 142 | elif args.wavegan_loss == 'lsgan': 143 | G_loss = tf.reduce_mean((D_G_z - 1.) ** 2) 144 | D_loss = tf.reduce_mean((D_x - 1.) ** 2) 145 | D_loss += tf.reduce_mean(D_G_z ** 2) 146 | D_loss /= 2. 147 | elif args.wavegan_loss == 'wgan': 148 | G_loss = -tf.reduce_mean(D_G_z) 149 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) 150 | 151 | with tf.name_scope('D_clip_weights'): 152 | clip_ops = [] 153 | for var in D_vars: 154 | clip_bounds = [-.01, .01] 155 | clip_ops.append( 156 | tf.assign( 157 | var, 158 | tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) 159 | ) 160 | ) 161 | D_clip_weights = tf.group(*clip_ops) 162 | elif args.wavegan_loss == 'wgan-gp': 163 | 164 | z_q_loss = z[:, : args.num_categ] 165 | q_q_loss = Q_G_z[:, : args.num_categ] 166 | q_sigmoid = tf.nn.sigmoid_cross_entropy_with_logits(labels=z_q_loss, logits=q_q_loss) 167 | G_loss = -tf.reduce_mean(D_G_z) 168 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) 169 | Q_loss = tf.reduce_mean(q_sigmoid) 170 | 171 | alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) 172 | differences = G_z - x 173 | interpolates = x + (alpha * differences) 174 | with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): 175 | D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) 176 | 177 | LAMBDA = 10 178 | gradients = tf.gradients(D_interp, [interpolates])[0] 179 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) 180 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.) 181 | D_loss += LAMBDA * gradient_penalty 182 | else: 183 | raise NotImplementedError() 184 | 185 | tf.summary.scalar('G_loss', G_loss) 186 | tf.summary.scalar('D_loss', D_loss) 187 | tf.summary.scalar('Q_loss', Q_loss) 188 | 189 | # Create (recommended) optimizer 190 | if args.wavegan_loss == 'dcgan': 191 | G_opt = tf.train.AdamOptimizer( 192 | learning_rate=2e-4, 193 | beta1=0.5) 194 | D_opt = tf.train.AdamOptimizer( 195 | learning_rate=2e-4, 196 | beta1=0.5) 197 | elif args.wavegan_loss == 'lsgan': 198 | G_opt = tf.train.RMSPropOptimizer( 199 | learning_rate=1e-4) 200 | D_opt = tf.train.RMSPropOptimizer( 201 | learning_rate=1e-4) 202 | elif args.wavegan_loss == 'wgan': 203 | G_opt = tf.train.RMSPropOptimizer( 204 | learning_rate=5e-5) 205 | D_opt = tf.train.RMSPropOptimizer( 206 | learning_rate=5e-5) 207 | elif args.wavegan_loss == 'wgan-gp': 208 | G_opt = tf.train.AdamOptimizer( 209 | learning_rate=1e-4, 210 | beta1=0.5, 211 | beta2=0.9) 212 | D_opt = tf.train.AdamOptimizer( 213 | learning_rate=1e-4, 214 | beta1=0.5, 215 | beta2=0.9) 216 | Q_opt = tf.train.RMSPropOptimizer( 217 | learning_rate=1e-4) 218 | else: 219 | raise NotImplementedError() 220 | 221 | # Create training ops 222 | G_train_op = G_opt.minimize(G_loss, var_list=G_vars, 223 | global_step=tf.train.get_or_create_global_step()) 224 | D_train_op = D_opt.minimize(D_loss, var_list=D_vars) 225 | Q_train_op = Q_opt.minimize(Q_loss, var_list=Q_vars+G_vars) 226 | 227 | # Run training 228 | with tf.train.MonitoredTrainingSession( 229 | checkpoint_dir=args.train_dir, 230 | save_checkpoint_secs=args.train_save_secs, 231 | save_summaries_secs=args.train_summary_secs) as sess: 232 | print('-' * 80) 233 | print('Training has started. Please use \'tensorboard --logdir={}\' to monitor.'.format(args.train_dir)) 234 | while True: 235 | # Train discriminator 236 | for i in xrange(args.wavegan_disc_nupdates): 237 | sess.run(D_train_op) 238 | 239 | 240 | # Enforce Lipschitz constraint for WGAN 241 | if D_clip_weights is not None: 242 | sess.run(D_clip_weights) 243 | 244 | # Train generator 245 | sess.run([G_train_op,Q_train_op]) 246 | 247 | 248 | """ 249 | Creates and saves a MetaGraphDef for simple inference 250 | Tensors: 251 | 'samp_z_n' int32 []: Sample this many latent vectors 252 | 'samp_z' float32 [samp_z_n, latent_dim]: Resultant latent vectors 253 | 'z:0' float32 [None, latent_dim]: Input latent vectors 254 | 'flat_pad:0' int32 []: Number of padding samples to use when flattening batch to a single audio file 255 | 'G_z:0' float32 [None, slice_len, 1]: Generated outputs 256 | 'G_z_int16:0' int16 [None, slice_len, 1]: Same as above but quantizied to 16-bit PCM samples 257 | 'G_z_flat:0' float32 [None, 1]: Outputs flattened into single audio file 258 | 'G_z_flat_int16:0' int16 [None, 1]: Same as above but quantized to 16-bit PCM samples 259 | Example usage: 260 | import tensorflow as tf 261 | tf.reset_default_graph() 262 | 263 | saver = tf.train.import_meta_graph('infer.meta') 264 | graph = tf.get_default_graph() 265 | sess = tf.InteractiveSession() 266 | saver.restore(sess, 'model.ckpt-10000') 267 | 268 | z_n = graph.get_tensor_by_name('samp_z_n:0') 269 | _z = sess.run(graph.get_tensor_by_name('samp_z:0'), {z_n: 10}) 270 | 271 | z = graph.get_tensor_by_name('G_z:0') 272 | _G_z = sess.run(graph.get_tensor_by_name('G_z:0'), {z: _z}) 273 | """ 274 | def infer(args): 275 | infer_dir = os.path.join(args.train_dir, 'infer') 276 | if not os.path.isdir(infer_dir): 277 | os.makedirs(infer_dir) 278 | 279 | # Subgraph that generates latent vectors 280 | samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n') 281 | samp_z = tf.random_uniform([samp_z_n, args.wavegan_latent_dim], -1.0, 1.0, dtype=tf.float32, name='samp_z') 282 | 283 | # Input zo 284 | z = tf.placeholder(tf.float32, [None, args.wavegan_latent_dim], name='z') 285 | flat_pad = tf.placeholder(tf.int32, [], name='flat_pad') 286 | 287 | # Execute generator 288 | with tf.variable_scope('G'): 289 | G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs) 290 | if args.wavegan_genr_pp: 291 | with tf.variable_scope('pp_filt'): 292 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') 293 | G_z = tf.identity(G_z, name='G_z') 294 | 295 | # Flatten batch 296 | nch = int(G_z.get_shape()[-1]) 297 | G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]]) 298 | G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat') 299 | 300 | # Encode to int16 301 | def float_to_int16(x, name=None): 302 | x_int16 = x * 32767. 303 | x_int16 = tf.clip_by_value(x_int16, -32767., 32767.) 304 | x_int16 = tf.cast(x_int16, tf.int16, name=name) 305 | return x_int16 306 | G_z_int16 = float_to_int16(G_z, name='G_z_int16') 307 | G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16') 308 | 309 | # Create saver 310 | G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G') 311 | global_step = tf.train.get_or_create_global_step() 312 | saver = tf.train.Saver(G_vars + [global_step]) 313 | 314 | # Export graph 315 | tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') 316 | 317 | # Export MetaGraph 318 | infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') 319 | tf.train.export_meta_graph( 320 | filename=infer_metagraph_fp, 321 | clear_devices=True, 322 | saver_def=saver.as_saver_def()) 323 | 324 | # Reset graph (in case training afterwards) 325 | tf.reset_default_graph() 326 | 327 | 328 | """ 329 | Generates a preview audio file every time a checkpoint is saved 330 | """ 331 | def preview(args): 332 | import matplotlib 333 | matplotlib.use('Agg') 334 | import matplotlib.pyplot as plt 335 | from scipy.io.wavfile import write as wavwrite 336 | from scipy.signal import freqz 337 | 338 | preview_dir = os.path.join(args.train_dir, 'preview') 339 | if not os.path.isdir(preview_dir): 340 | os.makedirs(preview_dir) 341 | 342 | # Load graph 343 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta') 344 | graph = tf.get_default_graph() 345 | saver = tf.train.import_meta_graph(infer_metagraph_fp) 346 | 347 | # Generate or restore z_i and z_o 348 | z_fp = os.path.join(preview_dir, 'z.pkl') 349 | if os.path.exists(z_fp): 350 | with open(z_fp, 'rb') as f: 351 | _zs = pickle.load(f) 352 | else: 353 | # Sample z 354 | samp_feeds = {} 355 | samp_feeds[graph.get_tensor_by_name('samp_z_n:0')] = args.preview_n 356 | samp_fetches = {} 357 | samp_fetches['zs'] = graph.get_tensor_by_name('samp_z:0') 358 | with tf.Session() as sess: 359 | _samp_fetches = sess.run(samp_fetches, samp_feeds) 360 | _zs = _samp_fetches['zs'] 361 | 362 | # Save z 363 | with open(z_fp, 'wb') as f: 364 | pickle.dump(_zs, f) 365 | 366 | # Set up graph for generating preview images 367 | feeds = {} 368 | feeds[graph.get_tensor_by_name('z:0')] = _zs 369 | feeds[graph.get_tensor_by_name('flat_pad:0')] = int(args.data_sample_rate / 2) 370 | fetches = {} 371 | fetches['step'] = tf.train.get_or_create_global_step() 372 | fetches['G_z'] = graph.get_tensor_by_name('G_z:0') 373 | fetches['G_z_flat_int16'] = graph.get_tensor_by_name('G_z_flat_int16:0') 374 | if args.wavegan_genr_pp: 375 | fetches['pp_filter'] = graph.get_tensor_by_name('G/pp_filt/conv1d/kernel:0')[:, 0, 0] 376 | 377 | # Summarize 378 | G_z = graph.get_tensor_by_name('G_z_flat:0') 379 | summaries = [ 380 | tf.summary.audio('preview', tf.expand_dims(G_z, axis=0), args.data_sample_rate, max_outputs=1) 381 | ] 382 | fetches['summaries'] = tf.summary.merge(summaries) 383 | summary_writer = tf.summary.FileWriter(preview_dir) 384 | 385 | # PP Summarize 386 | if args.wavegan_genr_pp: 387 | pp_fp = tf.placeholder(tf.string, []) 388 | pp_bin = tf.read_file(pp_fp) 389 | pp_png = tf.image.decode_png(pp_bin) 390 | pp_summary = tf.summary.image('pp_filt', tf.expand_dims(pp_png, axis=0)) 391 | 392 | # Loop, waiting for checkpoints 393 | ckpt_fp = None 394 | while True: 395 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) 396 | if latest_ckpt_fp != ckpt_fp: 397 | print('Preview: {}'.format(latest_ckpt_fp)) 398 | 399 | with tf.Session() as sess: 400 | saver.restore(sess, latest_ckpt_fp) 401 | 402 | _fetches = sess.run(fetches, feeds) 403 | 404 | _step = _fetches['step'] 405 | 406 | preview_fp = os.path.join(preview_dir, '{}.wav'.format(str(_step).zfill(8))) 407 | wavwrite(preview_fp, args.data_sample_rate, _fetches['G_z_flat_int16']) 408 | 409 | summary_writer.add_summary(_fetches['summaries'], _step) 410 | 411 | if args.wavegan_genr_pp: 412 | w, h = freqz(_fetches['pp_filter']) 413 | 414 | fig = plt.figure() 415 | plt.title('Digital filter frequncy response') 416 | ax1 = fig.add_subplot(111) 417 | 418 | plt.plot(w, 20 * np.log10(abs(h)), 'b') 419 | plt.ylabel('Amplitude [dB]', color='b') 420 | plt.xlabel('Frequency [rad/sample]') 421 | 422 | ax2 = ax1.twinx() 423 | angles = np.unwrap(np.angle(h)) 424 | plt.plot(w, angles, 'g') 425 | plt.ylabel('Angle (radians)', color='g') 426 | plt.grid() 427 | plt.axis('tight') 428 | 429 | _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8))) 430 | plt.savefig(_pp_fp) 431 | 432 | with tf.Session() as sess: 433 | _summary = sess.run(pp_summary, {pp_fp: _pp_fp}) 434 | summary_writer.add_summary(_summary, _step) 435 | 436 | print('Done') 437 | 438 | ckpt_fp = latest_ckpt_fp 439 | 440 | time.sleep(1) 441 | 442 | 443 | """ 444 | Computes inception score every time a checkpoint is saved 445 | """ 446 | def incept(args): 447 | incept_dir = os.path.join(args.train_dir, 'incept') 448 | if not os.path.isdir(incept_dir): 449 | os.makedirs(incept_dir) 450 | 451 | # Load GAN graph 452 | gan_graph = tf.Graph() 453 | with gan_graph.as_default(): 454 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta') 455 | gan_saver = tf.train.import_meta_graph(infer_metagraph_fp) 456 | score_saver = tf.train.Saver(max_to_keep=1) 457 | gan_z = gan_graph.get_tensor_by_name('z:0') 458 | gan_G_z = gan_graph.get_tensor_by_name('G_z:0')[:, :, 0] 459 | gan_step = gan_graph.get_tensor_by_name('global_step:0') 460 | 461 | # Load or generate latents 462 | z_fp = os.path.join(incept_dir, 'z.pkl') 463 | if os.path.exists(z_fp): 464 | with open(z_fp, 'rb') as f: 465 | _zs = pickle.load(f) 466 | else: 467 | gan_samp_z_n = gan_graph.get_tensor_by_name('samp_z_n:0') 468 | gan_samp_z = gan_graph.get_tensor_by_name('samp_z:0') 469 | with tf.Session(graph=gan_graph) as sess: 470 | _zs = sess.run(gan_samp_z, {gan_samp_z_n: args.incept_n}) 471 | with open(z_fp, 'wb') as f: 472 | pickle.dump(_zs, f) 473 | 474 | # Load classifier graph 475 | incept_graph = tf.Graph() 476 | with incept_graph.as_default(): 477 | incept_saver = tf.train.import_meta_graph(args.incept_metagraph_fp) 478 | incept_x = incept_graph.get_tensor_by_name('x:0') 479 | incept_preds = incept_graph.get_tensor_by_name('scores:0') 480 | incept_sess = tf.Session(graph=incept_graph) 481 | incept_saver.restore(incept_sess, args.incept_ckpt_fp) 482 | 483 | # Create summaries 484 | summary_graph = tf.Graph() 485 | with summary_graph.as_default(): 486 | incept_mean = tf.placeholder(tf.float32, []) 487 | incept_std = tf.placeholder(tf.float32, []) 488 | summaries = [ 489 | tf.summary.scalar('incept_mean', incept_mean), 490 | tf.summary.scalar('incept_std', incept_std) 491 | ] 492 | summaries = tf.summary.merge(summaries) 493 | summary_writer = tf.summary.FileWriter(incept_dir) 494 | 495 | # Loop, waiting for checkpoints 496 | ckpt_fp = None 497 | _best_score = 0. 498 | while True: 499 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) 500 | if latest_ckpt_fp != ckpt_fp: 501 | print('Incept: {}'.format(latest_ckpt_fp)) 502 | 503 | sess = tf.Session(graph=gan_graph) 504 | 505 | gan_saver.restore(sess, latest_ckpt_fp) 506 | 507 | _step = sess.run(gan_step) 508 | 509 | _G_zs = [] 510 | for i in xrange(0, args.incept_n, 100): 511 | _G_zs.append(sess.run(gan_G_z, {gan_z: _zs[i:i+100]})) 512 | _G_zs = np.concatenate(_G_zs, axis=0) 513 | 514 | _preds = [] 515 | for i in xrange(0, args.incept_n, 100): 516 | _preds.append(incept_sess.run(incept_preds, {incept_x: _G_zs[i:i+100]})) 517 | _preds = np.concatenate(_preds, axis=0) 518 | 519 | # Split into k groups 520 | _incept_scores = [] 521 | split_size = args.incept_n // args.incept_k 522 | for i in xrange(args.incept_k): 523 | _split = _preds[i * split_size:(i + 1) * split_size] 524 | _kl = _split * (np.log(_split) - np.log(np.expand_dims(np.mean(_split, 0), 0))) 525 | _kl = np.mean(np.sum(_kl, 1)) 526 | _incept_scores.append(np.exp(_kl)) 527 | 528 | _incept_mean, _incept_std = np.mean(_incept_scores), np.std(_incept_scores) 529 | 530 | # Summarize 531 | with tf.Session(graph=summary_graph) as summary_sess: 532 | _summaries = summary_sess.run(summaries, {incept_mean: _incept_mean, incept_std: _incept_std}) 533 | summary_writer.add_summary(_summaries, _step) 534 | 535 | # Save 536 | if _incept_mean > _best_score: 537 | score_saver.save(sess, os.path.join(incept_dir, 'best_score'), _step) 538 | _best_score = _incept_mean 539 | 540 | sess.close() 541 | 542 | print('Done') 543 | 544 | ckpt_fp = latest_ckpt_fp 545 | 546 | time.sleep(1) 547 | 548 | incept_sess.close() 549 | 550 | 551 | if __name__ == '__main__': 552 | import argparse 553 | import glob 554 | import sys 555 | 556 | parser = argparse.ArgumentParser() 557 | 558 | parser.add_argument('mode', type=str, choices=['train', 'preview', 'incept', 'infer']) 559 | parser.add_argument('train_dir', type=str, 560 | help='Training directory') 561 | 562 | data_args = parser.add_argument_group('Data') 563 | data_args.add_argument('--data_dir', type=str, 564 | help='Data directory containing *only* audio files to load') 565 | data_args.add_argument('--data_sample_rate', type=int, 566 | help='Number of audio samples per second') 567 | data_args.add_argument('--data_slice_len', type=int, choices=[16384, 32768, 65536], 568 | help='Number of audio samples per slice (maximum generation length)') 569 | data_args.add_argument('--data_num_channels', type=int, 570 | help='Number of audio channels to generate (for >2, must match that of data)') 571 | data_args.add_argument('--data_overlap_ratio', type=float, 572 | help='Overlap ratio [0, 1) between slices') 573 | data_args.add_argument('--data_first_slice', action='store_true', dest='data_first_slice', 574 | help='If set, only use the first slice each audio example') 575 | data_args.add_argument('--data_pad_end', action='store_true', dest='data_pad_end', 576 | help='If set, use zero-padded partial slices from the end of each audio file') 577 | data_args.add_argument('--data_normalize', action='store_true', dest='data_normalize', 578 | help='If set, normalize the training examples') 579 | data_args.add_argument('--data_fast_wav', action='store_true', dest='data_fast_wav', 580 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa') 581 | data_args.add_argument('--data_prefetch_gpu_num', type=int, 582 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)') 583 | 584 | wavegan_args = parser.add_argument_group('WaveGAN') 585 | wavegan_args.add_argument('--wavegan_latent_dim', type=int, 586 | help='Number of dimensions of the latent space') 587 | wavegan_args.add_argument('--wavegan_kernel_len', type=int, 588 | help='Length of 1D filter kernels') 589 | wavegan_args.add_argument('--wavegan_dim', type=int, 590 | help='Dimensionality multiplier for model of G and D') 591 | wavegan_args.add_argument('--num_categ', type=int, 592 | help='Number of categorical variables') 593 | wavegan_args.add_argument('--wavegan_batchnorm', action='store_true', dest='wavegan_batchnorm', 594 | help='Enable batchnorm') 595 | wavegan_args.add_argument('--wavegan_disc_nupdates', type=int, 596 | help='Number of discriminator updates per generator update') 597 | wavegan_args.add_argument('--wavegan_loss', type=str, choices=['dcgan', 'lsgan', 'wgan', 'wgan-gp'], 598 | help='Which GAN loss to use') 599 | wavegan_args.add_argument('--wavegan_genr_upsample', type=str, choices=['zeros', 'nn'], 600 | help='Generator upsample strategy') 601 | wavegan_args.add_argument('--wavegan_genr_pp', action='store_true', dest='wavegan_genr_pp', 602 | help='If set, use post-processing filter') 603 | wavegan_args.add_argument('--wavegan_genr_pp_len', type=int, 604 | help='Length of post-processing filter for DCGAN') 605 | wavegan_args.add_argument('--wavegan_disc_phaseshuffle', type=int, 606 | help='Radius of phase shuffle operation') 607 | 608 | train_args = parser.add_argument_group('Train') 609 | train_args.add_argument('--train_batch_size', type=int, 610 | help='Batch size') 611 | train_args.add_argument('--train_save_secs', type=int, 612 | help='How often to save model') 613 | train_args.add_argument('--train_summary_secs', type=int, 614 | help='How often to report summaries') 615 | 616 | preview_args = parser.add_argument_group('Preview') 617 | preview_args.add_argument('--preview_n', type=int, 618 | help='Number of samples to preview') 619 | 620 | incept_args = parser.add_argument_group('Incept') 621 | incept_args.add_argument('--incept_metagraph_fp', type=str, 622 | help='Inference model for inception score') 623 | incept_args.add_argument('--incept_ckpt_fp', type=str, 624 | help='Checkpoint for inference model') 625 | incept_args.add_argument('--incept_n', type=int, 626 | help='Number of generated examples to test') 627 | incept_args.add_argument('--incept_k', type=int, 628 | help='Number of groups to test') 629 | 630 | parser.set_defaults( 631 | data_dir=None, 632 | data_sample_rate=16000, 633 | data_slice_len=16384, 634 | data_num_channels=1, 635 | data_overlap_ratio=0., 636 | data_first_slice=False, 637 | data_pad_end=False, 638 | data_normalize=False, 639 | data_fast_wav=False, 640 | data_prefetch_gpu_num=0, 641 | wavegan_latent_dim=100, 642 | wavegan_kernel_len=25, 643 | wavegan_dim=64, 644 | num_categ=3, 645 | wavegan_batchnorm=False, 646 | wavegan_disc_nupdates=5, 647 | wavegan_loss='wgan-gp', 648 | wavegan_genr_upsample='zeros', 649 | wavegan_genr_pp=False, 650 | wavegan_genr_pp_len=512, 651 | wavegan_disc_phaseshuffle=2, 652 | train_batch_size=64, 653 | train_save_secs=300, 654 | train_summary_secs=120, 655 | preview_n=32, 656 | incept_metagraph_fp='./eval/inception/infer.meta', 657 | incept_ckpt_fp='./eval/inception/best_acc-103005', 658 | incept_n=5000, 659 | incept_k=10) 660 | 661 | args = parser.parse_args() 662 | 663 | # Make train dir 664 | if not os.path.isdir(args.train_dir): 665 | os.makedirs(args.train_dir) 666 | 667 | # Save args 668 | with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f: 669 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])])) 670 | 671 | # Make model kwarg dicts 672 | setattr(args, 'wavegan_g_kwargs', { 673 | 'slice_len': args.data_slice_len, 674 | 'nch': args.data_num_channels, 675 | 'kernel_len': args.wavegan_kernel_len, 676 | 'dim': args.wavegan_dim, 677 | 'use_batchnorm': args.wavegan_batchnorm, 678 | 'upsample': args.wavegan_genr_upsample 679 | }) 680 | setattr(args, 'wavegan_d_kwargs', { 681 | 'kernel_len': args.wavegan_kernel_len, 682 | 'dim': args.wavegan_dim, 683 | 'use_batchnorm': args.wavegan_batchnorm, 684 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle 685 | }) 686 | setattr(args, 'wavegan_q_kwargs', { 687 | 'kernel_len': args.wavegan_kernel_len, 688 | 'dim': args.wavegan_dim, 689 | 'use_batchnorm': args.wavegan_batchnorm, 690 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle, 691 | 'num_categ': args.num_categ 692 | }) 693 | 694 | if args.mode == 'train': 695 | fps = glob.glob(os.path.join(args.data_dir, '*')) 696 | if len(fps) == 0: 697 | raise Exception('Did not find any audio files in specified directory') 698 | print('Found {} audio files in specified directory'.format(len(fps))) 699 | infer(args) 700 | train(fps, args) 701 | elif args.mode == 'preview': 702 | preview(args) 703 | elif args.mode == 'incept': 704 | incept(args) 705 | elif args.mode == 'infer': 706 | infer(args) 707 | else: 708 | raise NotImplementedError() 709 | -------------------------------------------------------------------------------- /train_ciwgan.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | try: 4 | import cPickle as pickle 5 | except: 6 | import pickle 7 | from functools import reduce 8 | import os 9 | import time 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | from six.moves import xrange 14 | 15 | import loader 16 | from cinfowavegan import WaveGANGenerator, WaveGANDiscriminator, WaveGANQ 17 | 18 | 19 | """ 20 | Trains a WaveGAN 21 | """ 22 | def train(fps, args): 23 | with tf.name_scope('loader'): 24 | x = loader.decode_extract_and_batch( 25 | fps, 26 | batch_size=args.train_batch_size, 27 | slice_len=args.data_slice_len, 28 | decode_fs=args.data_sample_rate, 29 | decode_num_channels=args.data_num_channels, 30 | decode_fast_wav=args.data_fast_wav, 31 | decode_parallel_calls=4, 32 | slice_randomize_offset=False if args.data_first_slice else True, 33 | slice_first_only=args.data_first_slice, 34 | slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio, 35 | slice_pad_end=True if args.data_first_slice else args.data_pad_end, 36 | repeat=True, 37 | shuffle=True, 38 | shuffle_buffer_size=4096, 39 | prefetch_size=args.train_batch_size * 4, 40 | prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0] 41 | 42 | # Make z vector 43 | def random_c(): 44 | idxs = np.random.randint(args.num_categ, size=args.train_batch_size) 45 | c = np.zeros((args.train_batch_size, args.num_categ)) 46 | c[np.arange(args.train_batch_size), idxs] = 1 47 | return c 48 | def random_z(): 49 | rz = np.zeros([args.train_batch_size, args.wavegan_latent_dim]) 50 | rz[:, : args.num_categ] = random_c() 51 | rz[:, args.num_categ : ] = np.random.uniform(-1., 1., size=(args.train_batch_size, args.wavegan_latent_dim - args.num_categ)) 52 | return rz; 53 | 54 | z = tf.placeholder(tf.float32, (args.train_batch_size, args.wavegan_latent_dim)) 55 | 56 | # Make generator 57 | with tf.variable_scope('G'): 58 | G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs) 59 | if args.wavegan_genr_pp: 60 | with tf.variable_scope('pp_filt'): 61 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') 62 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G') 63 | 64 | # Print G summary 65 | print('-' * 80) 66 | print('Generator vars') 67 | nparams = 0 68 | for v in G_vars: 69 | v_shape = v.get_shape().as_list() 70 | v_n = reduce(lambda x, y: x * y, v_shape) 71 | nparams += v_n 72 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 73 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 74 | 75 | # Summarize 76 | tf.summary.audio('x', x, args.data_sample_rate) 77 | tf.summary.audio('G_z', G_z, args.data_sample_rate) 78 | G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1)) 79 | x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1)) 80 | tf.summary.histogram('x_rms_batch', x_rms) 81 | tf.summary.histogram('G_z_rms_batch', G_z_rms) 82 | tf.summary.scalar('x_rms', tf.reduce_mean(x_rms)) 83 | tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms)) 84 | 85 | # Make real discriminator 86 | with tf.name_scope('D_x'), tf.variable_scope('D'): 87 | D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs) 88 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D') 89 | 90 | # Print D summary 91 | print('-' * 80) 92 | print('Discriminator vars') 93 | nparams = 0 94 | for v in D_vars: 95 | v_shape = v.get_shape().as_list() 96 | v_n = reduce(lambda x, y: x * y, v_shape) 97 | nparams += v_n 98 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 99 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024))) 100 | print('-' * 80) 101 | 102 | 103 | 104 | # Make fake discriminator 105 | with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True): 106 | D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs) 107 | 108 | # Make Q 109 | with tf.variable_scope('Q'): 110 | Q_G_z = WaveGANQ(G_z, **args.wavegan_q_kwargs) 111 | Q_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Q') 112 | 113 | # Print Q summary 114 | print('Q vars') 115 | nparams = 0 116 | for v in Q_vars: 117 | v_shape = v.get_shape().as_list() 118 | v_n = reduce(lambda x, y: x * y, v_shape) 119 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name)) 120 | print('-' * 80) 121 | 122 | # Create loss 123 | D_clip_weights = None 124 | if args.wavegan_loss == 'dcgan': 125 | fake = tf.zeros([args.train_batch_size], dtype=tf.float32) 126 | real = tf.ones([args.train_batch_size], dtype=tf.float32) 127 | 128 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 129 | logits=D_G_z, 130 | labels=real 131 | )) 132 | 133 | D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 134 | logits=D_G_z, 135 | labels=fake 136 | )) 137 | D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 138 | logits=D_x, 139 | labels=real 140 | )) 141 | 142 | D_loss /= 2. 143 | elif args.wavegan_loss == 'lsgan': 144 | G_loss = tf.reduce_mean((D_G_z - 1.) ** 2) 145 | D_loss = tf.reduce_mean((D_x - 1.) ** 2) 146 | D_loss += tf.reduce_mean(D_G_z ** 2) 147 | D_loss /= 2. 148 | elif args.wavegan_loss == 'wgan': 149 | G_loss = -tf.reduce_mean(D_G_z) 150 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) 151 | 152 | with tf.name_scope('D_clip_weights'): 153 | clip_ops = [] 154 | for var in D_vars: 155 | clip_bounds = [-.01, .01] 156 | clip_ops.append( 157 | tf.assign( 158 | var, 159 | tf.clip_by_value(var, clip_bounds[0], clip_bounds[1]) 160 | ) 161 | ) 162 | D_clip_weights = tf.group(*clip_ops) 163 | elif args.wavegan_loss == 'wgan-gp': 164 | 165 | def q_cost_tf(z, q): 166 | z_cat = z[:, : args.num_categ] 167 | q_cat = q[:, : args.num_categ] 168 | lcat = tf.nn.softmax_cross_entropy_with_logits(labels=z_cat, logits=q_cat) 169 | return tf.reduce_mean(lcat); 170 | 171 | 172 | G_loss = -tf.reduce_mean(D_G_z) 173 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x) 174 | Q_loss = q_cost_tf(z, Q_G_z) 175 | 176 | alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.) 177 | differences = G_z - x 178 | interpolates = x + (alpha * differences) 179 | with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True): 180 | D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs) 181 | 182 | LAMBDA = 10 183 | gradients = tf.gradients(D_interp, [interpolates])[0] 184 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2])) 185 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.) 186 | D_loss += LAMBDA * gradient_penalty 187 | else: 188 | raise NotImplementedError() 189 | 190 | tf.summary.scalar('G_loss', G_loss) 191 | tf.summary.scalar('D_loss', D_loss) 192 | tf.summary.scalar('Q_loss', Q_loss) 193 | 194 | # Create (recommended) optimizer 195 | if args.wavegan_loss == 'dcgan': 196 | G_opt = tf.train.AdamOptimizer( 197 | learning_rate=2e-4, 198 | beta1=0.5) 199 | D_opt = tf.train.AdamOptimizer( 200 | learning_rate=2e-4, 201 | beta1=0.5) 202 | elif args.wavegan_loss == 'lsgan': 203 | G_opt = tf.train.RMSPropOptimizer( 204 | learning_rate=1e-4) 205 | D_opt = tf.train.RMSPropOptimizer( 206 | learning_rate=1e-4) 207 | elif args.wavegan_loss == 'wgan': 208 | G_opt = tf.train.RMSPropOptimizer( 209 | learning_rate=5e-5) 210 | D_opt = tf.train.RMSPropOptimizer( 211 | learning_rate=5e-5) 212 | elif args.wavegan_loss == 'wgan-gp': 213 | G_opt = tf.train.AdamOptimizer( 214 | learning_rate=1e-4, 215 | beta1=0.5, 216 | beta2=0.9) 217 | D_opt = tf.train.AdamOptimizer( 218 | learning_rate=1e-4, 219 | beta1=0.5, 220 | beta2=0.9) 221 | Q_opt = tf.train.RMSPropOptimizer( 222 | learning_rate=1e-4) 223 | else: 224 | raise NotImplementedError() 225 | 226 | # Create training ops 227 | G_train_op = G_opt.minimize(G_loss, var_list=G_vars, 228 | global_step=tf.train.get_or_create_global_step()) 229 | D_train_op = D_opt.minimize(D_loss, var_list=D_vars) 230 | Q_train_op = Q_opt.minimize(Q_loss, var_list=Q_vars+G_vars) 231 | 232 | # Run training 233 | with tf.train.MonitoredTrainingSession( 234 | checkpoint_dir=args.train_dir, 235 | save_checkpoint_secs=args.train_save_secs, 236 | save_summaries_secs=args.train_summary_secs) as sess: 237 | print('-' * 80) 238 | print('Training has started. Please use \'tensorboard --logdir={}\' to monitor.'.format(args.train_dir)) 239 | while True: 240 | # Train discriminator 241 | for i in xrange(args.wavegan_disc_nupdates): 242 | sess.run([D_loss,D_train_op], feed_dict={z: random_z()}) 243 | 244 | 245 | # Enforce Lipschitz constraint for WGAN 246 | if D_clip_weights is not None: 247 | sess.run(D_clip_weights) 248 | 249 | # Train generator 250 | sess.run([G_loss,Q_loss,G_train_op,Q_train_op], feed_dict={z: random_z()}) 251 | 252 | 253 | """ 254 | Creates and saves a MetaGraphDef for simple inference 255 | Tensors: 256 | 'samp_z_n' int32 []: Sample this many latent vectors 257 | 'samp_z' float32 [samp_z_n, latent_dim]: Resultant latent vectors 258 | 'z:0' float32 [None, latent_dim]: Input latent vectors 259 | 'flat_pad:0' int32 []: Number of padding samples to use when flattening batch to a single audio file 260 | 'G_z:0' float32 [None, slice_len, 1]: Generated outputs 261 | 'G_z_int16:0' int16 [None, slice_len, 1]: Same as above but quantizied to 16-bit PCM samples 262 | 'G_z_flat:0' float32 [None, 1]: Outputs flattened into single audio file 263 | 'G_z_flat_int16:0' int16 [None, 1]: Same as above but quantized to 16-bit PCM samples 264 | Example usage: 265 | import tensorflow as tf 266 | tf.reset_default_graph() 267 | 268 | saver = tf.train.import_meta_graph('infer.meta') 269 | graph = tf.get_default_graph() 270 | sess = tf.InteractiveSession() 271 | saver.restore(sess, 'model.ckpt-10000') 272 | 273 | z_n = graph.get_tensor_by_name('samp_z_n:0') 274 | _z = sess.run(graph.get_tensor_by_name('samp_z:0'), {z_n: 10}) 275 | 276 | z = graph.get_tensor_by_name('G_z:0') 277 | _G_z = sess.run(graph.get_tensor_by_name('G_z:0'), {z: _z}) 278 | """ 279 | def infer(args): 280 | infer_dir = os.path.join(args.train_dir, 'infer') 281 | if not os.path.isdir(infer_dir): 282 | os.makedirs(infer_dir) 283 | 284 | # Subgraph that generates latent vectors 285 | samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n') 286 | samp_z = tf.random_uniform([samp_z_n, args.wavegan_latent_dim], -1.0, 1.0, dtype=tf.float32, name='samp_z') 287 | 288 | # Input zo 289 | z = tf.placeholder(tf.float32, [None, args.wavegan_latent_dim], name='z') 290 | flat_pad = tf.placeholder(tf.int32, [], name='flat_pad') 291 | 292 | # Execute generator 293 | with tf.variable_scope('G'): 294 | G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs) 295 | if args.wavegan_genr_pp: 296 | with tf.variable_scope('pp_filt'): 297 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same') 298 | G_z = tf.identity(G_z, name='G_z') 299 | 300 | # Flatten batch 301 | nch = int(G_z.get_shape()[-1]) 302 | G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]]) 303 | G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat') 304 | 305 | # Encode to int16 306 | def float_to_int16(x, name=None): 307 | x_int16 = x * 32767. 308 | x_int16 = tf.clip_by_value(x_int16, -32767., 32767.) 309 | x_int16 = tf.cast(x_int16, tf.int16, name=name) 310 | return x_int16 311 | G_z_int16 = float_to_int16(G_z, name='G_z_int16') 312 | G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16') 313 | 314 | # Create saver 315 | G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G') 316 | global_step = tf.train.get_or_create_global_step() 317 | saver = tf.train.Saver(G_vars + [global_step]) 318 | 319 | # Export graph 320 | tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt') 321 | 322 | # Export MetaGraph 323 | infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta') 324 | tf.train.export_meta_graph( 325 | filename=infer_metagraph_fp, 326 | clear_devices=True, 327 | saver_def=saver.as_saver_def()) 328 | 329 | # Reset graph (in case training afterwards) 330 | tf.reset_default_graph() 331 | 332 | 333 | """ 334 | Generates a preview audio file every time a checkpoint is saved 335 | """ 336 | def preview(args): 337 | import matplotlib 338 | matplotlib.use('Agg') 339 | import matplotlib.pyplot as plt 340 | from scipy.io.wavfile import write as wavwrite 341 | from scipy.signal import freqz 342 | 343 | preview_dir = os.path.join(args.train_dir, 'preview') 344 | if not os.path.isdir(preview_dir): 345 | os.makedirs(preview_dir) 346 | 347 | # Load graph 348 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta') 349 | graph = tf.get_default_graph() 350 | saver = tf.train.import_meta_graph(infer_metagraph_fp) 351 | 352 | # Generate or restore z_i and z_o 353 | z_fp = os.path.join(preview_dir, 'z.pkl') 354 | if os.path.exists(z_fp): 355 | with open(z_fp, 'rb') as f: 356 | _zs = pickle.load(f) 357 | else: 358 | # Sample z 359 | samp_feeds = {} 360 | samp_feeds[graph.get_tensor_by_name('samp_z_n:0')] = args.preview_n 361 | samp_fetches = {} 362 | samp_fetches['zs'] = graph.get_tensor_by_name('samp_z:0') 363 | with tf.Session() as sess: 364 | _samp_fetches = sess.run(samp_fetches, samp_feeds) 365 | _zs = _samp_fetches['zs'] 366 | 367 | # Save z 368 | with open(z_fp, 'wb') as f: 369 | pickle.dump(_zs, f) 370 | 371 | # Set up graph for generating preview images 372 | feeds = {} 373 | feeds[graph.get_tensor_by_name('z:0')] = _zs 374 | feeds[graph.get_tensor_by_name('flat_pad:0')] = int(args.data_sample_rate / 2) 375 | fetches = {} 376 | fetches['step'] = tf.train.get_or_create_global_step() 377 | fetches['G_z'] = graph.get_tensor_by_name('G_z:0') 378 | fetches['G_z_flat_int16'] = graph.get_tensor_by_name('G_z_flat_int16:0') 379 | if args.wavegan_genr_pp: 380 | fetches['pp_filter'] = graph.get_tensor_by_name('G/pp_filt/conv1d/kernel:0')[:, 0, 0] 381 | 382 | # Summarize 383 | G_z = graph.get_tensor_by_name('G_z_flat:0') 384 | summaries = [ 385 | tf.summary.audio('preview', tf.expand_dims(G_z, axis=0), args.data_sample_rate, max_outputs=1) 386 | ] 387 | fetches['summaries'] = tf.summary.merge(summaries) 388 | summary_writer = tf.summary.FileWriter(preview_dir) 389 | 390 | # PP Summarize 391 | if args.wavegan_genr_pp: 392 | pp_fp = tf.placeholder(tf.string, []) 393 | pp_bin = tf.read_file(pp_fp) 394 | pp_png = tf.image.decode_png(pp_bin) 395 | pp_summary = tf.summary.image('pp_filt', tf.expand_dims(pp_png, axis=0)) 396 | 397 | # Loop, waiting for checkpoints 398 | ckpt_fp = None 399 | while True: 400 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) 401 | if latest_ckpt_fp != ckpt_fp: 402 | print('Preview: {}'.format(latest_ckpt_fp)) 403 | 404 | with tf.Session() as sess: 405 | saver.restore(sess, latest_ckpt_fp) 406 | 407 | _fetches = sess.run(fetches, feeds) 408 | 409 | _step = _fetches['step'] 410 | 411 | preview_fp = os.path.join(preview_dir, '{}.wav'.format(str(_step).zfill(8))) 412 | wavwrite(preview_fp, args.data_sample_rate, _fetches['G_z_flat_int16']) 413 | 414 | summary_writer.add_summary(_fetches['summaries'], _step) 415 | 416 | if args.wavegan_genr_pp: 417 | w, h = freqz(_fetches['pp_filter']) 418 | 419 | fig = plt.figure() 420 | plt.title('Digital filter frequncy response') 421 | ax1 = fig.add_subplot(111) 422 | 423 | plt.plot(w, 20 * np.log10(abs(h)), 'b') 424 | plt.ylabel('Amplitude [dB]', color='b') 425 | plt.xlabel('Frequency [rad/sample]') 426 | 427 | ax2 = ax1.twinx() 428 | angles = np.unwrap(np.angle(h)) 429 | plt.plot(w, angles, 'g') 430 | plt.ylabel('Angle (radians)', color='g') 431 | plt.grid() 432 | plt.axis('tight') 433 | 434 | _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8))) 435 | plt.savefig(_pp_fp) 436 | 437 | with tf.Session() as sess: 438 | _summary = sess.run(pp_summary, {pp_fp: _pp_fp}) 439 | summary_writer.add_summary(_summary, _step) 440 | 441 | print('Done') 442 | 443 | ckpt_fp = latest_ckpt_fp 444 | 445 | time.sleep(1) 446 | 447 | 448 | """ 449 | Computes inception score every time a checkpoint is saved 450 | """ 451 | def incept(args): 452 | incept_dir = os.path.join(args.train_dir, 'incept') 453 | if not os.path.isdir(incept_dir): 454 | os.makedirs(incept_dir) 455 | 456 | # Load GAN graph 457 | gan_graph = tf.Graph() 458 | with gan_graph.as_default(): 459 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta') 460 | gan_saver = tf.train.import_meta_graph(infer_metagraph_fp) 461 | score_saver = tf.train.Saver(max_to_keep=1) 462 | gan_z = gan_graph.get_tensor_by_name('z:0') 463 | gan_G_z = gan_graph.get_tensor_by_name('G_z:0')[:, :, 0] 464 | gan_step = gan_graph.get_tensor_by_name('global_step:0') 465 | 466 | # Load or generate latents 467 | z_fp = os.path.join(incept_dir, 'z.pkl') 468 | if os.path.exists(z_fp): 469 | with open(z_fp, 'rb') as f: 470 | _zs = pickle.load(f) 471 | else: 472 | gan_samp_z_n = gan_graph.get_tensor_by_name('samp_z_n:0') 473 | gan_samp_z = gan_graph.get_tensor_by_name('samp_z:0') 474 | with tf.Session(graph=gan_graph) as sess: 475 | _zs = sess.run(gan_samp_z, {gan_samp_z_n: args.incept_n}) 476 | with open(z_fp, 'wb') as f: 477 | pickle.dump(_zs, f) 478 | 479 | # Load classifier graph 480 | incept_graph = tf.Graph() 481 | with incept_graph.as_default(): 482 | incept_saver = tf.train.import_meta_graph(args.incept_metagraph_fp) 483 | incept_x = incept_graph.get_tensor_by_name('x:0') 484 | incept_preds = incept_graph.get_tensor_by_name('scores:0') 485 | incept_sess = tf.Session(graph=incept_graph) 486 | incept_saver.restore(incept_sess, args.incept_ckpt_fp) 487 | 488 | # Create summaries 489 | summary_graph = tf.Graph() 490 | with summary_graph.as_default(): 491 | incept_mean = tf.placeholder(tf.float32, []) 492 | incept_std = tf.placeholder(tf.float32, []) 493 | summaries = [ 494 | tf.summary.scalar('incept_mean', incept_mean), 495 | tf.summary.scalar('incept_std', incept_std) 496 | ] 497 | summaries = tf.summary.merge(summaries) 498 | summary_writer = tf.summary.FileWriter(incept_dir) 499 | 500 | # Loop, waiting for checkpoints 501 | ckpt_fp = None 502 | _best_score = 0. 503 | while True: 504 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir) 505 | if latest_ckpt_fp != ckpt_fp: 506 | print('Incept: {}'.format(latest_ckpt_fp)) 507 | 508 | sess = tf.Session(graph=gan_graph) 509 | 510 | gan_saver.restore(sess, latest_ckpt_fp) 511 | 512 | _step = sess.run(gan_step) 513 | 514 | _G_zs = [] 515 | for i in xrange(0, args.incept_n, 100): 516 | _G_zs.append(sess.run(gan_G_z, {gan_z: _zs[i:i+100]})) 517 | _G_zs = np.concatenate(_G_zs, axis=0) 518 | 519 | _preds = [] 520 | for i in xrange(0, args.incept_n, 100): 521 | _preds.append(incept_sess.run(incept_preds, {incept_x: _G_zs[i:i+100]})) 522 | _preds = np.concatenate(_preds, axis=0) 523 | 524 | # Split into k groups 525 | _incept_scores = [] 526 | split_size = args.incept_n // args.incept_k 527 | for i in xrange(args.incept_k): 528 | _split = _preds[i * split_size:(i + 1) * split_size] 529 | _kl = _split * (np.log(_split) - np.log(np.expand_dims(np.mean(_split, 0), 0))) 530 | _kl = np.mean(np.sum(_kl, 1)) 531 | _incept_scores.append(np.exp(_kl)) 532 | 533 | _incept_mean, _incept_std = np.mean(_incept_scores), np.std(_incept_scores) 534 | 535 | # Summarize 536 | with tf.Session(graph=summary_graph) as summary_sess: 537 | _summaries = summary_sess.run(summaries, {incept_mean: _incept_mean, incept_std: _incept_std}) 538 | summary_writer.add_summary(_summaries, _step) 539 | 540 | # Save 541 | if _incept_mean > _best_score: 542 | score_saver.save(sess, os.path.join(incept_dir, 'best_score'), _step) 543 | _best_score = _incept_mean 544 | 545 | sess.close() 546 | 547 | print('Done') 548 | 549 | ckpt_fp = latest_ckpt_fp 550 | 551 | time.sleep(1) 552 | 553 | incept_sess.close() 554 | 555 | 556 | if __name__ == '__main__': 557 | import argparse 558 | import glob 559 | import sys 560 | 561 | parser = argparse.ArgumentParser() 562 | 563 | parser.add_argument('mode', type=str, choices=['train', 'preview', 'incept', 'infer']) 564 | parser.add_argument('train_dir', type=str, 565 | help='Training directory') 566 | 567 | data_args = parser.add_argument_group('Data') 568 | data_args.add_argument('--data_dir', type=str, 569 | help='Data directory containing *only* audio files to load') 570 | data_args.add_argument('--data_sample_rate', type=int, 571 | help='Number of audio samples per second') 572 | data_args.add_argument('--data_slice_len', type=int, choices=[16384, 32768, 65536], 573 | help='Number of audio samples per slice (maximum generation length)') 574 | data_args.add_argument('--data_num_channels', type=int, 575 | help='Number of audio channels to generate (for >2, must match that of data)') 576 | data_args.add_argument('--data_overlap_ratio', type=float, 577 | help='Overlap ratio [0, 1) between slices') 578 | data_args.add_argument('--data_first_slice', action='store_true', dest='data_first_slice', 579 | help='If set, only use the first slice each audio example') 580 | data_args.add_argument('--data_pad_end', action='store_true', dest='data_pad_end', 581 | help='If set, use zero-padded partial slices from the end of each audio file') 582 | data_args.add_argument('--data_normalize', action='store_true', dest='data_normalize', 583 | help='If set, normalize the training examples') 584 | data_args.add_argument('--data_fast_wav', action='store_true', dest='data_fast_wav', 585 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa') 586 | data_args.add_argument('--data_prefetch_gpu_num', type=int, 587 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)') 588 | 589 | wavegan_args = parser.add_argument_group('WaveGAN') 590 | wavegan_args.add_argument('--wavegan_latent_dim', type=int, 591 | help='Number of dimensions of the latent space') 592 | wavegan_args.add_argument('--wavegan_kernel_len', type=int, 593 | help='Length of 1D filter kernels') 594 | wavegan_args.add_argument('--wavegan_dim', type=int, 595 | help='Dimensionality multiplier for model of G and D') 596 | wavegan_args.add_argument('--num_categ', type=int, 597 | help='Number of categorical variables') 598 | wavegan_args.add_argument('--wavegan_batchnorm', action='store_true', dest='wavegan_batchnorm', 599 | help='Enable batchnorm') 600 | wavegan_args.add_argument('--wavegan_disc_nupdates', type=int, 601 | help='Number of discriminator updates per generator update') 602 | wavegan_args.add_argument('--wavegan_loss', type=str, choices=['dcgan', 'lsgan', 'wgan', 'wgan-gp'], 603 | help='Which GAN loss to use') 604 | wavegan_args.add_argument('--wavegan_genr_upsample', type=str, choices=['zeros', 'nn'], 605 | help='Generator upsample strategy') 606 | wavegan_args.add_argument('--wavegan_genr_pp', action='store_true', dest='wavegan_genr_pp', 607 | help='If set, use post-processing filter') 608 | wavegan_args.add_argument('--wavegan_genr_pp_len', type=int, 609 | help='Length of post-processing filter for DCGAN') 610 | wavegan_args.add_argument('--wavegan_disc_phaseshuffle', type=int, 611 | help='Radius of phase shuffle operation') 612 | 613 | train_args = parser.add_argument_group('Train') 614 | train_args.add_argument('--train_batch_size', type=int, 615 | help='Batch size') 616 | train_args.add_argument('--train_save_secs', type=int, 617 | help='How often to save model') 618 | train_args.add_argument('--train_summary_secs', type=int, 619 | help='How often to report summaries') 620 | 621 | preview_args = parser.add_argument_group('Preview') 622 | preview_args.add_argument('--preview_n', type=int, 623 | help='Number of samples to preview') 624 | 625 | incept_args = parser.add_argument_group('Incept') 626 | incept_args.add_argument('--incept_metagraph_fp', type=str, 627 | help='Inference model for inception score') 628 | incept_args.add_argument('--incept_ckpt_fp', type=str, 629 | help='Checkpoint for inference model') 630 | incept_args.add_argument('--incept_n', type=int, 631 | help='Number of generated examples to test') 632 | incept_args.add_argument('--incept_k', type=int, 633 | help='Number of groups to test') 634 | 635 | parser.set_defaults( 636 | data_dir=None, 637 | data_sample_rate=16000, 638 | data_slice_len=16384, 639 | data_num_channels=1, 640 | data_overlap_ratio=0., 641 | data_first_slice=False, 642 | data_pad_end=False, 643 | data_normalize=False, 644 | data_fast_wav=False, 645 | data_prefetch_gpu_num=0, 646 | wavegan_latent_dim=100, 647 | wavegan_kernel_len=25, 648 | wavegan_dim=64, 649 | num_categ=10, 650 | wavegan_batchnorm=False, 651 | wavegan_disc_nupdates=5, 652 | wavegan_loss='wgan-gp', 653 | wavegan_genr_upsample='zeros', 654 | wavegan_genr_pp=False, 655 | wavegan_genr_pp_len=512, 656 | wavegan_disc_phaseshuffle=2, 657 | train_batch_size=64, 658 | train_save_secs=300, 659 | train_summary_secs=120, 660 | preview_n=32, 661 | incept_metagraph_fp='./eval/inception/infer.meta', 662 | incept_ckpt_fp='./eval/inception/best_acc-103005', 663 | incept_n=5000, 664 | incept_k=10) 665 | 666 | args = parser.parse_args() 667 | 668 | # Make train dir 669 | if not os.path.isdir(args.train_dir): 670 | os.makedirs(args.train_dir) 671 | 672 | # Save args 673 | with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f: 674 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])])) 675 | 676 | # Make model kwarg dicts 677 | setattr(args, 'wavegan_g_kwargs', { 678 | 'slice_len': args.data_slice_len, 679 | 'nch': args.data_num_channels, 680 | 'kernel_len': args.wavegan_kernel_len, 681 | 'dim': args.wavegan_dim, 682 | 'use_batchnorm': args.wavegan_batchnorm, 683 | 'upsample': args.wavegan_genr_upsample 684 | }) 685 | setattr(args, 'wavegan_d_kwargs', { 686 | 'kernel_len': args.wavegan_kernel_len, 687 | 'dim': args.wavegan_dim, 688 | 'use_batchnorm': args.wavegan_batchnorm, 689 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle 690 | }) 691 | setattr(args, 'wavegan_q_kwargs', { 692 | 'kernel_len': args.wavegan_kernel_len, 693 | 'dim': args.wavegan_dim, 694 | 'use_batchnorm': args.wavegan_batchnorm, 695 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle, 696 | 'num_categ': args.num_categ 697 | }) 698 | 699 | if args.mode == 'train': 700 | fps = glob.glob(os.path.join(args.data_dir, '*')) 701 | if len(fps) == 0: 702 | raise Exception('Did not find any audio files in specified directory') 703 | print('Found {} audio files in specified directory'.format(len(fps))) 704 | infer(args) 705 | train(fps, args) 706 | elif args.mode == 'preview': 707 | preview(args) 708 | elif args.mode == 'incept': 709 | incept(args) 710 | elif args.mode == 'infer': 711 | infer(args) 712 | else: 713 | raise NotImplementedError() 714 | --------------------------------------------------------------------------------