├── README.md ├── generation.ipynb ├── utils.py ├── train_cwavegan.py ├── wgangp.py └── cwawegan_architecture.py /README.md: -------------------------------------------------------------------------------- 1 | #### Keras implementation of conditional waveGAN. Application to knocking sound effects. 2 | 3 | Original waveGAN architecture: https://github.com/chrisdonahue/wavegan 4 | 5 | #### Requirements 6 | ``` 7 | Tensorflow >= 2.0 8 | Librosa 9 | ``` 10 | 11 | ##### Using the knocking sound effects with emotion dataset or your own sounds. 12 | 13 | We focused on knocking sound effects and recorded a dataset to train the model. If you want to train your model on the knocking sound effects with emotion dataset you can download it from [here](https://zenodo.org/record/3668503) and put it on an '/audio' subdirectory. 14 | 15 | If you want to use your own sounds, just place your .wav files (organised in folders for your labels) on an '/audio' subdirectory. If for instance you want to train the conditional waveGAN on footsteps on concrete and grass, put your sounds in '/audio/concrete' and '/audio/grass'. 16 | 17 | ##### Training 18 | 19 | To train the model just run ``` train_cwavegan.py ```. You can see and edit the parameters/hyperparameters of the model directly in the python file. Depending on your dataset you will probably want to change (at least) the architecture size and the sampling rate. Once you start training, you will find a date/time folder in the ```checkpoints``` directory. Inside you will find your saved model, a file with the list of the parameters used and a dictionary with the labels (for inference). 20 | 21 | 22 | ##### Synthesising audio 23 | 24 | Once the model is trained, just use the trained generator. You can find an example on how to use it on the generation notebook. 25 | -------------------------------------------------------------------------------- /generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Synthesising single samples from a trained model" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import tensorflow as tf\n", 17 | "import numpy as np\n", 18 | "import json\n", 19 | "from IPython.display import display, Audio\n", 20 | "from tqdm import tqdm\n", 21 | "import librosa" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "#### Get the trained model and class labels" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "path_to_generator = 'generator_good.h5'\n", 38 | "path_to_labels = 'label_names.json'\n", 39 | "z_dim = 100\n", 40 | "sample_rate = 22050" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "#load the generator\n", 50 | "generator = tf.keras.models.load_model(path_to_generator)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "#read the labels from the generated dictionary during training\n", 60 | "with open(path_to_labels) as json_file:\n", 61 | " label_names = json.load(json_file)\n", 62 | "label_names" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "#### Generating a single sample (with label)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "#create noise and label\n", 79 | "label = 0\n", 80 | "noise = np.random.normal(0,1, (1, z_dim))\n", 81 | "label_synth = np.array(label).reshape(-1,1)\n", 82 | "\n", 83 | "#synthesise the audio\n", 84 | "%time synth_audio = generator.predict([noise, label_synth])\n", 85 | "\n", 86 | "#listen to the synthesised audio\n", 87 | "display(Audio(np.squeeze(synth_audio[0]), rate = sample_rate))" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### Batch generation" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "#how many samples per label\n", 104 | "n_samples_label = 100" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "for emotion in tqdm(label_names):\n", 114 | " noise = tf.random.normal(shape=(n_samples_label, z_dim))\n", 115 | " label_synth = tf.constant(int(emotion), shape=(n_samples_label,1))\n", 116 | " synth_audio = generator.predict([noise, label_synth])\n", 117 | " for i in range(n_samples_label):\n", 118 | " librosa.output.write_wav(f'{label_names[emotion]}_{i}.wav', y = np.squeeze(synth_audio[i]), sr = sample_rate, norm=False) " 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "Python 3", 132 | "language": "python", 133 | "name": "python3" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 3 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython3", 145 | "version": "3.6.9" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 4 150 | } 151 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import json 4 | import os 5 | from datetime import datetime 6 | 7 | #get the number of classes from the number of folders in the audio dir 8 | def get_n_classes(audio_path): 9 | root, dirs, files = next(os.walk(audio_path)) 10 | n_classes = len(dirs) 11 | print(f'Found {n_classes} different classes in {audio_path}') 12 | return n_classes 13 | 14 | #load the audio. Pad the audio if the file is shorter than the maximum architecture capacity 15 | def load_audio(audio_path, sr, audio_size_samples): 16 | X_audio, _ = librosa.load(audio_path, sr = sr) 17 | if X_audio.size < audio_size_samples: 18 | padding = audio_size_samples - X_audio.size 19 | X_audio = np.pad(X_audio, (0, padding), mode = 'constant') 20 | elif (X_audio.size >= audio_size_samples): 21 | X_audio = X_audio[0:audio_size_samples] 22 | return X_audio 23 | 24 | #save the label names for inference 25 | def save_label_names(audio_path, save_folder): 26 | label_names = {} 27 | for i, folder in enumerate(next(os.walk(audio_path))[1]): 28 | label_names[i] = folder 29 | #save the dictionary to use it later with the standalone generator 30 | with open(os.path.join(save_folder, 'label_names.json'), 'w') as outfile: 31 | json.dump(label_names, outfile) 32 | 33 | #create the dataset from the audio path folder 34 | def create_dataset(audio_path, sample_rate, architecture_size, labels_saving_path): 35 | 36 | if architecture_size == 'large': 37 | audio_size_samples = 65536 38 | elif architecture_size == 'medium': 39 | audio_size_samples = 32768 40 | elif architecture_size == 'small': 41 | audio_size_samples = 16384 42 | 43 | #save the label names in a dict 44 | save_label_names(audio_path, labels_saving_path) 45 | audio = [] 46 | labels_names = [] 47 | for folder in next(os.walk(audio_path))[1]: 48 | for wavfile in os.listdir(audio_path+folder): 49 | audio.append(load_audio(audio_path = f'{audio_path}{folder}/{wavfile}', sr = sample_rate, audio_size_samples = audio_size_samples)) 50 | labels_names.append(folder) 51 | audio_np = np.asarray(audio) 52 | audio_np = np.expand_dims(audio_np, axis = -1) 53 | labels = np.unique(labels_names, return_inverse=True)[1] 54 | labels_np = np.expand_dims(labels, axis = -1) 55 | 56 | return audio_np, labels_np 57 | 58 | #create folder with current date (to avoid overriding the synthesised audio/model when resuming the training) 59 | def create_date_folder(checkpoints_path): 60 | if not os.path.exists(checkpoints_path): 61 | os.mkdir(checkpoints_path) 62 | date = datetime.now() 63 | day = date.strftime('%d-%m-%Y_') 64 | path = f'{checkpoints_path}{day}{str(date.hour)}h' 65 | if not os.path.exists(path): 66 | os.mkdir(path) 67 | if not os.path.exists(f'{path}/synth_audio'): 68 | os.mkdir(f'{path}/synth_audio') 69 | return path 70 | 71 | #save the training arguments used to the checkpoints folder (it make it easier retrieve the hyperparameters afterwards) 72 | def write_parameters(sampling_rate, n_batches, batch_size, audio_path, checkpoints_path, 73 | architecture_size, path_to_weights, resume_training, override_saved_model, synth_frequency, 74 | save_frequency, latent_dim, use_batch_norm, discriminator_learning_rate, generator_learning_rate, 75 | discriminator_extra_steps, phaseshuffle_samples): 76 | print(f'Saving the training parameters to disk in {checkpoints_path}/training_parameters.txt') 77 | arguments = open(f'{checkpoints_path}/training_parameters.txt', "w") 78 | arguments.write(f'sampling_rate = {sampling_rate}\n') 79 | arguments.write(f'n_batches = {n_batches}\n') 80 | arguments.write(f'batch_size = {batch_size}\n') 81 | arguments.write(f'audio_path = {audio_path}\n') 82 | arguments.write(f'checkpoints_path = {checkpoints_path}\n') 83 | arguments.write(f'architecture_size = {architecture_size}\n') 84 | arguments.write(f'path_to_weights = {path_to_weights}\n') 85 | arguments.write(f'resume_training = {resume_training}\n') 86 | arguments.write(f'override_saved_model = {override_saved_model}\n') 87 | arguments.write(f'synth_frequency = {synth_frequency}\n') 88 | arguments.write(f'save_frequency = {save_frequency}\n') 89 | arguments.write(f'latent_dim = {latent_dim}\n') 90 | arguments.write(f'use_batch_norm = {use_batch_norm}\n') 91 | arguments.write(f'discriminator_learning_rate = {discriminator_learning_rate}\n') 92 | arguments.write(f'generator_learning_rate = {generator_learning_rate}\n') 93 | arguments.write(f'discriminator_extra_steps = {discriminator_extra_steps}\n') 94 | arguments.write(f'phaseshuffle_samples = {phaseshuffle_samples}') 95 | arguments.close() -------------------------------------------------------------------------------- /train_cwavegan.py: -------------------------------------------------------------------------------- 1 | import cwawegan_architecture 2 | import wgangp 3 | import utils 4 | from tensorflow.keras.optimizers import Adam 5 | 6 | def train_model(sampling_rate = 22050, 7 | n_batches = 10000, 8 | batch_size = 128, 9 | audio_path = 'audio/', 10 | checkpoints_path = 'checkpoints/', 11 | architecture_size = 'large', 12 | resume_training = False, 13 | path_to_weights = 'checkpoints/model_weights.h5', 14 | override_saved_model = False, 15 | synth_frequency = 200, 16 | save_frequency = 200, 17 | latent_dim = 100, 18 | use_batch_norm = False, 19 | discriminator_learning_rate = 0.00004, 20 | generator_learning_rate = 0.00004, 21 | discriminator_extra_steps = 5, 22 | phaseshuffle_samples = 0): 23 | 24 | ''' 25 | Train the conditional WaveGAN architecture. 26 | Args: 27 | sampling_rate (int): Sampling rate of the loaded/synthesised audio. 28 | n_batches (int): Number of batches to train for. 29 | batch_size (int): batch size (for the training process). 30 | audio_path (str): Path where your training data (wav files) are store. 31 | Each class should be in a folder with the class name 32 | checkpoints_path (str): Path to save the model / synth the audio during training 33 | architecture_size (str) = size of the wavegan architecture. Eeach size processes the following number 34 | of audio samples: 'small' = 16384, 'medium' = 32768, 'large' = 65536" 35 | resume_training (bool) = Restore the model weights from a previous session? 36 | path_to_weights (str) = Where the model weights are (when resuming training) 37 | override_saved_model (bool) = save the model overwriting 38 | the previous saved model (in a past epoch)?. Be aware the saved files could be large! 39 | synth_frequency (int): How often do you want to synthesise a sample during training (in batches). 40 | save_frequency (int): How often do you want to save the model during training (in batches). 41 | latent_dim (int): Dimension of the latent space. 42 | use_batch_norm (bool): Use batch normalization? 43 | discriminator_learning_rate (float): Discriminator learning rate. 44 | generator_learning_rate (float): Generator learning rate. 45 | discriminator_extra_steps (int): How many steps the discriminator is trained per step of the generator. 46 | phaseshuffle_samples (int): Discriminator phase shuffle. 0 for no phases shuffle. 47 | ''' 48 | 49 | #get the number of classes from the audio folder 50 | n_classes = utils.get_n_classes(audio_path) 51 | 52 | #build the discriminator 53 | discriminator = cwawegan_architecture.discriminator(architecture_size=architecture_size, 54 | phaseshuffle_samples = phaseshuffle_samples, 55 | n_classes = n_classes) 56 | #build the generator 57 | generator = cwawegan_architecture.generator(architecture_size=architecture_size, 58 | z_dim = latent_dim, 59 | use_batch_norm = use_batch_norm, 60 | n_classes = n_classes) 61 | #set the optimizers 62 | discriminator_optimizer = Adam(learning_rate = discriminator_learning_rate) 63 | generator_optimizer = Adam(learning_rate = generator_learning_rate) 64 | 65 | #build the gan 66 | gan = wgangp.WGANGP(latent_dim=latent_dim, discriminator=discriminator, generator=generator, 67 | n_classes = n_classes, discriminator_extra_steps = discriminator_extra_steps, 68 | d_optimizer = discriminator_optimizer, g_optimizer = generator_optimizer) 69 | 70 | # Compile the wgan model 71 | gan.compile( 72 | d_optimizer=discriminator_optimizer, 73 | g_optimizer=generator_optimizer) 74 | 75 | #make a folder with the current date to store the current session to 76 | #avoid overriding past synth audio files and checkpoints 77 | checkpoints_path = utils.create_date_folder(checkpoints_path) 78 | 79 | #save the training parameters used to the checkpoints folder, 80 | #it makes it easier to retrieve the parameters/hyperparameters afterwards 81 | utils.write_parameters(sampling_rate, n_batches, batch_size, audio_path, checkpoints_path, architecture_size, 82 | path_to_weights, resume_training, override_saved_model, synth_frequency, save_frequency, 83 | latent_dim, use_batch_norm, discriminator_learning_rate, generator_learning_rate, 84 | discriminator_extra_steps, phaseshuffle_samples) 85 | 86 | #create the dataset from the class folders in '/audio' 87 | audio, labels = utils.create_dataset(audio_path, sampling_rate, architecture_size, checkpoints_path) 88 | 89 | #load the desired weights in path (if resuming training) 90 | if resume_training == True: 91 | print(f'Resuming training. Loading weights in {path_to_weights}') 92 | gan.load_weights(path_to_weights) 93 | 94 | #train the gan for the desired number of batches 95 | gan.train(x = audio, y = labels, batch_size = batch_size, batches = n_batches, 96 | synth_frequency = synth_frequency, save_frequency = save_frequency, 97 | checkpoints_path = checkpoints_path, override_saved_model = override_saved_model, 98 | sampling_rate = sampling_rate, n_classes = n_classes) 99 | 100 | 101 | if __name__ == '__main__': 102 | train_model(sampling_rate = 22050, 103 | n_batches = 30000, 104 | batch_size = 128, 105 | audio_path = 'audio/', 106 | checkpoints_path = 'checkpoints/', 107 | architecture_size = 'large', 108 | path_to_weights = 'model_weights.h5', 109 | resume_training = False, 110 | override_saved_model = True, 111 | synth_frequency = 200, 112 | save_frequency = 200, 113 | latent_dim = 100, 114 | use_batch_norm = True, 115 | discriminator_learning_rate = 0.0002, 116 | generator_learning_rate = 0.0002, 117 | discriminator_extra_steps = 5, 118 | phaseshuffle_samples = 0) -------------------------------------------------------------------------------- /wgangp.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import numpy as np 4 | import librosa 5 | import time 6 | 7 | #Baseline WGANGP model directly from the Keras documentation: https://keras.io/examples/generative/wgan_gp/ 8 | #Original WaveGAN: https://github.com/chrisdonahue/wavegan 9 | 10 | class WGANGP(keras.Model): 11 | def __init__( 12 | self, 13 | latent_dim, 14 | discriminator, 15 | generator, 16 | n_classes, 17 | discriminator_extra_steps=5, 18 | gp_weight=10.0, 19 | d_optimizer=tf.keras.optimizers.Adam(learning_rate = 0.0004), 20 | g_optimizer=tf.keras.optimizers.Adam(learning_rate = 0.0004) 21 | ): 22 | super(WGANGP, self).__init__() 23 | self.latent_dim = latent_dim 24 | self.discriminator = discriminator 25 | self.generator = generator 26 | self.n_classes = n_classes 27 | self.d_steps = discriminator_extra_steps 28 | self.gp_weight = gp_weight 29 | self.d_optimizer=d_optimizer 30 | self.g_optimizer=g_optimizer 31 | 32 | def compile(self, d_optimizer, g_optimizer): 33 | super(WGANGP, self).compile() 34 | self.d_optimizer = d_optimizer 35 | self.g_optimizer = g_optimizer 36 | self.d_loss_fn = self.discriminator_loss 37 | self.g_loss_fn = self.generator_loss 38 | 39 | # Define the loss functions to be used for discriminator 40 | # This should be (fake_loss - real_loss) 41 | # We will add the gradient penalty later to this loss function 42 | def discriminator_loss(self, real_img, fake_img): 43 | real_loss = tf.reduce_mean(real_img) 44 | fake_loss = tf.reduce_mean(fake_img) 45 | return fake_loss - real_loss 46 | 47 | # Define the loss functions to be used for generator 48 | def generator_loss(self, fake_img): 49 | return -tf.reduce_mean(fake_img) 50 | 51 | def gradient_penalty(self, batch_size, real_images, fake_images, labels): 52 | """ Calculates the gradient penalty. 53 | 54 | This loss is calculated on an interpolated image 55 | and added to the discriminator loss. 56 | """ 57 | # get the interplated image 58 | alpha = tf.random.normal([batch_size, 1, 1], 0.0, 1.0) 59 | diff = fake_images - real_images 60 | interpolated = real_images + alpha * diff 61 | with tf.GradientTape() as gp_tape: 62 | gp_tape.watch(interpolated) 63 | # 1. Get the discriminator output for this interpolated image. 64 | pred = self.discriminator([interpolated, labels], training=True) 65 | 66 | # 2. Calculate the gradients w.r.t to this interpolated image. 67 | grads = gp_tape.gradient(pred, [interpolated])[0] 68 | # 3. Calcuate the norm of the gradients 69 | norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2])) 70 | gp = tf.reduce_mean((norm - 1.0) ** 2) 71 | return gp 72 | 73 | def train_batch(self, x, y, batch_size): 74 | #get a random indexes for the batch 75 | idx = np.random.randint(0, x.shape[0], batch_size) 76 | real_images = x[idx] 77 | labels = y[idx] 78 | 79 | # For each batch, we are going to perform the 80 | # following steps as laid out in the original paper. 81 | # 1. Train the generator and get the generator loss 82 | # 2. Train the discriminator and get the discriminator loss 83 | # 3. Calculate the gradient penalty 84 | # 4. Multiply this gradient penalty with a constant weight factor 85 | # 5. Add gradient penalty to the discriminator loss 86 | # 6. Return generator and discriminator losses as a loss dictionary. 87 | 88 | # Train discriminator first. The original paper recommends training 89 | # the discriminator for `x` more steps (typically 5) as compared to 90 | # one step of the generator. 91 | for i in range(self.d_steps): 92 | # Get the latent vector 93 | random_latent_vectors = tf.random.normal( 94 | shape=(batch_size, self.latent_dim) 95 | ) 96 | 97 | with tf.GradientTape() as tape: 98 | # Generate fake images from the latent vector 99 | fake_images = self.generator([random_latent_vectors, labels], training=True) 100 | # Get the logits for the fake images 101 | fake_logits = self.discriminator([fake_images, labels], training=True) 102 | # Get the logits for real images 103 | real_logits = self.discriminator([real_images, labels], training=True) 104 | # Calculate discriminator loss using fake and real logits 105 | d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits) 106 | # Calculate the gradient penalty 107 | gp = self.gradient_penalty(batch_size, real_images, fake_images, labels) 108 | # Add the gradient penalty to the original discriminator loss 109 | d_loss = d_cost + gp * self.gp_weight 110 | 111 | # Get the gradients w.r.t the discriminator loss 112 | d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables) 113 | # Update the weights of the discriminator using the discriminator optimizer 114 | self.d_optimizer.apply_gradients( 115 | zip(d_gradient, self.discriminator.trainable_variables) 116 | ) 117 | 118 | # Train the generator now. 119 | # Get the latent vector 120 | random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) 121 | 122 | with tf.GradientTape() as tape: 123 | # Generate fake images using the generator 124 | generated_images = self.generator([random_latent_vectors, labels], training=True) 125 | # Get the discriminator logits for fake images 126 | gen_img_logits = self.discriminator([generated_images, labels], training=True) 127 | # Calculate the generator loss 128 | g_loss = self.g_loss_fn(gen_img_logits) 129 | 130 | # Get the gradients w.r.t the generator loss 131 | gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables) 132 | # Update the weights of the generator using the generator optimizer 133 | self.g_optimizer.apply_gradients( 134 | zip(gen_gradient, self.generator.trainable_variables) 135 | ) 136 | return d_loss, g_loss 137 | 138 | def train(self, x, y, batch_size, batches, synth_frequency, save_frequency, 139 | sampling_rate, n_classes, checkpoints_path, override_saved_model): 140 | 141 | for batch in range(batches): 142 | start_time = time.time() 143 | d_loss, g_loss = self.train_batch(x, y, batch_size) 144 | end_time = time.time() 145 | time_batch = (end_time - start_time) 146 | print(f'Batch: {batch} == Batch size: {batch_size} == Time elapsed: {time_batch:.2f} == d_loss: {d_loss:.4f}, g_loss: {g_loss:.4f}') 147 | 148 | #This works as a callback 149 | if batch % synth_frequency == 0 : 150 | print(f'Synthesising audio at batch {batch}. Path: {checkpoints_path}/synth_audio') 151 | random_latent_vectors = tf.random.normal(shape=(1, self.latent_dim)) 152 | for i in range (n_classes): 153 | generated_audio = self.generator([random_latent_vectors, np.array(i).reshape(-1,1)]) 154 | librosa.output.write_wav(f'{checkpoints_path}/synth_audio/{batch}_batch_synth_class_{i}.wav', 155 | y = tf.squeeze(generated_audio).numpy(), sr = sampling_rate, norm=False) 156 | print(f'Done.') 157 | 158 | if batch % save_frequency == 0: 159 | print(f'Saving the model at batch {batch}. Path: {checkpoints_path}') 160 | if override_saved_model == False: 161 | self.generator.save(f'{checkpoints_path}/{batch}_batch_generator.h5') 162 | self.discriminator.save(f'{checkpoints_path}/{batch}_batch_discriminator.h5') 163 | self.save_weights(f'{checkpoints_path}/{batch}_batch_weights.h5') 164 | else: 165 | self.generator.save(f'{checkpoints_path}/generator.h5') 166 | self.discriminator.save(f'{checkpoints_path}/discriminator.h5') 167 | self.save_weights(f'{checkpoints_path}/model_weights.h5') 168 | print(f'Model saved.') -------------------------------------------------------------------------------- /cwawegan_architecture.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Input, Conv1D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, LeakyReLU, ReLU, Embedding, Concatenate, BatchNormalization 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras import backend as K 4 | from tensorflow import pad, maximum, random, int32 5 | 6 | #Original WaveGAN: https://github.com/chrisdonahue/wavegan 7 | #Label embeding using the method in https://machinelearningmastery.com/how-to-develop-a-conditional-generative-adversarial-network-from-scratch/ 8 | 9 | #phase shuffle [directly from the original waveGAN implementation] 10 | def apply_phaseshuffle(args): 11 | x, rad = args 12 | pad_type = 'reflect' 13 | b, x_len, nch = x.get_shape().as_list() 14 | phase = random.uniform([], minval=-rad, maxval=rad + 1, dtype=int32) 15 | pad_l = maximum(phase, 0) 16 | pad_r = maximum(-phase, 0) 17 | phase_start = pad_r 18 | x = pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode=pad_type) 19 | 20 | x = x[:, phase_start:phase_start+x_len] 21 | x.set_shape([b, x_len, nch]) 22 | 23 | return x 24 | 25 | 26 | #TODO: clean/redo this 27 | def Conv1DTranspose(input_tensor, filters, kernel_size, strides=2, padding='same' 28 | , name = '1DTConv', activation = 'relu'): 29 | x = Conv2DTranspose(filters=filters, kernel_size=(1, kernel_size), strides=(1, strides), padding=padding, 30 | name = name, activation = activation)(K.expand_dims(input_tensor, axis=1)) 31 | x = K.squeeze(x, axis=1) 32 | return x 33 | 34 | def generator(z_dim = 100, 35 | architecture_size = 'large', 36 | use_batch_norm = False, 37 | n_classes = 5): 38 | 39 | generator_filters = [1024, 512, 256, 128, 64] 40 | 41 | label_input = Input(shape=(1,), dtype='int32', name='generator_label_input') 42 | label_em = Embedding(n_classes, n_classes * 20, name = 'label_embedding')(label_input) 43 | label_em = Dense(16, name = 'label_dense')(label_em) 44 | label_em = Reshape((16, 1), name = 'label_respahe')(label_em) 45 | 46 | generator_input = Input(shape=(z_dim,), name='generator_input') 47 | x = generator_input 48 | 49 | if architecture_size == 'small': 50 | x = Dense(16384, name='generator_input_dense')(x) 51 | x = Reshape((16, 1024), name='generator_input_reshape')(x) 52 | if use_batch_norm == True: 53 | x = BatchNormalization()(x) 54 | 55 | if architecture_size == 'medium' or architecture_size == 'large': 56 | x = Dense(32768, name='generator_input_dense')(x) 57 | x = Reshape((16, 2048), name='generator_input_reshape')(x) 58 | if use_batch_norm == True: 59 | x = BatchNormalization()(x) 60 | 61 | x = ReLU()(x) 62 | 63 | x = Concatenate()([x, label_em]) 64 | 65 | if architecture_size == 'small': 66 | for i in range(4): 67 | x = Conv1DTranspose( 68 | input_tensor = x 69 | , filters = generator_filters[i+1] 70 | , kernel_size = 25 71 | , strides = 4 72 | , padding='same' 73 | , name = f'generator_Tconv_{i}' 74 | , activation = 'relu' 75 | ) 76 | if use_batch_norm == True: 77 | x = BatchNormalization()(x) 78 | 79 | x = Conv1DTranspose( 80 | input_tensor = x 81 | , filters = 1 82 | , kernel_size = 25 83 | , strides = 4 84 | , padding='same' 85 | , name = 'generator_Tconv_4' 86 | , activation = 'tanh' 87 | ) 88 | 89 | if architecture_size == 'medium': 90 | #layer 0 to 4 91 | for i in range(5): 92 | x = Conv1DTranspose( 93 | input_tensor = x 94 | , filters = generator_filters[i] 95 | , kernel_size = 25 96 | , strides = 4 97 | , padding='same' 98 | , name = f'generator_Tconv_{i}' 99 | , activation = 'relu' 100 | ) 101 | if use_batch_norm == True: 102 | x = BatchNormalization()(x) 103 | #layer 5 104 | x = Conv1DTranspose( 105 | input_tensor = x 106 | , filters = 1 107 | , kernel_size = 25 108 | , strides = 2 109 | , padding='same' 110 | , name = 'generator_Tconv_5' 111 | , activation = 'tanh' 112 | ) 113 | 114 | 115 | if architecture_size == 'large': 116 | #layer 0 to 4 117 | for i in range(5): 118 | x = Conv1DTranspose( 119 | input_tensor = x 120 | , filters = generator_filters[i] 121 | , kernel_size = 25 122 | , strides = 4 123 | , padding='same' 124 | , name = f'generator_Tconv_{i}' 125 | , activation = 'relu' 126 | ) 127 | if use_batch_norm == True: 128 | x = BatchNormalization()(x) 129 | 130 | #layer 5 131 | x = Conv1DTranspose( 132 | input_tensor = x 133 | , filters = 1 134 | , kernel_size = 25 135 | , strides = 4 136 | , padding='same' 137 | , name = 'generator_Tconv_5' 138 | , activation = 'tanh' 139 | ) 140 | 141 | generator_output = x 142 | generator = Model([generator_input, label_input], generator_output, name = 'Generator') 143 | return generator 144 | 145 | def discriminator(architecture_size='small', 146 | phaseshuffle_samples = 0, 147 | n_classes = 5): 148 | 149 | discriminator_filters = [64, 128, 256, 512, 1024, 2048] 150 | 151 | if architecture_size == 'large': 152 | audio_input_dim = 65536 153 | elif architecture_size == 'medium': 154 | audio_input_dim = 32768 155 | elif architecture_size == 'small': 156 | audio_input_dim = 16384 157 | 158 | label_input = Input(shape=(1,), dtype='int32', name='discriminator_label_input') 159 | label_em = Embedding(n_classes, n_classes * 20)(label_input) 160 | label_em = Dense(audio_input_dim)(label_em) 161 | label_em = Reshape((audio_input_dim, 1))(label_em) 162 | 163 | discriminator_input = Input(shape=(audio_input_dim,1), name='discriminator_input') 164 | x = Concatenate()([discriminator_input, label_em]) 165 | 166 | if architecture_size == 'small': 167 | # layers 0 to 3 168 | for i in range(4): 169 | x = Conv1D( 170 | filters = discriminator_filters[i] 171 | , kernel_size = 25 172 | , strides = 4 173 | , padding = 'same' 174 | , name = f'discriminator_conv_{i}' 175 | )(x) 176 | 177 | x = LeakyReLU(alpha = 0.2)(x) 178 | if phaseshuffle_samples > 0: 179 | x = Lambda(apply_phaseshuffle)([x, phaseshuffle_samples]) 180 | 181 | #layer 4, no phase shuffle 182 | x = Conv1D( 183 | filters = discriminator_filters[4] 184 | , kernel_size = 25 185 | , strides = 4 186 | , padding = 'same' 187 | , name = f'discriminator_conv_4' 188 | )(x) 189 | 190 | x = Flatten()(x) 191 | 192 | if architecture_size == 'medium': 193 | 194 | # layers 195 | for i in range(4): 196 | x = Conv1D( 197 | filters = discriminator_filters[i] 198 | , kernel_size = 25 199 | , strides = 4 200 | , padding = 'same' 201 | , name = f'discriminator_conv_{i}' 202 | )(x) 203 | 204 | 205 | x = LeakyReLU(alpha = 0.2)(x) 206 | if phaseshuffle_samples > 0: 207 | x = Lambda(apply_phaseshuffle)([x, phaseshuffle_samples]) 208 | 209 | 210 | x = Conv1D( 211 | filters = discriminator_filters[4] 212 | , kernel_size = 25 213 | , strides = 4 214 | , padding = 'same' 215 | , name = 'discriminator_conv_4' 216 | )(x) 217 | 218 | x = LeakyReLU(alpha = 0.2)(x) 219 | 220 | x = Conv1D( 221 | filters = discriminator_filters[5] 222 | , kernel_size = 25 223 | , strides = 2 224 | , padding = 'same' 225 | , name = 'discriminator_conv_5' 226 | )(x) 227 | 228 | 229 | x = LeakyReLU(alpha = 0.2)(x) 230 | x = Flatten()(x) 231 | 232 | if architecture_size == 'large': 233 | 234 | # layers 235 | for i in range(4): 236 | x = Conv1D( 237 | filters = discriminator_filters[i] 238 | , kernel_size = 25 239 | , strides = 4 240 | , padding = 'same' 241 | , name = f'discriminator_conv_{i}' 242 | )(x) 243 | x = LeakyReLU(alpha = 0.2)(x) 244 | if phaseshuffle_samples > 0: 245 | x = Lambda(apply_phaseshuffle)([x, phaseshuffle_samples]) 246 | 247 | #last 2 layers without phase shuffle 248 | x = Conv1D( 249 | filters = discriminator_filters[4] 250 | , kernel_size = 25 251 | , strides = 4 252 | , padding = 'same' 253 | , name = 'discriminator_conv_4' 254 | )(x) 255 | x = LeakyReLU(alpha = 0.2)(x) 256 | 257 | x = Conv1D( 258 | filters = discriminator_filters[5] 259 | , kernel_size = 25 260 | , strides = 4 261 | , padding = 'same' 262 | , name = 'discriminator_conv_5' 263 | )(x) 264 | x = LeakyReLU(alpha = 0.2)(x) 265 | x = Flatten()(x) 266 | 267 | discriminator_output = Dense(1)(x) 268 | discriminator = Model([discriminator_input, label_input], discriminator_output, name = 'Discriminator') 269 | return discriminator --------------------------------------------------------------------------------