├── static
├── arch.png
├── fiwgan_arch.png
├── github19244c_2_5words.wav
└── start_20026_222_ch59_seed436.wav
├── backup.py
├── LICENSE.txt
├── README.md
├── loader.py
├── cinfowavegan.py
├── train_fiwgan.py
└── train_ciwgan.py
/static/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/arch.png
--------------------------------------------------------------------------------
/static/fiwgan_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/fiwgan_arch.png
--------------------------------------------------------------------------------
/static/github19244c_2_5words.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/github19244c_2_5words.wav
--------------------------------------------------------------------------------
/static/start_20026_222_ch59_seed436.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gbegus/fiwGAN-ciwGAN/HEAD/static/start_20026_222_ch59_seed436.wav
--------------------------------------------------------------------------------
/backup.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | if __name__ == '__main__':
4 | import glob
5 | import os
6 | import shutil
7 | import sys
8 | import time
9 |
10 | import tensorflow as tf
11 |
12 | train_dir, nmin = sys.argv[1:3]
13 | nsec = int(float(nmin) * 60.)
14 |
15 | backup_dir = os.path.join(train_dir, 'backup')
16 |
17 | if not os.path.exists(backup_dir):
18 | os.makedirs(backup_dir)
19 |
20 | while tf.train.latest_checkpoint(train_dir) is None:
21 | print('Waiting for first checkpoint')
22 | time.sleep(1)
23 |
24 | while True:
25 | latest_ckpt = tf.train.latest_checkpoint(train_dir)
26 |
27 | # Sleep for two seconds in case file flushing
28 | time.sleep(2)
29 |
30 | for fp in glob.glob(latest_ckpt + '*'):
31 | _, name = os.path.split(fp)
32 | backup_fp = os.path.join(backup_dir, name)
33 | print('{}->{}'.format(fp, backup_fp))
34 | shutil.copyfile(fp, backup_fp)
35 | print('-' * 80)
36 |
37 | # Sleep for an hour
38 | time.sleep(nsec)
39 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Christopher Donahue
4 | Additions and modifications (fiwgan.py & ciwgan.py) (c) 2020 Gasper Begus
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # fiwGAN (Featural InfoWaveGAN): Lexical Learning in Generative Adversarial Phonology
2 |
3 | PAPER HERE: https://www.sciencedirect.com/science/article/pii/S0893608021001052
4 |
5 | In `fiwGAN.py`. An architecture for modeling lexical learning from raw acoustic inputs called Featural InfoWaveGAN (fiwGAN) that combines Deep Convolutional GAN architecture for audio data (WaveGAN) with categorical variables in information theoretic proposal InfoGAN. Unlike InfoGAN, latent code is distributed binomially and the training is performed with sigmoid cross-entropy. Based on [WaveGAN](https://github.com/chrisdonahue/wavegan) (Donahue et al. 2019) and InfoGAN (Chen et al. 2016), partially also on code by [Rodionov](https://github.com/singnet/semantic-vision/blob/master/experiments/concept_learning/gans/info-wgan-gp/10_originfo_sepQ_v2_lr1e-3/train.py) (2018).
6 |
7 |
8 |
9 |
10 | # ciwGAN (Categorical InfoWaveGAN)
11 |
12 | An architecture for modeling lexical learning from raw acoustic inputs called Categorical InfoWaveGAN that combines Deep Convolutional GAN architecture for audio data (WaveGAN) with categorical variables in information theoretic proposal InfoGAN.
13 |
14 | Based on WaveGAN (Donahue et al. 2019) (https://github.com/chrisdonahue/wavegan) and WGAN-GP implementation of InfoGAN by Sergey Rodionov (https://github.com/singnet/semantic-vision/blob/master/experiments/concept_learning/gans/info-wgan-gp/10_originfo_sepQ_v2_lr1e-3/train.py).
15 |
16 | In addition to the Generator and the Discriminator networks, the architecture introduces a network that learns to classify generated outputs and forces the Generator to encode lexical information in its latent space. Lexical and semantic encoding is represented with a set of categorical binary variables. The network is trained on five lexical items from TIMIT. The network learns to generate lexical items and encodes the identity of each item in categorical variables of the latent space. By manipulating the categorical variables in the latent space that encode lexical information, the network outputs the five lexical items, suggesting that each lexical item is represented with unique categorical code. Such representation can serve as the basis for lexical and semantic learning from raw acoustic input.
17 |
18 |
19 |
20 | After 19,244 steps trained on _oily, water, rag, suit_ and _year_ from TIMIT, the network learns to output lexical items based on latent code. The following generated outputs are generated with the following values of c:
21 |
22 | 1. \[1, 0, 0, 0, 0\]: _suit_
23 | 2. \[0, 1, 0, 0, 0\]: _year_
24 | 3. \[0, 0, 1, 0, 0\]: _water_
25 | 4. \[0, 0, 0, 1, 0\]: _oily_
26 | 5. \[0, 0, 0, 0, 1\]: _rag_
27 |
28 | [Audio sample 1](http://faculty.washington.edu/begus/files/github19244c_2_5words.wav)
29 | [Audio sample 2](http://faculty.washington.edu/begus/files/github19244c_2_5words.wav)
30 |
31 | To change number of categorical latent variables:
32 |
33 | ```
34 | --num_categ n
35 | ```
36 |
--------------------------------------------------------------------------------
/loader.py:
--------------------------------------------------------------------------------
1 | from scipy.io.wavfile import read as wavread
2 | import numpy as np
3 |
4 | import tensorflow as tf
5 |
6 | import sys
7 |
8 |
9 | def decode_audio(fp, fs=None, num_channels=1, normalize=False, fast_wav=False):
10 | """Decodes audio file paths into 32-bit floating point vectors.
11 |
12 | Args:
13 | fp: Audio file path.
14 | fs: If specified, resamples decoded audio to this rate.
15 | mono: If true, averages channels to mono.
16 | fast_wav: Assume fp is a standard WAV file (PCM 16-bit or float 32-bit).
17 |
18 | Returns:
19 | A np.float32 array containing the audio samples at specified sample rate.
20 | """
21 | if fast_wav:
22 | # Read with scipy wavread (fast).
23 | _fs, _wav = wavread(fp)
24 | if fs is not None and fs != _fs:
25 | raise NotImplementedError('Scipy cannot resample audio.')
26 | if _wav.dtype == np.int16:
27 | _wav = _wav.astype(np.float32)
28 | _wav /= 32768.
29 | elif _wav.dtype == np.float32:
30 | _wav = np.copy(_wav)
31 | else:
32 | raise NotImplementedError('Scipy cannot process atypical WAV files.')
33 | else:
34 | # Decode with librosa load (slow but supports file formats like mp3).
35 | import librosa
36 | _wav, _fs = librosa.core.load(fp, sr=fs, mono=False)
37 | if _wav.ndim == 2:
38 | _wav = np.swapaxes(_wav, 0, 1)
39 |
40 | assert _wav.dtype == np.float32
41 |
42 | # At this point, _wav is np.float32 either [nsamps,] or [nsamps, nch].
43 | # We want [nsamps, 1, nch] to mimic 2D shape of spectral feats.
44 | if _wav.ndim == 1:
45 | nsamps = _wav.shape[0]
46 | nch = 1
47 | else:
48 | nsamps, nch = _wav.shape
49 | _wav = np.reshape(_wav, [nsamps, 1, nch])
50 |
51 | # Average (mono) or expand (stereo) channels
52 | if nch != num_channels:
53 | if num_channels == 1:
54 | _wav = np.mean(_wav, 2, keepdims=True)
55 | elif nch == 1 and num_channels == 2:
56 | _wav = np.concatenate([_wav, _wav], axis=2)
57 | else:
58 | raise ValueError('Number of audio channels not equal to num specified')
59 |
60 | if normalize:
61 | factor = np.max(np.abs(_wav))
62 | if factor > 0:
63 | _wav /= factor
64 |
65 | return _wav
66 |
67 |
68 | def decode_extract_and_batch(
69 | fps,
70 | batch_size,
71 | slice_len,
72 | decode_fs,
73 | decode_num_channels,
74 | decode_normalize=True,
75 | decode_fast_wav=False,
76 | decode_parallel_calls=1,
77 | slice_randomize_offset=False,
78 | slice_first_only=False,
79 | slice_overlap_ratio=0,
80 | slice_pad_end=False,
81 | repeat=False,
82 | shuffle=False,
83 | shuffle_buffer_size=None,
84 | prefetch_size=None,
85 | prefetch_gpu_num=None):
86 | """Decodes audio file paths into mini-batches of samples.
87 |
88 | Args:
89 | fps: List of audio file paths.
90 | batch_size: Number of items in the batch.
91 | slice_len: Length of the sliceuences in samples or feature timesteps.
92 | decode_fs: (Re-)sample rate for decoded audio files.
93 | decode_num_channels: Number of channels for decoded audio files.
94 | decode_normalize: If false, do not normalize audio waveforms.
95 | decode_fast_wav: If true, uses scipy to decode standard wav files.
96 | decode_parallel_calls: Number of parallel decoding threads.
97 | slice_randomize_offset: If true, randomize starting position for slice.
98 | slice_first_only: If true, only use first slice from each audio file.
99 | slice_overlap_ratio: Ratio of overlap between adjacent slices.
100 | slice_pad_end: If true, allows zero-padded examples from the end of each audio file.
101 | repeat: If true (for training), continuously iterate through the dataset.
102 | shuffle: If true (for training), buffer and shuffle the sliceuences.
103 | shuffle_buffer_size: Number of examples to queue up before grabbing a batch.
104 | prefetch_size: Number of examples to prefetch from the queue.
105 | prefetch_gpu_num: If specified, prefetch examples to GPU.
106 |
107 | Returns:
108 | A tuple of np.float32 tensors representing audio waveforms.
109 | audio: [batch_size, slice_len, 1, nch]
110 | """
111 | # Create dataset of filepaths
112 | dataset = tf.data.Dataset.from_tensor_slices(fps)
113 |
114 | # Shuffle all filepaths every epoch
115 | if shuffle:
116 | dataset = dataset.shuffle(buffer_size=len(fps))
117 |
118 | # Repeat
119 | if repeat:
120 | dataset = dataset.repeat()
121 |
122 | def _decode_audio_shaped(fp):
123 | _decode_audio_closure = lambda _fp: decode_audio(
124 | _fp,
125 | fs=decode_fs,
126 | num_channels=decode_num_channels,
127 | normalize=decode_normalize,
128 | fast_wav=decode_fast_wav)
129 |
130 | audio = tf.py_func(
131 | _decode_audio_closure,
132 | [fp],
133 | tf.float32,
134 | stateful=False)
135 | audio.set_shape([None, 1, decode_num_channels])
136 |
137 | return audio
138 |
139 | # Decode audio
140 | dataset = dataset.map(
141 | _decode_audio_shaped,
142 | num_parallel_calls=decode_parallel_calls)
143 |
144 | # Parallel
145 | def _slice(audio):
146 | # Calculate hop size
147 | if slice_overlap_ratio < 0:
148 | raise ValueError('Overlap ratio must be greater than 0')
149 | slice_hop = int(round(slice_len * (1. - slice_overlap_ratio)) + 1e-4)
150 | if slice_hop < 1:
151 | raise ValueError('Overlap ratio too high')
152 |
153 | # Randomize starting phase:
154 | if slice_randomize_offset:
155 | start = tf.random_uniform([], maxval=slice_len, dtype=tf.int32)
156 | audio = audio[start:]
157 |
158 | # Extract sliceuences
159 | audio_slices = tf.contrib.signal.frame(
160 | audio,
161 | slice_len,
162 | slice_hop,
163 | pad_end=slice_pad_end,
164 | pad_value=0,
165 | axis=0)
166 |
167 | # Only use first slice if requested
168 | if slice_first_only:
169 | audio_slices = audio_slices[:1]
170 |
171 | return audio_slices
172 |
173 | def _slice_dataset_wrapper(audio):
174 | audio_slices = _slice(audio)
175 | return tf.data.Dataset.from_tensor_slices(audio_slices)
176 |
177 | # Extract parallel sliceuences from both audio and features
178 | dataset = dataset.flat_map(_slice_dataset_wrapper)
179 |
180 | # Shuffle examples
181 | if shuffle:
182 | dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
183 |
184 | # Make batches
185 | dataset = dataset.batch(batch_size, drop_remainder=True)
186 |
187 | # Prefetch a number of batches
188 | if prefetch_size is not None:
189 | dataset = dataset.prefetch(prefetch_size)
190 | if prefetch_gpu_num is not None and prefetch_gpu_num >= 0:
191 | dataset = dataset.apply(
192 | tf.data.experimental.prefetch_to_device(
193 | '/device:GPU:{}'.format(prefetch_gpu_num)))
194 |
195 | # Get tensors
196 | iterator = dataset.make_one_shot_iterator()
197 |
198 | return iterator.get_next()
199 |
--------------------------------------------------------------------------------
/cinfowavegan.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def conv1d_transpose(
5 | inputs,
6 | filters,
7 | kernel_width,
8 | stride=4,
9 | padding='same',
10 | upsample='zeros'):
11 | if upsample == 'zeros':
12 | return tf.layers.conv2d_transpose(
13 | tf.expand_dims(inputs, axis=1),
14 | filters,
15 | (1, kernel_width),
16 | strides=(1, stride),
17 | padding='same'
18 | )[:, 0]
19 | elif upsample == 'nn':
20 | batch_size = tf.shape(inputs)[0]
21 | _, w, nch = inputs.get_shape().as_list()
22 |
23 | x = inputs
24 |
25 | x = tf.expand_dims(x, axis=1)
26 | x = tf.image.resize_nearest_neighbor(x, [1, w * stride])
27 | x = x[:, 0]
28 |
29 | return tf.layers.conv1d(
30 | x,
31 | filters,
32 | kernel_width,
33 | 1,
34 | padding='same')
35 | else:
36 | raise NotImplementedError
37 |
38 |
39 | """
40 | Input: [None, 100]
41 | Output: [None, slice_len, 1]
42 | """
43 | def WaveGANGenerator(
44 | z,
45 | slice_len=16384,
46 | nch=1,
47 | kernel_len=25,
48 | dim=64,
49 | use_batchnorm=False,
50 | upsample='zeros',
51 | train=False):
52 | assert slice_len in [16384, 32768, 65536]
53 | batch_size = tf.shape(z)[0]
54 |
55 | if use_batchnorm:
56 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=train)
57 | else:
58 | batchnorm = lambda x: x
59 |
60 | # FC and reshape for convolution
61 | # [100] -> [16, 1024]
62 | dim_mul = 16 if slice_len == 16384 else 32
63 | output = z
64 | with tf.variable_scope('z_project'):
65 | output = tf.layers.dense(output, 4 * 4 * dim * dim_mul)
66 | output = tf.reshape(output, [batch_size, 16, dim * dim_mul])
67 | output = batchnorm(output)
68 | output = tf.nn.relu(output)
69 | dim_mul //= 2
70 |
71 | # Layer 0
72 | # [16, 1024] -> [64, 512]
73 | with tf.variable_scope('upconv_0'):
74 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
75 | output = batchnorm(output)
76 | output = tf.nn.relu(output)
77 | dim_mul //= 2
78 |
79 | # Layer 1
80 | # [64, 512] -> [256, 256]
81 | with tf.variable_scope('upconv_1'):
82 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
83 | output = batchnorm(output)
84 | output = tf.nn.relu(output)
85 | dim_mul //= 2
86 |
87 | # Layer 2
88 | # [256, 256] -> [1024, 128]
89 | with tf.variable_scope('upconv_2'):
90 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
91 | output = batchnorm(output)
92 | output = tf.nn.relu(output)
93 | dim_mul //= 2
94 |
95 | # Layer 3
96 | # [1024, 128] -> [4096, 64]
97 | with tf.variable_scope('upconv_3'):
98 | output = conv1d_transpose(output, dim * dim_mul, kernel_len, 4, upsample=upsample)
99 | output = batchnorm(output)
100 | output = tf.nn.relu(output)
101 |
102 | if slice_len == 16384:
103 | # Layer 4
104 | # [4096, 64] -> [16384, nch]
105 | with tf.variable_scope('upconv_4'):
106 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample)
107 | output = tf.nn.tanh(output)
108 | elif slice_len == 32768:
109 | # Layer 4
110 | # [4096, 128] -> [16384, 64]
111 | with tf.variable_scope('upconv_4'):
112 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample)
113 | output = batchnorm(output)
114 | output = tf.nn.relu(output)
115 |
116 | # Layer 5
117 | # [16384, 64] -> [32768, nch]
118 | with tf.variable_scope('upconv_5'):
119 | output = conv1d_transpose(output, nch, kernel_len, 2, upsample=upsample)
120 | output = tf.nn.tanh(output)
121 | elif slice_len == 65536:
122 | # Layer 4
123 | # [4096, 128] -> [16384, 64]
124 | with tf.variable_scope('upconv_4'):
125 | output = conv1d_transpose(output, dim, kernel_len, 4, upsample=upsample)
126 | output = batchnorm(output)
127 | output = tf.nn.relu(output)
128 |
129 | # Layer 5
130 | # [16384, 64] -> [65536, nch]
131 | with tf.variable_scope('upconv_5'):
132 | output = conv1d_transpose(output, nch, kernel_len, 4, upsample=upsample)
133 | output = tf.nn.tanh(output)
134 |
135 | # Automatically update batchnorm moving averages every time G is used during training
136 | if train and use_batchnorm:
137 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
138 | if slice_len == 16384:
139 | assert len(update_ops) == 10
140 | else:
141 | assert len(update_ops) == 12
142 | with tf.control_dependencies(update_ops):
143 | output = tf.identity(output)
144 |
145 | return output
146 |
147 |
148 | def lrelu(inputs, alpha=0.2):
149 | return tf.maximum(alpha * inputs, inputs)
150 |
151 |
152 | def apply_phaseshuffle(x, rad, pad_type='reflect'):
153 | b, x_len, nch = x.get_shape().as_list()
154 |
155 | phase = tf.random_uniform([], minval=-rad, maxval=rad + 1, dtype=tf.int32)
156 | pad_l = tf.maximum(phase, 0)
157 | pad_r = tf.maximum(-phase, 0)
158 | phase_start = pad_r
159 | x = tf.pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode=pad_type)
160 |
161 | x = x[:, phase_start:phase_start+x_len]
162 | x.set_shape([b, x_len, nch])
163 |
164 | return x
165 |
166 |
167 | """
168 | Input: [None, slice_len, nch]
169 | Output: [None] (linear output)
170 | """
171 | def WaveGANDiscriminator(
172 | x,
173 | kernel_len=25,
174 | dim=64,
175 | use_batchnorm=False,
176 | phaseshuffle_rad=0):
177 | batch_size = tf.shape(x)[0]
178 | slice_len = int(x.get_shape()[1])
179 |
180 | if use_batchnorm:
181 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=True)
182 | else:
183 | batchnorm = lambda x: x
184 |
185 | if phaseshuffle_rad > 0:
186 | phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad)
187 | else:
188 | phaseshuffle = lambda x: x
189 |
190 | # Layer 0
191 | # [16384, 1] -> [4096, 64]
192 | output = x
193 | with tf.variable_scope('downconv_0'):
194 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME')
195 | output = lrelu(output)
196 | output = phaseshuffle(output)
197 |
198 | # Layer 1
199 | # [4096, 64] -> [1024, 128]
200 | with tf.variable_scope('downconv_1'):
201 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME')
202 | output = batchnorm(output)
203 | output = lrelu(output)
204 | output = phaseshuffle(output)
205 |
206 | # Layer 2
207 | # [1024, 128] -> [256, 256]
208 | with tf.variable_scope('downconv_2'):
209 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME')
210 | output = batchnorm(output)
211 | output = lrelu(output)
212 | output = phaseshuffle(output)
213 |
214 | # Layer 3
215 | # [256, 256] -> [64, 512]
216 | with tf.variable_scope('downconv_3'):
217 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME')
218 | output = batchnorm(output)
219 | output = lrelu(output)
220 | output = phaseshuffle(output)
221 |
222 | # Layer 4
223 | # [64, 512] -> [16, 1024]
224 | with tf.variable_scope('downconv_4'):
225 | output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME')
226 | output = batchnorm(output)
227 | output = lrelu(output)
228 |
229 | if slice_len == 32768:
230 | # Layer 5
231 | # [32, 1024] -> [16, 2048]
232 | with tf.variable_scope('downconv_5'):
233 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 2, padding='SAME')
234 | output = batchnorm(output)
235 | output = lrelu(output)
236 | elif slice_len == 65536:
237 | # Layer 5
238 | # [64, 1024] -> [16, 2048]
239 | with tf.variable_scope('downconv_5'):
240 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 4, padding='SAME')
241 | output = batchnorm(output)
242 | output = lrelu(output)
243 |
244 | # Flatten
245 | output = tf.reshape(output, [batch_size, -1])
246 |
247 | # Connect to single logit
248 | with tf.variable_scope('output'):
249 | output = tf.layers.dense(output, 1)[:, 0]
250 |
251 | # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training
252 |
253 | return output
254 |
255 |
256 |
257 | def WaveGANQ(
258 | x,
259 | kernel_len=25,
260 | dim=64,
261 | use_batchnorm=False,
262 | phaseshuffle_rad=0,
263 | num_categ=10):
264 | batch_size = tf.shape(x)[0]
265 | slice_len = int(x.get_shape()[1])
266 |
267 | if use_batchnorm:
268 | batchnorm = lambda x: tf.layers.batch_normalization(x, training=True)
269 | else:
270 | batchnorm = lambda x: x
271 |
272 | if phaseshuffle_rad > 0:
273 | phaseshuffle = lambda x: apply_phaseshuffle(x, phaseshuffle_rad)
274 | else:
275 | phaseshuffle = lambda x: x
276 |
277 | # Layer 0
278 | # [16384, 1] -> [4096, 64]
279 | output = x
280 | with tf.variable_scope('Qdownconv_0'):
281 | output = tf.layers.conv1d(output, dim, kernel_len, 4, padding='SAME')
282 | output = lrelu(output)
283 | output = phaseshuffle(output)
284 |
285 | # Layer 1
286 | # [4096, 64] -> [1024, 128]
287 | with tf.variable_scope('Qdownconv_1'):
288 | output = tf.layers.conv1d(output, dim * 2, kernel_len, 4, padding='SAME')
289 | output = batchnorm(output)
290 | output = lrelu(output)
291 | output = phaseshuffle(output)
292 |
293 | # Layer 2
294 | # [1024, 128] -> [256, 256]
295 | with tf.variable_scope('Qdownconv_2'):
296 | output = tf.layers.conv1d(output, dim * 4, kernel_len, 4, padding='SAME')
297 | output = batchnorm(output)
298 | output = lrelu(output)
299 | output = phaseshuffle(output)
300 |
301 | # Layer 3
302 | # [256, 256] -> [64, 512]
303 | with tf.variable_scope('Qdownconv_3'):
304 | output = tf.layers.conv1d(output, dim * 8, kernel_len, 4, padding='SAME')
305 | output = batchnorm(output)
306 | output = lrelu(output)
307 | output = phaseshuffle(output)
308 |
309 | # Layer 4
310 | # [64, 512] -> [16, 1024]
311 | with tf.variable_scope('Qdownconv_4'):
312 | output = tf.layers.conv1d(output, dim * 16, kernel_len, 4, padding='SAME')
313 | output = batchnorm(output)
314 | output = lrelu(output)
315 |
316 | if slice_len == 32768:
317 | # Layer 5
318 | # [32, 1024] -> [16, 2048]
319 | with tf.variable_scope('Qdownconv_5'):
320 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 2, padding='SAME')
321 | output = batchnorm(output)
322 | output = lrelu(output)
323 | elif slice_len == 65536:
324 | # Layer 5
325 | # [64, 1024] -> [16, 2048]
326 | with tf.variable_scope('Qdownconv_5'):
327 | output = tf.layers.conv1d(output, dim * 32, kernel_len, 4, padding='SAME')
328 | output = batchnorm(output)
329 | output = lrelu(output)
330 |
331 | # Flatten
332 | output = tf.reshape(output, [batch_size, -1])
333 |
334 | # Connect to single logit
335 | with tf.variable_scope('Qoutput'):
336 | Qoutput = tf.layers.dense(output, num_categ)
337 |
338 |
339 | # Don't need to aggregate batchnorm update ops like we do for the generator because we only use the discriminator for training
340 |
341 | return Qoutput
342 |
--------------------------------------------------------------------------------
/train_fiwgan.py:
--------------------------------------------------------------------------------
1 | '''
2 | fiwGAN: Featural InfoWaveGAN
3 | Gasper Begus (begus@uw.edu) 2020
4 | Based on WaveGAN (Donahue et al. 2019) and InfoGAN (Chen et al. 2016), partially also on code by Rodionov (2018).
5 | Unlike InfoGAN, the latent code is binomially distributed (features) and training is performed with sigmoid cross-entropy.
6 | '''
7 |
8 | from __future__ import print_function
9 |
10 | try:
11 | import cPickle as pickle
12 | except:
13 | import pickle
14 | from functools import reduce
15 | import os
16 | import time
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 | from six.moves import xrange
21 |
22 | import loader
23 | from waveganNET import WaveGANGenerator, WaveGANDiscriminator, WaveGANQ
24 |
25 |
26 | """
27 | Trains a WaveGAN
28 | """
29 | def train(fps, args):
30 | with tf.name_scope('loader'):
31 | x = loader.decode_extract_and_batch(
32 | fps,
33 | batch_size=args.train_batch_size,
34 | slice_len=args.data_slice_len,
35 | decode_fs=args.data_sample_rate,
36 | decode_num_channels=args.data_num_channels,
37 | decode_fast_wav=args.data_fast_wav,
38 | decode_parallel_calls=4,
39 | slice_randomize_offset=False if args.data_first_slice else True,
40 | slice_first_only=args.data_first_slice,
41 | slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio,
42 | slice_pad_end=True if args.data_first_slice else args.data_pad_end,
43 | repeat=True,
44 | shuffle=True,
45 | shuffle_buffer_size=4096,
46 | prefetch_size=args.train_batch_size * 4,
47 | prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]
48 |
49 | # Make z vector
50 |
51 | categ = tf.keras.backend.random_binomial([args.train_batch_size,args.num_categ],0.5)
52 | uniform = tf.random_uniform([args.train_batch_size,args.wavegan_latent_dim-args.num_categ],-1.,1.)
53 | z = tf.concat([categ,uniform],1)
54 |
55 | # Make generator
56 | with tf.variable_scope('G'):
57 | G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
58 | if args.wavegan_genr_pp:
59 | with tf.variable_scope('pp_filt'):
60 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
61 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')
62 |
63 | # Print G summary
64 | print('-' * 80)
65 | print('Generator vars')
66 | nparams = 0
67 | for v in G_vars:
68 | v_shape = v.get_shape().as_list()
69 | v_n = reduce(lambda x, y: x * y, v_shape)
70 | nparams += v_n
71 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
72 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
73 |
74 | # Summarize
75 | tf.summary.audio('x', x, args.data_sample_rate)
76 | tf.summary.audio('G_z', G_z, args.data_sample_rate)
77 | G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
78 | x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
79 | tf.summary.histogram('x_rms_batch', x_rms)
80 | tf.summary.histogram('G_z_rms_batch', G_z_rms)
81 | tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
82 | tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
83 |
84 | # Make real discriminator
85 | with tf.name_scope('D_x'), tf.variable_scope('D'):
86 | D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
87 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')
88 |
89 | # Print D summary
90 | print('-' * 80)
91 | print('Discriminator vars')
92 | nparams = 0
93 | for v in D_vars:
94 | v_shape = v.get_shape().as_list()
95 | v_n = reduce(lambda x, y: x * y, v_shape)
96 | nparams += v_n
97 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
98 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
99 | print('-' * 80)
100 |
101 |
102 |
103 | # Make fake discriminator
104 | with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
105 | D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)
106 |
107 | # Make Q
108 | with tf.variable_scope('Q'):
109 | Q_G_z = WaveGANQ(G_z, **args.wavegan_q_kwargs)
110 | Q_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Q')
111 |
112 | # Print Q summary
113 | print('Q vars')
114 | nparams = 0
115 | for v in Q_vars:
116 | v_shape = v.get_shape().as_list()
117 | v_n = reduce(lambda x, y: x * y, v_shape)
118 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
119 | print('-' * 80)
120 |
121 | # Create loss
122 | D_clip_weights = None
123 | if args.wavegan_loss == 'dcgan':
124 | fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
125 | real = tf.ones([args.train_batch_size], dtype=tf.float32)
126 |
127 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
128 | logits=D_G_z,
129 | labels=real
130 | ))
131 |
132 | D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
133 | logits=D_G_z,
134 | labels=fake
135 | ))
136 | D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
137 | logits=D_x,
138 | labels=real
139 | ))
140 |
141 | D_loss /= 2.
142 | elif args.wavegan_loss == 'lsgan':
143 | G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
144 | D_loss = tf.reduce_mean((D_x - 1.) ** 2)
145 | D_loss += tf.reduce_mean(D_G_z ** 2)
146 | D_loss /= 2.
147 | elif args.wavegan_loss == 'wgan':
148 | G_loss = -tf.reduce_mean(D_G_z)
149 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)
150 |
151 | with tf.name_scope('D_clip_weights'):
152 | clip_ops = []
153 | for var in D_vars:
154 | clip_bounds = [-.01, .01]
155 | clip_ops.append(
156 | tf.assign(
157 | var,
158 | tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
159 | )
160 | )
161 | D_clip_weights = tf.group(*clip_ops)
162 | elif args.wavegan_loss == 'wgan-gp':
163 |
164 | z_q_loss = z[:, : args.num_categ]
165 | q_q_loss = Q_G_z[:, : args.num_categ]
166 | q_sigmoid = tf.nn.sigmoid_cross_entropy_with_logits(labels=z_q_loss, logits=q_q_loss)
167 | G_loss = -tf.reduce_mean(D_G_z)
168 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)
169 | Q_loss = tf.reduce_mean(q_sigmoid)
170 |
171 | alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
172 | differences = G_z - x
173 | interpolates = x + (alpha * differences)
174 | with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
175 | D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs)
176 |
177 | LAMBDA = 10
178 | gradients = tf.gradients(D_interp, [interpolates])[0]
179 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
180 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
181 | D_loss += LAMBDA * gradient_penalty
182 | else:
183 | raise NotImplementedError()
184 |
185 | tf.summary.scalar('G_loss', G_loss)
186 | tf.summary.scalar('D_loss', D_loss)
187 | tf.summary.scalar('Q_loss', Q_loss)
188 |
189 | # Create (recommended) optimizer
190 | if args.wavegan_loss == 'dcgan':
191 | G_opt = tf.train.AdamOptimizer(
192 | learning_rate=2e-4,
193 | beta1=0.5)
194 | D_opt = tf.train.AdamOptimizer(
195 | learning_rate=2e-4,
196 | beta1=0.5)
197 | elif args.wavegan_loss == 'lsgan':
198 | G_opt = tf.train.RMSPropOptimizer(
199 | learning_rate=1e-4)
200 | D_opt = tf.train.RMSPropOptimizer(
201 | learning_rate=1e-4)
202 | elif args.wavegan_loss == 'wgan':
203 | G_opt = tf.train.RMSPropOptimizer(
204 | learning_rate=5e-5)
205 | D_opt = tf.train.RMSPropOptimizer(
206 | learning_rate=5e-5)
207 | elif args.wavegan_loss == 'wgan-gp':
208 | G_opt = tf.train.AdamOptimizer(
209 | learning_rate=1e-4,
210 | beta1=0.5,
211 | beta2=0.9)
212 | D_opt = tf.train.AdamOptimizer(
213 | learning_rate=1e-4,
214 | beta1=0.5,
215 | beta2=0.9)
216 | Q_opt = tf.train.RMSPropOptimizer(
217 | learning_rate=1e-4)
218 | else:
219 | raise NotImplementedError()
220 |
221 | # Create training ops
222 | G_train_op = G_opt.minimize(G_loss, var_list=G_vars,
223 | global_step=tf.train.get_or_create_global_step())
224 | D_train_op = D_opt.minimize(D_loss, var_list=D_vars)
225 | Q_train_op = Q_opt.minimize(Q_loss, var_list=Q_vars+G_vars)
226 |
227 | # Run training
228 | with tf.train.MonitoredTrainingSession(
229 | checkpoint_dir=args.train_dir,
230 | save_checkpoint_secs=args.train_save_secs,
231 | save_summaries_secs=args.train_summary_secs) as sess:
232 | print('-' * 80)
233 | print('Training has started. Please use \'tensorboard --logdir={}\' to monitor.'.format(args.train_dir))
234 | while True:
235 | # Train discriminator
236 | for i in xrange(args.wavegan_disc_nupdates):
237 | sess.run(D_train_op)
238 |
239 |
240 | # Enforce Lipschitz constraint for WGAN
241 | if D_clip_weights is not None:
242 | sess.run(D_clip_weights)
243 |
244 | # Train generator
245 | sess.run([G_train_op,Q_train_op])
246 |
247 |
248 | """
249 | Creates and saves a MetaGraphDef for simple inference
250 | Tensors:
251 | 'samp_z_n' int32 []: Sample this many latent vectors
252 | 'samp_z' float32 [samp_z_n, latent_dim]: Resultant latent vectors
253 | 'z:0' float32 [None, latent_dim]: Input latent vectors
254 | 'flat_pad:0' int32 []: Number of padding samples to use when flattening batch to a single audio file
255 | 'G_z:0' float32 [None, slice_len, 1]: Generated outputs
256 | 'G_z_int16:0' int16 [None, slice_len, 1]: Same as above but quantizied to 16-bit PCM samples
257 | 'G_z_flat:0' float32 [None, 1]: Outputs flattened into single audio file
258 | 'G_z_flat_int16:0' int16 [None, 1]: Same as above but quantized to 16-bit PCM samples
259 | Example usage:
260 | import tensorflow as tf
261 | tf.reset_default_graph()
262 |
263 | saver = tf.train.import_meta_graph('infer.meta')
264 | graph = tf.get_default_graph()
265 | sess = tf.InteractiveSession()
266 | saver.restore(sess, 'model.ckpt-10000')
267 |
268 | z_n = graph.get_tensor_by_name('samp_z_n:0')
269 | _z = sess.run(graph.get_tensor_by_name('samp_z:0'), {z_n: 10})
270 |
271 | z = graph.get_tensor_by_name('G_z:0')
272 | _G_z = sess.run(graph.get_tensor_by_name('G_z:0'), {z: _z})
273 | """
274 | def infer(args):
275 | infer_dir = os.path.join(args.train_dir, 'infer')
276 | if not os.path.isdir(infer_dir):
277 | os.makedirs(infer_dir)
278 |
279 | # Subgraph that generates latent vectors
280 | samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n')
281 | samp_z = tf.random_uniform([samp_z_n, args.wavegan_latent_dim], -1.0, 1.0, dtype=tf.float32, name='samp_z')
282 |
283 | # Input zo
284 | z = tf.placeholder(tf.float32, [None, args.wavegan_latent_dim], name='z')
285 | flat_pad = tf.placeholder(tf.int32, [], name='flat_pad')
286 |
287 | # Execute generator
288 | with tf.variable_scope('G'):
289 | G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs)
290 | if args.wavegan_genr_pp:
291 | with tf.variable_scope('pp_filt'):
292 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
293 | G_z = tf.identity(G_z, name='G_z')
294 |
295 | # Flatten batch
296 | nch = int(G_z.get_shape()[-1])
297 | G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]])
298 | G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat')
299 |
300 | # Encode to int16
301 | def float_to_int16(x, name=None):
302 | x_int16 = x * 32767.
303 | x_int16 = tf.clip_by_value(x_int16, -32767., 32767.)
304 | x_int16 = tf.cast(x_int16, tf.int16, name=name)
305 | return x_int16
306 | G_z_int16 = float_to_int16(G_z, name='G_z_int16')
307 | G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16')
308 |
309 | # Create saver
310 | G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
311 | global_step = tf.train.get_or_create_global_step()
312 | saver = tf.train.Saver(G_vars + [global_step])
313 |
314 | # Export graph
315 | tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')
316 |
317 | # Export MetaGraph
318 | infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
319 | tf.train.export_meta_graph(
320 | filename=infer_metagraph_fp,
321 | clear_devices=True,
322 | saver_def=saver.as_saver_def())
323 |
324 | # Reset graph (in case training afterwards)
325 | tf.reset_default_graph()
326 |
327 |
328 | """
329 | Generates a preview audio file every time a checkpoint is saved
330 | """
331 | def preview(args):
332 | import matplotlib
333 | matplotlib.use('Agg')
334 | import matplotlib.pyplot as plt
335 | from scipy.io.wavfile import write as wavwrite
336 | from scipy.signal import freqz
337 |
338 | preview_dir = os.path.join(args.train_dir, 'preview')
339 | if not os.path.isdir(preview_dir):
340 | os.makedirs(preview_dir)
341 |
342 | # Load graph
343 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
344 | graph = tf.get_default_graph()
345 | saver = tf.train.import_meta_graph(infer_metagraph_fp)
346 |
347 | # Generate or restore z_i and z_o
348 | z_fp = os.path.join(preview_dir, 'z.pkl')
349 | if os.path.exists(z_fp):
350 | with open(z_fp, 'rb') as f:
351 | _zs = pickle.load(f)
352 | else:
353 | # Sample z
354 | samp_feeds = {}
355 | samp_feeds[graph.get_tensor_by_name('samp_z_n:0')] = args.preview_n
356 | samp_fetches = {}
357 | samp_fetches['zs'] = graph.get_tensor_by_name('samp_z:0')
358 | with tf.Session() as sess:
359 | _samp_fetches = sess.run(samp_fetches, samp_feeds)
360 | _zs = _samp_fetches['zs']
361 |
362 | # Save z
363 | with open(z_fp, 'wb') as f:
364 | pickle.dump(_zs, f)
365 |
366 | # Set up graph for generating preview images
367 | feeds = {}
368 | feeds[graph.get_tensor_by_name('z:0')] = _zs
369 | feeds[graph.get_tensor_by_name('flat_pad:0')] = int(args.data_sample_rate / 2)
370 | fetches = {}
371 | fetches['step'] = tf.train.get_or_create_global_step()
372 | fetches['G_z'] = graph.get_tensor_by_name('G_z:0')
373 | fetches['G_z_flat_int16'] = graph.get_tensor_by_name('G_z_flat_int16:0')
374 | if args.wavegan_genr_pp:
375 | fetches['pp_filter'] = graph.get_tensor_by_name('G/pp_filt/conv1d/kernel:0')[:, 0, 0]
376 |
377 | # Summarize
378 | G_z = graph.get_tensor_by_name('G_z_flat:0')
379 | summaries = [
380 | tf.summary.audio('preview', tf.expand_dims(G_z, axis=0), args.data_sample_rate, max_outputs=1)
381 | ]
382 | fetches['summaries'] = tf.summary.merge(summaries)
383 | summary_writer = tf.summary.FileWriter(preview_dir)
384 |
385 | # PP Summarize
386 | if args.wavegan_genr_pp:
387 | pp_fp = tf.placeholder(tf.string, [])
388 | pp_bin = tf.read_file(pp_fp)
389 | pp_png = tf.image.decode_png(pp_bin)
390 | pp_summary = tf.summary.image('pp_filt', tf.expand_dims(pp_png, axis=0))
391 |
392 | # Loop, waiting for checkpoints
393 | ckpt_fp = None
394 | while True:
395 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
396 | if latest_ckpt_fp != ckpt_fp:
397 | print('Preview: {}'.format(latest_ckpt_fp))
398 |
399 | with tf.Session() as sess:
400 | saver.restore(sess, latest_ckpt_fp)
401 |
402 | _fetches = sess.run(fetches, feeds)
403 |
404 | _step = _fetches['step']
405 |
406 | preview_fp = os.path.join(preview_dir, '{}.wav'.format(str(_step).zfill(8)))
407 | wavwrite(preview_fp, args.data_sample_rate, _fetches['G_z_flat_int16'])
408 |
409 | summary_writer.add_summary(_fetches['summaries'], _step)
410 |
411 | if args.wavegan_genr_pp:
412 | w, h = freqz(_fetches['pp_filter'])
413 |
414 | fig = plt.figure()
415 | plt.title('Digital filter frequncy response')
416 | ax1 = fig.add_subplot(111)
417 |
418 | plt.plot(w, 20 * np.log10(abs(h)), 'b')
419 | plt.ylabel('Amplitude [dB]', color='b')
420 | plt.xlabel('Frequency [rad/sample]')
421 |
422 | ax2 = ax1.twinx()
423 | angles = np.unwrap(np.angle(h))
424 | plt.plot(w, angles, 'g')
425 | plt.ylabel('Angle (radians)', color='g')
426 | plt.grid()
427 | plt.axis('tight')
428 |
429 | _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8)))
430 | plt.savefig(_pp_fp)
431 |
432 | with tf.Session() as sess:
433 | _summary = sess.run(pp_summary, {pp_fp: _pp_fp})
434 | summary_writer.add_summary(_summary, _step)
435 |
436 | print('Done')
437 |
438 | ckpt_fp = latest_ckpt_fp
439 |
440 | time.sleep(1)
441 |
442 |
443 | """
444 | Computes inception score every time a checkpoint is saved
445 | """
446 | def incept(args):
447 | incept_dir = os.path.join(args.train_dir, 'incept')
448 | if not os.path.isdir(incept_dir):
449 | os.makedirs(incept_dir)
450 |
451 | # Load GAN graph
452 | gan_graph = tf.Graph()
453 | with gan_graph.as_default():
454 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
455 | gan_saver = tf.train.import_meta_graph(infer_metagraph_fp)
456 | score_saver = tf.train.Saver(max_to_keep=1)
457 | gan_z = gan_graph.get_tensor_by_name('z:0')
458 | gan_G_z = gan_graph.get_tensor_by_name('G_z:0')[:, :, 0]
459 | gan_step = gan_graph.get_tensor_by_name('global_step:0')
460 |
461 | # Load or generate latents
462 | z_fp = os.path.join(incept_dir, 'z.pkl')
463 | if os.path.exists(z_fp):
464 | with open(z_fp, 'rb') as f:
465 | _zs = pickle.load(f)
466 | else:
467 | gan_samp_z_n = gan_graph.get_tensor_by_name('samp_z_n:0')
468 | gan_samp_z = gan_graph.get_tensor_by_name('samp_z:0')
469 | with tf.Session(graph=gan_graph) as sess:
470 | _zs = sess.run(gan_samp_z, {gan_samp_z_n: args.incept_n})
471 | with open(z_fp, 'wb') as f:
472 | pickle.dump(_zs, f)
473 |
474 | # Load classifier graph
475 | incept_graph = tf.Graph()
476 | with incept_graph.as_default():
477 | incept_saver = tf.train.import_meta_graph(args.incept_metagraph_fp)
478 | incept_x = incept_graph.get_tensor_by_name('x:0')
479 | incept_preds = incept_graph.get_tensor_by_name('scores:0')
480 | incept_sess = tf.Session(graph=incept_graph)
481 | incept_saver.restore(incept_sess, args.incept_ckpt_fp)
482 |
483 | # Create summaries
484 | summary_graph = tf.Graph()
485 | with summary_graph.as_default():
486 | incept_mean = tf.placeholder(tf.float32, [])
487 | incept_std = tf.placeholder(tf.float32, [])
488 | summaries = [
489 | tf.summary.scalar('incept_mean', incept_mean),
490 | tf.summary.scalar('incept_std', incept_std)
491 | ]
492 | summaries = tf.summary.merge(summaries)
493 | summary_writer = tf.summary.FileWriter(incept_dir)
494 |
495 | # Loop, waiting for checkpoints
496 | ckpt_fp = None
497 | _best_score = 0.
498 | while True:
499 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
500 | if latest_ckpt_fp != ckpt_fp:
501 | print('Incept: {}'.format(latest_ckpt_fp))
502 |
503 | sess = tf.Session(graph=gan_graph)
504 |
505 | gan_saver.restore(sess, latest_ckpt_fp)
506 |
507 | _step = sess.run(gan_step)
508 |
509 | _G_zs = []
510 | for i in xrange(0, args.incept_n, 100):
511 | _G_zs.append(sess.run(gan_G_z, {gan_z: _zs[i:i+100]}))
512 | _G_zs = np.concatenate(_G_zs, axis=0)
513 |
514 | _preds = []
515 | for i in xrange(0, args.incept_n, 100):
516 | _preds.append(incept_sess.run(incept_preds, {incept_x: _G_zs[i:i+100]}))
517 | _preds = np.concatenate(_preds, axis=0)
518 |
519 | # Split into k groups
520 | _incept_scores = []
521 | split_size = args.incept_n // args.incept_k
522 | for i in xrange(args.incept_k):
523 | _split = _preds[i * split_size:(i + 1) * split_size]
524 | _kl = _split * (np.log(_split) - np.log(np.expand_dims(np.mean(_split, 0), 0)))
525 | _kl = np.mean(np.sum(_kl, 1))
526 | _incept_scores.append(np.exp(_kl))
527 |
528 | _incept_mean, _incept_std = np.mean(_incept_scores), np.std(_incept_scores)
529 |
530 | # Summarize
531 | with tf.Session(graph=summary_graph) as summary_sess:
532 | _summaries = summary_sess.run(summaries, {incept_mean: _incept_mean, incept_std: _incept_std})
533 | summary_writer.add_summary(_summaries, _step)
534 |
535 | # Save
536 | if _incept_mean > _best_score:
537 | score_saver.save(sess, os.path.join(incept_dir, 'best_score'), _step)
538 | _best_score = _incept_mean
539 |
540 | sess.close()
541 |
542 | print('Done')
543 |
544 | ckpt_fp = latest_ckpt_fp
545 |
546 | time.sleep(1)
547 |
548 | incept_sess.close()
549 |
550 |
551 | if __name__ == '__main__':
552 | import argparse
553 | import glob
554 | import sys
555 |
556 | parser = argparse.ArgumentParser()
557 |
558 | parser.add_argument('mode', type=str, choices=['train', 'preview', 'incept', 'infer'])
559 | parser.add_argument('train_dir', type=str,
560 | help='Training directory')
561 |
562 | data_args = parser.add_argument_group('Data')
563 | data_args.add_argument('--data_dir', type=str,
564 | help='Data directory containing *only* audio files to load')
565 | data_args.add_argument('--data_sample_rate', type=int,
566 | help='Number of audio samples per second')
567 | data_args.add_argument('--data_slice_len', type=int, choices=[16384, 32768, 65536],
568 | help='Number of audio samples per slice (maximum generation length)')
569 | data_args.add_argument('--data_num_channels', type=int,
570 | help='Number of audio channels to generate (for >2, must match that of data)')
571 | data_args.add_argument('--data_overlap_ratio', type=float,
572 | help='Overlap ratio [0, 1) between slices')
573 | data_args.add_argument('--data_first_slice', action='store_true', dest='data_first_slice',
574 | help='If set, only use the first slice each audio example')
575 | data_args.add_argument('--data_pad_end', action='store_true', dest='data_pad_end',
576 | help='If set, use zero-padded partial slices from the end of each audio file')
577 | data_args.add_argument('--data_normalize', action='store_true', dest='data_normalize',
578 | help='If set, normalize the training examples')
579 | data_args.add_argument('--data_fast_wav', action='store_true', dest='data_fast_wav',
580 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa')
581 | data_args.add_argument('--data_prefetch_gpu_num', type=int,
582 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)')
583 |
584 | wavegan_args = parser.add_argument_group('WaveGAN')
585 | wavegan_args.add_argument('--wavegan_latent_dim', type=int,
586 | help='Number of dimensions of the latent space')
587 | wavegan_args.add_argument('--wavegan_kernel_len', type=int,
588 | help='Length of 1D filter kernels')
589 | wavegan_args.add_argument('--wavegan_dim', type=int,
590 | help='Dimensionality multiplier for model of G and D')
591 | wavegan_args.add_argument('--num_categ', type=int,
592 | help='Number of categorical variables')
593 | wavegan_args.add_argument('--wavegan_batchnorm', action='store_true', dest='wavegan_batchnorm',
594 | help='Enable batchnorm')
595 | wavegan_args.add_argument('--wavegan_disc_nupdates', type=int,
596 | help='Number of discriminator updates per generator update')
597 | wavegan_args.add_argument('--wavegan_loss', type=str, choices=['dcgan', 'lsgan', 'wgan', 'wgan-gp'],
598 | help='Which GAN loss to use')
599 | wavegan_args.add_argument('--wavegan_genr_upsample', type=str, choices=['zeros', 'nn'],
600 | help='Generator upsample strategy')
601 | wavegan_args.add_argument('--wavegan_genr_pp', action='store_true', dest='wavegan_genr_pp',
602 | help='If set, use post-processing filter')
603 | wavegan_args.add_argument('--wavegan_genr_pp_len', type=int,
604 | help='Length of post-processing filter for DCGAN')
605 | wavegan_args.add_argument('--wavegan_disc_phaseshuffle', type=int,
606 | help='Radius of phase shuffle operation')
607 |
608 | train_args = parser.add_argument_group('Train')
609 | train_args.add_argument('--train_batch_size', type=int,
610 | help='Batch size')
611 | train_args.add_argument('--train_save_secs', type=int,
612 | help='How often to save model')
613 | train_args.add_argument('--train_summary_secs', type=int,
614 | help='How often to report summaries')
615 |
616 | preview_args = parser.add_argument_group('Preview')
617 | preview_args.add_argument('--preview_n', type=int,
618 | help='Number of samples to preview')
619 |
620 | incept_args = parser.add_argument_group('Incept')
621 | incept_args.add_argument('--incept_metagraph_fp', type=str,
622 | help='Inference model for inception score')
623 | incept_args.add_argument('--incept_ckpt_fp', type=str,
624 | help='Checkpoint for inference model')
625 | incept_args.add_argument('--incept_n', type=int,
626 | help='Number of generated examples to test')
627 | incept_args.add_argument('--incept_k', type=int,
628 | help='Number of groups to test')
629 |
630 | parser.set_defaults(
631 | data_dir=None,
632 | data_sample_rate=16000,
633 | data_slice_len=16384,
634 | data_num_channels=1,
635 | data_overlap_ratio=0.,
636 | data_first_slice=False,
637 | data_pad_end=False,
638 | data_normalize=False,
639 | data_fast_wav=False,
640 | data_prefetch_gpu_num=0,
641 | wavegan_latent_dim=100,
642 | wavegan_kernel_len=25,
643 | wavegan_dim=64,
644 | num_categ=3,
645 | wavegan_batchnorm=False,
646 | wavegan_disc_nupdates=5,
647 | wavegan_loss='wgan-gp',
648 | wavegan_genr_upsample='zeros',
649 | wavegan_genr_pp=False,
650 | wavegan_genr_pp_len=512,
651 | wavegan_disc_phaseshuffle=2,
652 | train_batch_size=64,
653 | train_save_secs=300,
654 | train_summary_secs=120,
655 | preview_n=32,
656 | incept_metagraph_fp='./eval/inception/infer.meta',
657 | incept_ckpt_fp='./eval/inception/best_acc-103005',
658 | incept_n=5000,
659 | incept_k=10)
660 |
661 | args = parser.parse_args()
662 |
663 | # Make train dir
664 | if not os.path.isdir(args.train_dir):
665 | os.makedirs(args.train_dir)
666 |
667 | # Save args
668 | with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f:
669 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
670 |
671 | # Make model kwarg dicts
672 | setattr(args, 'wavegan_g_kwargs', {
673 | 'slice_len': args.data_slice_len,
674 | 'nch': args.data_num_channels,
675 | 'kernel_len': args.wavegan_kernel_len,
676 | 'dim': args.wavegan_dim,
677 | 'use_batchnorm': args.wavegan_batchnorm,
678 | 'upsample': args.wavegan_genr_upsample
679 | })
680 | setattr(args, 'wavegan_d_kwargs', {
681 | 'kernel_len': args.wavegan_kernel_len,
682 | 'dim': args.wavegan_dim,
683 | 'use_batchnorm': args.wavegan_batchnorm,
684 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle
685 | })
686 | setattr(args, 'wavegan_q_kwargs', {
687 | 'kernel_len': args.wavegan_kernel_len,
688 | 'dim': args.wavegan_dim,
689 | 'use_batchnorm': args.wavegan_batchnorm,
690 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle,
691 | 'num_categ': args.num_categ
692 | })
693 |
694 | if args.mode == 'train':
695 | fps = glob.glob(os.path.join(args.data_dir, '*'))
696 | if len(fps) == 0:
697 | raise Exception('Did not find any audio files in specified directory')
698 | print('Found {} audio files in specified directory'.format(len(fps)))
699 | infer(args)
700 | train(fps, args)
701 | elif args.mode == 'preview':
702 | preview(args)
703 | elif args.mode == 'incept':
704 | incept(args)
705 | elif args.mode == 'infer':
706 | infer(args)
707 | else:
708 | raise NotImplementedError()
709 |
--------------------------------------------------------------------------------
/train_ciwgan.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | try:
4 | import cPickle as pickle
5 | except:
6 | import pickle
7 | from functools import reduce
8 | import os
9 | import time
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 | from six.moves import xrange
14 |
15 | import loader
16 | from cinfowavegan import WaveGANGenerator, WaveGANDiscriminator, WaveGANQ
17 |
18 |
19 | """
20 | Trains a WaveGAN
21 | """
22 | def train(fps, args):
23 | with tf.name_scope('loader'):
24 | x = loader.decode_extract_and_batch(
25 | fps,
26 | batch_size=args.train_batch_size,
27 | slice_len=args.data_slice_len,
28 | decode_fs=args.data_sample_rate,
29 | decode_num_channels=args.data_num_channels,
30 | decode_fast_wav=args.data_fast_wav,
31 | decode_parallel_calls=4,
32 | slice_randomize_offset=False if args.data_first_slice else True,
33 | slice_first_only=args.data_first_slice,
34 | slice_overlap_ratio=0. if args.data_first_slice else args.data_overlap_ratio,
35 | slice_pad_end=True if args.data_first_slice else args.data_pad_end,
36 | repeat=True,
37 | shuffle=True,
38 | shuffle_buffer_size=4096,
39 | prefetch_size=args.train_batch_size * 4,
40 | prefetch_gpu_num=args.data_prefetch_gpu_num)[:, :, 0]
41 |
42 | # Make z vector
43 | def random_c():
44 | idxs = np.random.randint(args.num_categ, size=args.train_batch_size)
45 | c = np.zeros((args.train_batch_size, args.num_categ))
46 | c[np.arange(args.train_batch_size), idxs] = 1
47 | return c
48 | def random_z():
49 | rz = np.zeros([args.train_batch_size, args.wavegan_latent_dim])
50 | rz[:, : args.num_categ] = random_c()
51 | rz[:, args.num_categ : ] = np.random.uniform(-1., 1., size=(args.train_batch_size, args.wavegan_latent_dim - args.num_categ))
52 | return rz;
53 |
54 | z = tf.placeholder(tf.float32, (args.train_batch_size, args.wavegan_latent_dim))
55 |
56 | # Make generator
57 | with tf.variable_scope('G'):
58 | G_z = WaveGANGenerator(z, train=True, **args.wavegan_g_kwargs)
59 | if args.wavegan_genr_pp:
60 | with tf.variable_scope('pp_filt'):
61 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
62 | G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='G')
63 |
64 | # Print G summary
65 | print('-' * 80)
66 | print('Generator vars')
67 | nparams = 0
68 | for v in G_vars:
69 | v_shape = v.get_shape().as_list()
70 | v_n = reduce(lambda x, y: x * y, v_shape)
71 | nparams += v_n
72 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
73 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
74 |
75 | # Summarize
76 | tf.summary.audio('x', x, args.data_sample_rate)
77 | tf.summary.audio('G_z', G_z, args.data_sample_rate)
78 | G_z_rms = tf.sqrt(tf.reduce_mean(tf.square(G_z[:, :, 0]), axis=1))
79 | x_rms = tf.sqrt(tf.reduce_mean(tf.square(x[:, :, 0]), axis=1))
80 | tf.summary.histogram('x_rms_batch', x_rms)
81 | tf.summary.histogram('G_z_rms_batch', G_z_rms)
82 | tf.summary.scalar('x_rms', tf.reduce_mean(x_rms))
83 | tf.summary.scalar('G_z_rms', tf.reduce_mean(G_z_rms))
84 |
85 | # Make real discriminator
86 | with tf.name_scope('D_x'), tf.variable_scope('D'):
87 | D_x = WaveGANDiscriminator(x, **args.wavegan_d_kwargs)
88 | D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='D')
89 |
90 | # Print D summary
91 | print('-' * 80)
92 | print('Discriminator vars')
93 | nparams = 0
94 | for v in D_vars:
95 | v_shape = v.get_shape().as_list()
96 | v_n = reduce(lambda x, y: x * y, v_shape)
97 | nparams += v_n
98 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
99 | print('Total params: {} ({:.2f} MB)'.format(nparams, (float(nparams) * 4) / (1024 * 1024)))
100 | print('-' * 80)
101 |
102 |
103 |
104 | # Make fake discriminator
105 | with tf.name_scope('D_G_z'), tf.variable_scope('D', reuse=True):
106 | D_G_z = WaveGANDiscriminator(G_z, **args.wavegan_d_kwargs)
107 |
108 | # Make Q
109 | with tf.variable_scope('Q'):
110 | Q_G_z = WaveGANQ(G_z, **args.wavegan_q_kwargs)
111 | Q_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Q')
112 |
113 | # Print Q summary
114 | print('Q vars')
115 | nparams = 0
116 | for v in Q_vars:
117 | v_shape = v.get_shape().as_list()
118 | v_n = reduce(lambda x, y: x * y, v_shape)
119 | print('{} ({}): {}'.format(v.get_shape().as_list(), v_n, v.name))
120 | print('-' * 80)
121 |
122 | # Create loss
123 | D_clip_weights = None
124 | if args.wavegan_loss == 'dcgan':
125 | fake = tf.zeros([args.train_batch_size], dtype=tf.float32)
126 | real = tf.ones([args.train_batch_size], dtype=tf.float32)
127 |
128 | G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
129 | logits=D_G_z,
130 | labels=real
131 | ))
132 |
133 | D_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
134 | logits=D_G_z,
135 | labels=fake
136 | ))
137 | D_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
138 | logits=D_x,
139 | labels=real
140 | ))
141 |
142 | D_loss /= 2.
143 | elif args.wavegan_loss == 'lsgan':
144 | G_loss = tf.reduce_mean((D_G_z - 1.) ** 2)
145 | D_loss = tf.reduce_mean((D_x - 1.) ** 2)
146 | D_loss += tf.reduce_mean(D_G_z ** 2)
147 | D_loss /= 2.
148 | elif args.wavegan_loss == 'wgan':
149 | G_loss = -tf.reduce_mean(D_G_z)
150 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)
151 |
152 | with tf.name_scope('D_clip_weights'):
153 | clip_ops = []
154 | for var in D_vars:
155 | clip_bounds = [-.01, .01]
156 | clip_ops.append(
157 | tf.assign(
158 | var,
159 | tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])
160 | )
161 | )
162 | D_clip_weights = tf.group(*clip_ops)
163 | elif args.wavegan_loss == 'wgan-gp':
164 |
165 | def q_cost_tf(z, q):
166 | z_cat = z[:, : args.num_categ]
167 | q_cat = q[:, : args.num_categ]
168 | lcat = tf.nn.softmax_cross_entropy_with_logits(labels=z_cat, logits=q_cat)
169 | return tf.reduce_mean(lcat);
170 |
171 |
172 | G_loss = -tf.reduce_mean(D_G_z)
173 | D_loss = tf.reduce_mean(D_G_z) - tf.reduce_mean(D_x)
174 | Q_loss = q_cost_tf(z, Q_G_z)
175 |
176 | alpha = tf.random_uniform(shape=[args.train_batch_size, 1, 1], minval=0., maxval=1.)
177 | differences = G_z - x
178 | interpolates = x + (alpha * differences)
179 | with tf.name_scope('D_interp'), tf.variable_scope('D', reuse=True):
180 | D_interp = WaveGANDiscriminator(interpolates, **args.wavegan_d_kwargs)
181 |
182 | LAMBDA = 10
183 | gradients = tf.gradients(D_interp, [interpolates])[0]
184 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2]))
185 | gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2.)
186 | D_loss += LAMBDA * gradient_penalty
187 | else:
188 | raise NotImplementedError()
189 |
190 | tf.summary.scalar('G_loss', G_loss)
191 | tf.summary.scalar('D_loss', D_loss)
192 | tf.summary.scalar('Q_loss', Q_loss)
193 |
194 | # Create (recommended) optimizer
195 | if args.wavegan_loss == 'dcgan':
196 | G_opt = tf.train.AdamOptimizer(
197 | learning_rate=2e-4,
198 | beta1=0.5)
199 | D_opt = tf.train.AdamOptimizer(
200 | learning_rate=2e-4,
201 | beta1=0.5)
202 | elif args.wavegan_loss == 'lsgan':
203 | G_opt = tf.train.RMSPropOptimizer(
204 | learning_rate=1e-4)
205 | D_opt = tf.train.RMSPropOptimizer(
206 | learning_rate=1e-4)
207 | elif args.wavegan_loss == 'wgan':
208 | G_opt = tf.train.RMSPropOptimizer(
209 | learning_rate=5e-5)
210 | D_opt = tf.train.RMSPropOptimizer(
211 | learning_rate=5e-5)
212 | elif args.wavegan_loss == 'wgan-gp':
213 | G_opt = tf.train.AdamOptimizer(
214 | learning_rate=1e-4,
215 | beta1=0.5,
216 | beta2=0.9)
217 | D_opt = tf.train.AdamOptimizer(
218 | learning_rate=1e-4,
219 | beta1=0.5,
220 | beta2=0.9)
221 | Q_opt = tf.train.RMSPropOptimizer(
222 | learning_rate=1e-4)
223 | else:
224 | raise NotImplementedError()
225 |
226 | # Create training ops
227 | G_train_op = G_opt.minimize(G_loss, var_list=G_vars,
228 | global_step=tf.train.get_or_create_global_step())
229 | D_train_op = D_opt.minimize(D_loss, var_list=D_vars)
230 | Q_train_op = Q_opt.minimize(Q_loss, var_list=Q_vars+G_vars)
231 |
232 | # Run training
233 | with tf.train.MonitoredTrainingSession(
234 | checkpoint_dir=args.train_dir,
235 | save_checkpoint_secs=args.train_save_secs,
236 | save_summaries_secs=args.train_summary_secs) as sess:
237 | print('-' * 80)
238 | print('Training has started. Please use \'tensorboard --logdir={}\' to monitor.'.format(args.train_dir))
239 | while True:
240 | # Train discriminator
241 | for i in xrange(args.wavegan_disc_nupdates):
242 | sess.run([D_loss,D_train_op], feed_dict={z: random_z()})
243 |
244 |
245 | # Enforce Lipschitz constraint for WGAN
246 | if D_clip_weights is not None:
247 | sess.run(D_clip_weights)
248 |
249 | # Train generator
250 | sess.run([G_loss,Q_loss,G_train_op,Q_train_op], feed_dict={z: random_z()})
251 |
252 |
253 | """
254 | Creates and saves a MetaGraphDef for simple inference
255 | Tensors:
256 | 'samp_z_n' int32 []: Sample this many latent vectors
257 | 'samp_z' float32 [samp_z_n, latent_dim]: Resultant latent vectors
258 | 'z:0' float32 [None, latent_dim]: Input latent vectors
259 | 'flat_pad:0' int32 []: Number of padding samples to use when flattening batch to a single audio file
260 | 'G_z:0' float32 [None, slice_len, 1]: Generated outputs
261 | 'G_z_int16:0' int16 [None, slice_len, 1]: Same as above but quantizied to 16-bit PCM samples
262 | 'G_z_flat:0' float32 [None, 1]: Outputs flattened into single audio file
263 | 'G_z_flat_int16:0' int16 [None, 1]: Same as above but quantized to 16-bit PCM samples
264 | Example usage:
265 | import tensorflow as tf
266 | tf.reset_default_graph()
267 |
268 | saver = tf.train.import_meta_graph('infer.meta')
269 | graph = tf.get_default_graph()
270 | sess = tf.InteractiveSession()
271 | saver.restore(sess, 'model.ckpt-10000')
272 |
273 | z_n = graph.get_tensor_by_name('samp_z_n:0')
274 | _z = sess.run(graph.get_tensor_by_name('samp_z:0'), {z_n: 10})
275 |
276 | z = graph.get_tensor_by_name('G_z:0')
277 | _G_z = sess.run(graph.get_tensor_by_name('G_z:0'), {z: _z})
278 | """
279 | def infer(args):
280 | infer_dir = os.path.join(args.train_dir, 'infer')
281 | if not os.path.isdir(infer_dir):
282 | os.makedirs(infer_dir)
283 |
284 | # Subgraph that generates latent vectors
285 | samp_z_n = tf.placeholder(tf.int32, [], name='samp_z_n')
286 | samp_z = tf.random_uniform([samp_z_n, args.wavegan_latent_dim], -1.0, 1.0, dtype=tf.float32, name='samp_z')
287 |
288 | # Input zo
289 | z = tf.placeholder(tf.float32, [None, args.wavegan_latent_dim], name='z')
290 | flat_pad = tf.placeholder(tf.int32, [], name='flat_pad')
291 |
292 | # Execute generator
293 | with tf.variable_scope('G'):
294 | G_z = WaveGANGenerator(z, train=False, **args.wavegan_g_kwargs)
295 | if args.wavegan_genr_pp:
296 | with tf.variable_scope('pp_filt'):
297 | G_z = tf.layers.conv1d(G_z, 1, args.wavegan_genr_pp_len, use_bias=False, padding='same')
298 | G_z = tf.identity(G_z, name='G_z')
299 |
300 | # Flatten batch
301 | nch = int(G_z.get_shape()[-1])
302 | G_z_padded = tf.pad(G_z, [[0, 0], [0, flat_pad], [0, 0]])
303 | G_z_flat = tf.reshape(G_z_padded, [-1, nch], name='G_z_flat')
304 |
305 | # Encode to int16
306 | def float_to_int16(x, name=None):
307 | x_int16 = x * 32767.
308 | x_int16 = tf.clip_by_value(x_int16, -32767., 32767.)
309 | x_int16 = tf.cast(x_int16, tf.int16, name=name)
310 | return x_int16
311 | G_z_int16 = float_to_int16(G_z, name='G_z_int16')
312 | G_z_flat_int16 = float_to_int16(G_z_flat, name='G_z_flat_int16')
313 |
314 | # Create saver
315 | G_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='G')
316 | global_step = tf.train.get_or_create_global_step()
317 | saver = tf.train.Saver(G_vars + [global_step])
318 |
319 | # Export graph
320 | tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')
321 |
322 | # Export MetaGraph
323 | infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
324 | tf.train.export_meta_graph(
325 | filename=infer_metagraph_fp,
326 | clear_devices=True,
327 | saver_def=saver.as_saver_def())
328 |
329 | # Reset graph (in case training afterwards)
330 | tf.reset_default_graph()
331 |
332 |
333 | """
334 | Generates a preview audio file every time a checkpoint is saved
335 | """
336 | def preview(args):
337 | import matplotlib
338 | matplotlib.use('Agg')
339 | import matplotlib.pyplot as plt
340 | from scipy.io.wavfile import write as wavwrite
341 | from scipy.signal import freqz
342 |
343 | preview_dir = os.path.join(args.train_dir, 'preview')
344 | if not os.path.isdir(preview_dir):
345 | os.makedirs(preview_dir)
346 |
347 | # Load graph
348 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
349 | graph = tf.get_default_graph()
350 | saver = tf.train.import_meta_graph(infer_metagraph_fp)
351 |
352 | # Generate or restore z_i and z_o
353 | z_fp = os.path.join(preview_dir, 'z.pkl')
354 | if os.path.exists(z_fp):
355 | with open(z_fp, 'rb') as f:
356 | _zs = pickle.load(f)
357 | else:
358 | # Sample z
359 | samp_feeds = {}
360 | samp_feeds[graph.get_tensor_by_name('samp_z_n:0')] = args.preview_n
361 | samp_fetches = {}
362 | samp_fetches['zs'] = graph.get_tensor_by_name('samp_z:0')
363 | with tf.Session() as sess:
364 | _samp_fetches = sess.run(samp_fetches, samp_feeds)
365 | _zs = _samp_fetches['zs']
366 |
367 | # Save z
368 | with open(z_fp, 'wb') as f:
369 | pickle.dump(_zs, f)
370 |
371 | # Set up graph for generating preview images
372 | feeds = {}
373 | feeds[graph.get_tensor_by_name('z:0')] = _zs
374 | feeds[graph.get_tensor_by_name('flat_pad:0')] = int(args.data_sample_rate / 2)
375 | fetches = {}
376 | fetches['step'] = tf.train.get_or_create_global_step()
377 | fetches['G_z'] = graph.get_tensor_by_name('G_z:0')
378 | fetches['G_z_flat_int16'] = graph.get_tensor_by_name('G_z_flat_int16:0')
379 | if args.wavegan_genr_pp:
380 | fetches['pp_filter'] = graph.get_tensor_by_name('G/pp_filt/conv1d/kernel:0')[:, 0, 0]
381 |
382 | # Summarize
383 | G_z = graph.get_tensor_by_name('G_z_flat:0')
384 | summaries = [
385 | tf.summary.audio('preview', tf.expand_dims(G_z, axis=0), args.data_sample_rate, max_outputs=1)
386 | ]
387 | fetches['summaries'] = tf.summary.merge(summaries)
388 | summary_writer = tf.summary.FileWriter(preview_dir)
389 |
390 | # PP Summarize
391 | if args.wavegan_genr_pp:
392 | pp_fp = tf.placeholder(tf.string, [])
393 | pp_bin = tf.read_file(pp_fp)
394 | pp_png = tf.image.decode_png(pp_bin)
395 | pp_summary = tf.summary.image('pp_filt', tf.expand_dims(pp_png, axis=0))
396 |
397 | # Loop, waiting for checkpoints
398 | ckpt_fp = None
399 | while True:
400 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
401 | if latest_ckpt_fp != ckpt_fp:
402 | print('Preview: {}'.format(latest_ckpt_fp))
403 |
404 | with tf.Session() as sess:
405 | saver.restore(sess, latest_ckpt_fp)
406 |
407 | _fetches = sess.run(fetches, feeds)
408 |
409 | _step = _fetches['step']
410 |
411 | preview_fp = os.path.join(preview_dir, '{}.wav'.format(str(_step).zfill(8)))
412 | wavwrite(preview_fp, args.data_sample_rate, _fetches['G_z_flat_int16'])
413 |
414 | summary_writer.add_summary(_fetches['summaries'], _step)
415 |
416 | if args.wavegan_genr_pp:
417 | w, h = freqz(_fetches['pp_filter'])
418 |
419 | fig = plt.figure()
420 | plt.title('Digital filter frequncy response')
421 | ax1 = fig.add_subplot(111)
422 |
423 | plt.plot(w, 20 * np.log10(abs(h)), 'b')
424 | plt.ylabel('Amplitude [dB]', color='b')
425 | plt.xlabel('Frequency [rad/sample]')
426 |
427 | ax2 = ax1.twinx()
428 | angles = np.unwrap(np.angle(h))
429 | plt.plot(w, angles, 'g')
430 | plt.ylabel('Angle (radians)', color='g')
431 | plt.grid()
432 | plt.axis('tight')
433 |
434 | _pp_fp = os.path.join(preview_dir, '{}_ppfilt.png'.format(str(_step).zfill(8)))
435 | plt.savefig(_pp_fp)
436 |
437 | with tf.Session() as sess:
438 | _summary = sess.run(pp_summary, {pp_fp: _pp_fp})
439 | summary_writer.add_summary(_summary, _step)
440 |
441 | print('Done')
442 |
443 | ckpt_fp = latest_ckpt_fp
444 |
445 | time.sleep(1)
446 |
447 |
448 | """
449 | Computes inception score every time a checkpoint is saved
450 | """
451 | def incept(args):
452 | incept_dir = os.path.join(args.train_dir, 'incept')
453 | if not os.path.isdir(incept_dir):
454 | os.makedirs(incept_dir)
455 |
456 | # Load GAN graph
457 | gan_graph = tf.Graph()
458 | with gan_graph.as_default():
459 | infer_metagraph_fp = os.path.join(args.train_dir, 'infer', 'infer.meta')
460 | gan_saver = tf.train.import_meta_graph(infer_metagraph_fp)
461 | score_saver = tf.train.Saver(max_to_keep=1)
462 | gan_z = gan_graph.get_tensor_by_name('z:0')
463 | gan_G_z = gan_graph.get_tensor_by_name('G_z:0')[:, :, 0]
464 | gan_step = gan_graph.get_tensor_by_name('global_step:0')
465 |
466 | # Load or generate latents
467 | z_fp = os.path.join(incept_dir, 'z.pkl')
468 | if os.path.exists(z_fp):
469 | with open(z_fp, 'rb') as f:
470 | _zs = pickle.load(f)
471 | else:
472 | gan_samp_z_n = gan_graph.get_tensor_by_name('samp_z_n:0')
473 | gan_samp_z = gan_graph.get_tensor_by_name('samp_z:0')
474 | with tf.Session(graph=gan_graph) as sess:
475 | _zs = sess.run(gan_samp_z, {gan_samp_z_n: args.incept_n})
476 | with open(z_fp, 'wb') as f:
477 | pickle.dump(_zs, f)
478 |
479 | # Load classifier graph
480 | incept_graph = tf.Graph()
481 | with incept_graph.as_default():
482 | incept_saver = tf.train.import_meta_graph(args.incept_metagraph_fp)
483 | incept_x = incept_graph.get_tensor_by_name('x:0')
484 | incept_preds = incept_graph.get_tensor_by_name('scores:0')
485 | incept_sess = tf.Session(graph=incept_graph)
486 | incept_saver.restore(incept_sess, args.incept_ckpt_fp)
487 |
488 | # Create summaries
489 | summary_graph = tf.Graph()
490 | with summary_graph.as_default():
491 | incept_mean = tf.placeholder(tf.float32, [])
492 | incept_std = tf.placeholder(tf.float32, [])
493 | summaries = [
494 | tf.summary.scalar('incept_mean', incept_mean),
495 | tf.summary.scalar('incept_std', incept_std)
496 | ]
497 | summaries = tf.summary.merge(summaries)
498 | summary_writer = tf.summary.FileWriter(incept_dir)
499 |
500 | # Loop, waiting for checkpoints
501 | ckpt_fp = None
502 | _best_score = 0.
503 | while True:
504 | latest_ckpt_fp = tf.train.latest_checkpoint(args.train_dir)
505 | if latest_ckpt_fp != ckpt_fp:
506 | print('Incept: {}'.format(latest_ckpt_fp))
507 |
508 | sess = tf.Session(graph=gan_graph)
509 |
510 | gan_saver.restore(sess, latest_ckpt_fp)
511 |
512 | _step = sess.run(gan_step)
513 |
514 | _G_zs = []
515 | for i in xrange(0, args.incept_n, 100):
516 | _G_zs.append(sess.run(gan_G_z, {gan_z: _zs[i:i+100]}))
517 | _G_zs = np.concatenate(_G_zs, axis=0)
518 |
519 | _preds = []
520 | for i in xrange(0, args.incept_n, 100):
521 | _preds.append(incept_sess.run(incept_preds, {incept_x: _G_zs[i:i+100]}))
522 | _preds = np.concatenate(_preds, axis=0)
523 |
524 | # Split into k groups
525 | _incept_scores = []
526 | split_size = args.incept_n // args.incept_k
527 | for i in xrange(args.incept_k):
528 | _split = _preds[i * split_size:(i + 1) * split_size]
529 | _kl = _split * (np.log(_split) - np.log(np.expand_dims(np.mean(_split, 0), 0)))
530 | _kl = np.mean(np.sum(_kl, 1))
531 | _incept_scores.append(np.exp(_kl))
532 |
533 | _incept_mean, _incept_std = np.mean(_incept_scores), np.std(_incept_scores)
534 |
535 | # Summarize
536 | with tf.Session(graph=summary_graph) as summary_sess:
537 | _summaries = summary_sess.run(summaries, {incept_mean: _incept_mean, incept_std: _incept_std})
538 | summary_writer.add_summary(_summaries, _step)
539 |
540 | # Save
541 | if _incept_mean > _best_score:
542 | score_saver.save(sess, os.path.join(incept_dir, 'best_score'), _step)
543 | _best_score = _incept_mean
544 |
545 | sess.close()
546 |
547 | print('Done')
548 |
549 | ckpt_fp = latest_ckpt_fp
550 |
551 | time.sleep(1)
552 |
553 | incept_sess.close()
554 |
555 |
556 | if __name__ == '__main__':
557 | import argparse
558 | import glob
559 | import sys
560 |
561 | parser = argparse.ArgumentParser()
562 |
563 | parser.add_argument('mode', type=str, choices=['train', 'preview', 'incept', 'infer'])
564 | parser.add_argument('train_dir', type=str,
565 | help='Training directory')
566 |
567 | data_args = parser.add_argument_group('Data')
568 | data_args.add_argument('--data_dir', type=str,
569 | help='Data directory containing *only* audio files to load')
570 | data_args.add_argument('--data_sample_rate', type=int,
571 | help='Number of audio samples per second')
572 | data_args.add_argument('--data_slice_len', type=int, choices=[16384, 32768, 65536],
573 | help='Number of audio samples per slice (maximum generation length)')
574 | data_args.add_argument('--data_num_channels', type=int,
575 | help='Number of audio channels to generate (for >2, must match that of data)')
576 | data_args.add_argument('--data_overlap_ratio', type=float,
577 | help='Overlap ratio [0, 1) between slices')
578 | data_args.add_argument('--data_first_slice', action='store_true', dest='data_first_slice',
579 | help='If set, only use the first slice each audio example')
580 | data_args.add_argument('--data_pad_end', action='store_true', dest='data_pad_end',
581 | help='If set, use zero-padded partial slices from the end of each audio file')
582 | data_args.add_argument('--data_normalize', action='store_true', dest='data_normalize',
583 | help='If set, normalize the training examples')
584 | data_args.add_argument('--data_fast_wav', action='store_true', dest='data_fast_wav',
585 | help='If your data is comprised of standard WAV files (16-bit signed PCM or 32-bit float), use this flag to decode audio using scipy (faster) instead of librosa')
586 | data_args.add_argument('--data_prefetch_gpu_num', type=int,
587 | help='If nonnegative, prefetch examples to this GPU (Tensorflow device num)')
588 |
589 | wavegan_args = parser.add_argument_group('WaveGAN')
590 | wavegan_args.add_argument('--wavegan_latent_dim', type=int,
591 | help='Number of dimensions of the latent space')
592 | wavegan_args.add_argument('--wavegan_kernel_len', type=int,
593 | help='Length of 1D filter kernels')
594 | wavegan_args.add_argument('--wavegan_dim', type=int,
595 | help='Dimensionality multiplier for model of G and D')
596 | wavegan_args.add_argument('--num_categ', type=int,
597 | help='Number of categorical variables')
598 | wavegan_args.add_argument('--wavegan_batchnorm', action='store_true', dest='wavegan_batchnorm',
599 | help='Enable batchnorm')
600 | wavegan_args.add_argument('--wavegan_disc_nupdates', type=int,
601 | help='Number of discriminator updates per generator update')
602 | wavegan_args.add_argument('--wavegan_loss', type=str, choices=['dcgan', 'lsgan', 'wgan', 'wgan-gp'],
603 | help='Which GAN loss to use')
604 | wavegan_args.add_argument('--wavegan_genr_upsample', type=str, choices=['zeros', 'nn'],
605 | help='Generator upsample strategy')
606 | wavegan_args.add_argument('--wavegan_genr_pp', action='store_true', dest='wavegan_genr_pp',
607 | help='If set, use post-processing filter')
608 | wavegan_args.add_argument('--wavegan_genr_pp_len', type=int,
609 | help='Length of post-processing filter for DCGAN')
610 | wavegan_args.add_argument('--wavegan_disc_phaseshuffle', type=int,
611 | help='Radius of phase shuffle operation')
612 |
613 | train_args = parser.add_argument_group('Train')
614 | train_args.add_argument('--train_batch_size', type=int,
615 | help='Batch size')
616 | train_args.add_argument('--train_save_secs', type=int,
617 | help='How often to save model')
618 | train_args.add_argument('--train_summary_secs', type=int,
619 | help='How often to report summaries')
620 |
621 | preview_args = parser.add_argument_group('Preview')
622 | preview_args.add_argument('--preview_n', type=int,
623 | help='Number of samples to preview')
624 |
625 | incept_args = parser.add_argument_group('Incept')
626 | incept_args.add_argument('--incept_metagraph_fp', type=str,
627 | help='Inference model for inception score')
628 | incept_args.add_argument('--incept_ckpt_fp', type=str,
629 | help='Checkpoint for inference model')
630 | incept_args.add_argument('--incept_n', type=int,
631 | help='Number of generated examples to test')
632 | incept_args.add_argument('--incept_k', type=int,
633 | help='Number of groups to test')
634 |
635 | parser.set_defaults(
636 | data_dir=None,
637 | data_sample_rate=16000,
638 | data_slice_len=16384,
639 | data_num_channels=1,
640 | data_overlap_ratio=0.,
641 | data_first_slice=False,
642 | data_pad_end=False,
643 | data_normalize=False,
644 | data_fast_wav=False,
645 | data_prefetch_gpu_num=0,
646 | wavegan_latent_dim=100,
647 | wavegan_kernel_len=25,
648 | wavegan_dim=64,
649 | num_categ=10,
650 | wavegan_batchnorm=False,
651 | wavegan_disc_nupdates=5,
652 | wavegan_loss='wgan-gp',
653 | wavegan_genr_upsample='zeros',
654 | wavegan_genr_pp=False,
655 | wavegan_genr_pp_len=512,
656 | wavegan_disc_phaseshuffle=2,
657 | train_batch_size=64,
658 | train_save_secs=300,
659 | train_summary_secs=120,
660 | preview_n=32,
661 | incept_metagraph_fp='./eval/inception/infer.meta',
662 | incept_ckpt_fp='./eval/inception/best_acc-103005',
663 | incept_n=5000,
664 | incept_k=10)
665 |
666 | args = parser.parse_args()
667 |
668 | # Make train dir
669 | if not os.path.isdir(args.train_dir):
670 | os.makedirs(args.train_dir)
671 |
672 | # Save args
673 | with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f:
674 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
675 |
676 | # Make model kwarg dicts
677 | setattr(args, 'wavegan_g_kwargs', {
678 | 'slice_len': args.data_slice_len,
679 | 'nch': args.data_num_channels,
680 | 'kernel_len': args.wavegan_kernel_len,
681 | 'dim': args.wavegan_dim,
682 | 'use_batchnorm': args.wavegan_batchnorm,
683 | 'upsample': args.wavegan_genr_upsample
684 | })
685 | setattr(args, 'wavegan_d_kwargs', {
686 | 'kernel_len': args.wavegan_kernel_len,
687 | 'dim': args.wavegan_dim,
688 | 'use_batchnorm': args.wavegan_batchnorm,
689 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle
690 | })
691 | setattr(args, 'wavegan_q_kwargs', {
692 | 'kernel_len': args.wavegan_kernel_len,
693 | 'dim': args.wavegan_dim,
694 | 'use_batchnorm': args.wavegan_batchnorm,
695 | 'phaseshuffle_rad': args.wavegan_disc_phaseshuffle,
696 | 'num_categ': args.num_categ
697 | })
698 |
699 | if args.mode == 'train':
700 | fps = glob.glob(os.path.join(args.data_dir, '*'))
701 | if len(fps) == 0:
702 | raise Exception('Did not find any audio files in specified directory')
703 | print('Found {} audio files in specified directory'.format(len(fps)))
704 | infer(args)
705 | train(fps, args)
706 | elif args.mode == 'preview':
707 | preview(args)
708 | elif args.mode == 'incept':
709 | incept(args)
710 | elif args.mode == 'infer':
711 | infer(args)
712 | else:
713 | raise NotImplementedError()
714 |
--------------------------------------------------------------------------------