├── README.md ├── config.py ├── imgs ├── archi.png └── loss_curve.png ├── logger.py ├── requirements.txt ├── train.py ├── utils.py └── wavegan.py /README.md: -------------------------------------------------------------------------------- 1 | # WaveGAN-pytorch 2 | PyTorch implementation of [Synthesizing Audio with Generative Adversarial Networks(Chris Donahue, Feb 2018)](https://arxiv.org/abs/1802.04208). 3 | 4 | Befor running, make sure you have the `sc09` dataset, and put that dataset under your current filepath. 5 | 6 | ## Quick Start: 7 | 1. Installation 8 | ``` 9 | sudo apt-get install libav-tools 10 | ``` 11 | 12 | 2. Download dataset 13 | * `sc09`: [sc09 raw WAV files](http://deepyeti.ucsd.edu/cdonahue/sc09.tar.gz), utterances of spoken english words '0'-'9' 14 | * `piano`: [Piano raw WAV files](http://deepyeti.ucsd.edu/cdonahue/mancini_piano.tar.gz) 15 | 16 | 3. Run 17 | 18 | For `sc09` task, **make sure `sc09` dataset under your current project filepath befor run your code.** 19 | ``` 20 | $ python train.py 21 | ``` 22 | 23 | #### Training time 24 | * For `SC09` dataset, 4 X Tesla P40 takes nearly 2 days to get reasonable result. 25 | * For `piano` piano dataset, 2 X Tesla P40 takes 3-6 hours to get reasonable result. 26 | * Increase the `BATCH_SIZE` from 10 to 32 or 64 can acquire shorter per-epoch time on multiple-GPU but slower gradient descent learning rate. 27 | 28 | ## Results 29 | Generated "0-9": https://soundcloud.com/mazzzystar/sets/dcgan-sc09 30 | 31 | Generated piano: https://soundcloud.com/mazzzystar/sets/wavegan-piano 32 | 33 | Loss curve: 34 | 35 | ![](imgs/loss_curve.png) 36 | 37 | ## Architecture 38 | ![](imgs/archi.png) 39 | 40 | ## TODO 41 | * [ ] Add some evaluation experiments, eg. inception score. 42 | 43 | ## Contributions 44 | This repo is based on [chrisdonahue's](https://github.com/chrisdonahue/wavegan) and [jtcramer's](https://github.com/jtcramer/wavegan) implementation. 45 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Training 2 | EPOCHS = 180 3 | BATCH_SIZE = 10 4 | 5 | SAMPLE_EVERY = 1 # Generate audio samples every 1 epoch. 6 | SAMPLE_NUM = 10 # Generate 10 samples every sample generation. 7 | 8 | # Data 9 | DATASET_NAME = 'sc09/' 10 | OUTPUT_PATH = "output/" 11 | -------------------------------------------------------------------------------- /imgs/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazzzystar/WaveGAN-pytorch/f56965e32c454880857e3566711da62a2457648f/imgs/archi.png -------------------------------------------------------------------------------- /imgs/loss_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazzzystar/WaveGAN-pytorch/f56965e32c454880857e3566711da62a2457648f/imgs/loss_curve.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def init_console_logger(logger, verbose=False): 5 | stream_handler = logging.StreamHandler() 6 | if verbose: 7 | stream_handler.setLevel(logging.DEBUG) 8 | else: 9 | stream_handler.setLevel(logging.INFO) 10 | formatter = logging.Formatter('[%(levelname)s] %(message)s') 11 | stream_handler.setFormatter(formatter) 12 | file_handler = logging.FileHandler("model.log") 13 | logger.addHandler(stream_handler) 14 | logger.addHandler(file_handler) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | librosa 3 | pescador 4 | matplotlib -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import autograd 3 | from torch import optim 4 | import json 5 | from utils import save_samples 6 | import numpy as np 7 | import pprint 8 | import pickle 9 | import datetime 10 | from wavegan import * 11 | from utils import * 12 | from logger import * 13 | cuda = True if torch.cuda.is_available() else False 14 | 15 | 16 | # =============Logger=============== 17 | LOGGER = logging.getLogger('wavegan') 18 | LOGGER.setLevel(logging.DEBUG) 19 | 20 | LOGGER.info('Initialized logger.') 21 | init_console_logger(LOGGER) 22 | 23 | # =============Parameters=============== 24 | args = parse_arguments() 25 | epochs = args['num_epochs'] 26 | batch_size = args['batch_size'] 27 | latent_dim = args['latent_dim'] 28 | ngpus = args['ngpus'] 29 | model_size = args['model_size'] 30 | model_dir = make_path(os.path.join(args['output_dir'], 31 | datetime.datetime.now().strftime("%Y%m%d%H%M%S"))) 32 | args['model_dir'] = model_dir 33 | # save samples for every N epochs. 34 | epochs_per_sample = args['epochs_per_sample'] 35 | # gradient penalty regularization factor. 36 | lmbda = args['lmbda'] 37 | 38 | # Dir 39 | audio_dir = args['audio_dir'] 40 | output_dir = args['output_dir'] 41 | 42 | # =============Network=============== 43 | netG = WaveGANGenerator(model_size=model_size, ngpus=ngpus, latent_dim=latent_dim, upsample=True) 44 | netD = WaveGANDiscriminator(model_size=model_size, ngpus=ngpus) 45 | 46 | if cuda: 47 | netG = torch.nn.DataParallel(netG).cuda() 48 | netD = torch.nn.DataParallel(netD).cuda() 49 | 50 | # "Two time-scale update rule"(TTUR) to update netD 4x faster than netG. 51 | optimizerG = optim.Adam(netG.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2'])) 52 | optimizerD = optim.Adam(netD.parameters(), lr=args['learning_rate'], betas=(args['beta1'], args['beta2'])) 53 | 54 | # Sample noise used for generated output. 55 | sample_noise = torch.randn(args['sample_size'], latent_dim) 56 | if cuda: 57 | sample_noise = sample_noise.cuda() 58 | sample_noise_Var = autograd.Variable(sample_noise, requires_grad=False) 59 | 60 | # Save config. 61 | LOGGER.info('Saving configurations...') 62 | config_path = os.path.join(model_dir, 'config.json') 63 | with open(config_path, 'w') as f: 64 | json.dump(args, f) 65 | 66 | # Load data. 67 | LOGGER.info('Loading audio data...') 68 | audio_paths = get_all_audio_filepaths(audio_dir) 69 | train_data, valid_data, test_data, train_size = split_data(audio_paths, args['valid_ratio'], 70 | args['test_ratio'], batch_size) 71 | TOTAL_TRAIN_SAMPLES = train_size 72 | BATCH_NUM = TOTAL_TRAIN_SAMPLES // batch_size 73 | 74 | train_iter = iter(train_data) 75 | valid_iter = iter(valid_data) 76 | test_iter = iter(test_data) 77 | 78 | 79 | # =============Train=============== 80 | history = [] 81 | D_costs_train = [] 82 | D_wasses_train = [] 83 | D_costs_valid = [] 84 | D_wasses_valid = [] 85 | G_costs = [] 86 | 87 | start = time.time() 88 | LOGGER.info('Starting training...EPOCHS={}, BATCH_SIZE={}, BATCH_NUM={}'.format(epochs, batch_size, BATCH_NUM)) 89 | for epoch in range(1, epochs+1): 90 | LOGGER.info("{} Epoch: {}/{}".format(time_since(start), epoch, epochs)) 91 | 92 | D_cost_train_epoch = [] 93 | D_wass_train_epoch = [] 94 | D_cost_valid_epoch = [] 95 | D_wass_valid_epoch = [] 96 | G_cost_epoch = [] 97 | for i in range(1, BATCH_NUM+1): 98 | # Set Discriminator parameters to require gradients. 99 | for p in netD.parameters(): 100 | p.requires_grad = True 101 | 102 | one = torch.tensor(1, dtype=torch.float) 103 | neg_one = one * -1 104 | if cuda: 105 | one = one.cuda() 106 | neg_one = neg_one.cuda() 107 | ############################# 108 | # (1) Train Discriminator 109 | ############################# 110 | for iter_dis in range(5): 111 | netD.zero_grad() 112 | 113 | # Noise 114 | noise = torch.Tensor(batch_size, latent_dim).uniform_(-1, 1) 115 | if cuda: 116 | noise = noise.cuda() 117 | noise_Var = Variable(noise, requires_grad=False) 118 | 119 | real_data_Var = numpy_to_var(next(train_iter)['X'], cuda) 120 | 121 | # a) compute loss contribution from real training data 122 | D_real = netD(real_data_Var) 123 | D_real = D_real.mean() # avg loss 124 | D_real.backward(neg_one) # loss * -1 125 | 126 | # b) compute loss contribution from generated data, then backprop. 127 | fake = autograd.Variable(netG(noise_Var).data) 128 | D_fake = netD(fake) 129 | D_fake = D_fake.mean() 130 | D_fake.backward(one) 131 | 132 | # c) compute gradient penalty and backprop 133 | gradient_penalty = calc_gradient_penalty(netD, real_data_Var.data, 134 | fake.data, batch_size, lmbda, 135 | use_cuda=cuda) 136 | gradient_penalty.backward(one) 137 | 138 | # Compute cost * Wassertein loss.. 139 | D_cost_train = D_fake - D_real + gradient_penalty 140 | D_wass_train = D_real - D_fake 141 | 142 | # Update gradient of discriminator. 143 | optimizerD.step() 144 | 145 | ############################# 146 | # (2) Compute Valid data 147 | ############################# 148 | netD.zero_grad() 149 | 150 | valid_data_Var = numpy_to_var(next(valid_iter)['X'], cuda) 151 | D_real_valid = netD(valid_data_Var) 152 | D_real_valid = D_real_valid.mean() # avg loss 153 | 154 | # b) compute loss contribution from generated data, then backprop. 155 | fake_valid = netG(noise_Var) 156 | D_fake_valid = netD(fake_valid) 157 | D_fake_valid = D_fake_valid.mean() 158 | 159 | # c) compute gradient penalty and backprop 160 | gradient_penalty_valid = calc_gradient_penalty(netD, valid_data_Var.data, 161 | fake_valid.data, batch_size, lmbda, 162 | use_cuda=cuda) 163 | # Compute metrics and record in batch history. 164 | D_cost_valid = D_fake_valid - D_real_valid + gradient_penalty_valid 165 | D_wass_valid = D_real_valid - D_fake_valid 166 | 167 | if cuda: 168 | D_cost_train = D_cost_train.cpu() 169 | D_wass_train = D_wass_train.cpu() 170 | D_cost_valid = D_cost_valid.cpu() 171 | D_wass_valid = D_wass_valid.cpu() 172 | 173 | # Record costs 174 | D_cost_train_epoch.append(D_cost_train.data.numpy()) 175 | D_wass_train_epoch.append(D_wass_train.data.numpy()) 176 | D_cost_valid_epoch.append(D_cost_valid.data.numpy()) 177 | D_wass_valid_epoch.append(D_wass_valid.data.numpy()) 178 | 179 | ############################# 180 | # (3) Train Generator 181 | ############################# 182 | # Prevent discriminator update. 183 | for p in netD.parameters(): 184 | p.requires_grad = False 185 | 186 | # Reset generator gradients 187 | netG.zero_grad() 188 | 189 | # Noise 190 | noise = torch.Tensor(batch_size, latent_dim).uniform_(-1, 1) 191 | if cuda: 192 | noise = noise.cuda() 193 | noise_Var = Variable(noise, requires_grad=False) 194 | 195 | fake = netG(noise_Var) 196 | G = netD(fake) 197 | G = G.mean() 198 | 199 | # Update gradients. 200 | G.backward(neg_one) 201 | G_cost = -G 202 | 203 | optimizerG.step() 204 | 205 | # Record costs 206 | if cuda: 207 | G_cost = G_cost.cpu() 208 | G_cost_epoch.append(G_cost.data.numpy()) 209 | 210 | if i % (BATCH_NUM // 5) == 0: 211 | LOGGER.info("{} Epoch={} Batch: {}/{} D_c:{:.4f} | D_w:{:.4f} | G:{:.4f}".format(time_since(start), epoch, 212 | i, BATCH_NUM, 213 | D_cost_train.data.numpy(), 214 | D_wass_train.data.numpy(), 215 | G_cost.data.numpy())) 216 | 217 | # Save the average cost of batches in every epoch. 218 | D_cost_train_epoch_avg = sum(D_cost_train_epoch) / float(len(D_cost_train_epoch)) 219 | D_wass_train_epoch_avg = sum(D_wass_train_epoch) / float(len(D_wass_train_epoch)) 220 | D_cost_valid_epoch_avg = sum(D_cost_valid_epoch) / float(len(D_cost_valid_epoch)) 221 | D_wass_valid_epoch_avg = sum(D_wass_valid_epoch) / float(len(D_wass_valid_epoch)) 222 | G_cost_epoch_avg = sum(G_cost_epoch) / float(len(G_cost_epoch)) 223 | 224 | D_costs_train.append(D_cost_train_epoch_avg) 225 | D_wasses_train.append(D_wass_train_epoch_avg) 226 | D_costs_valid.append(D_cost_valid_epoch_avg) 227 | D_wasses_valid.append(D_wass_valid_epoch_avg) 228 | G_costs.append(G_cost_epoch_avg) 229 | 230 | LOGGER.info("{} D_cost_train:{:.4f} | D_wass_train:{:.4f} | D_cost_valid:{:.4f} | D_wass_valid:{:.4f} | " 231 | "G_cost:{:.4f}".format(time_since(start), 232 | D_cost_train_epoch_avg, 233 | D_wass_train_epoch_avg, 234 | D_cost_valid_epoch_avg, 235 | D_wass_valid_epoch_avg, 236 | G_cost_epoch_avg)) 237 | 238 | # Generate audio samples. 239 | if epoch % epochs_per_sample == 0: 240 | LOGGER.info("Generating samples...") 241 | sample_out = netG(sample_noise_Var) 242 | if cuda: 243 | sample_out = sample_out.cpu() 244 | sample_out = sample_out.data.numpy() 245 | save_samples(sample_out, epoch, output_dir) 246 | 247 | # TODO 248 | # Early stopping by Inception Score(IS) 249 | 250 | LOGGER.info('>>>>>>>Training finished !<<<<<<<') 251 | 252 | # Save model 253 | LOGGER.info("Saving models...") 254 | netD_path = os.path.join(output_dir, "discriminator.pkl") 255 | netG_path = os.path.join(output_dir, "generator.pkl") 256 | torch.save(netD.state_dict(), netD_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) 257 | torch.save(netG.state_dict(), netG_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) 258 | 259 | # Plot loss curve. 260 | LOGGER.info("Saving loss curve...") 261 | plot_loss(D_costs_train, D_wasses_train, 262 | D_costs_valid, D_wasses_valid, G_costs, output_dir) 263 | 264 | LOGGER.info("All finished!") 265 | 266 | 267 | 268 | 269 | 270 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import torch 5 | import random 6 | import logging 7 | import librosa 8 | import argparse 9 | import pescador 10 | import numpy as np 11 | from config import * 12 | from torch import autograd 13 | from torch.autograd import Variable 14 | import matplotlib 15 | matplotlib.use('agg') 16 | import matplotlib.pyplot as plt 17 | 18 | 19 | LOGGER = logging.getLogger('wavegan') 20 | LOGGER.setLevel(logging.DEBUG) 21 | 22 | 23 | def make_path(output_path): 24 | if not os.path.isdir(output_path): 25 | os.makedirs(output_path) 26 | return output_path 27 | 28 | 29 | traindata = DATASET_NAME 30 | output = make_path(OUTPUT_PATH) 31 | 32 | 33 | def time_since(since): 34 | now = time.time() 35 | s = now - since 36 | m = math.floor(s / 60) 37 | s -= m * 60 38 | return '%dm %ds' % (m, s) 39 | 40 | 41 | def save_samples(epoch_samples, epoch, output_dir, fs=16000): 42 | """ 43 | Save output samples. 44 | """ 45 | sample_dir = make_path(os.path.join(output_dir, str(epoch))) 46 | 47 | for idx, sample in enumerate(epoch_samples): 48 | output_path = os.path.join(sample_dir, "{}.wav".format(idx+1)) 49 | sample = sample[0] 50 | librosa.output.write_wav(output_path, sample, fs) 51 | 52 | 53 | # Adapted from @jtcramer https://github.com/jtcramer/wavegan/blob/master/sample.py. 54 | def sample_generator(filepath, window_length=16384, fs=16000): 55 | """ 56 | Audio sample generator 57 | """ 58 | try: 59 | audio_data, _ = librosa.load(filepath, sr=fs) 60 | 61 | # Clip magnitude 62 | max_mag = np.max(np.abs(audio_data)) 63 | if max_mag > 1: 64 | audio_data /= max_mag 65 | except Exception as e: 66 | LOGGER.error("Could not load {}: {}".format(filepath, str(e))) 67 | raise StopIteration 68 | 69 | # Pad audio to >= window_length. 70 | audio_len = len(audio_data) 71 | if audio_len < window_length: 72 | pad_length = window_length - audio_len 73 | left_pad = pad_length // 2 74 | right_pad = pad_length - left_pad 75 | 76 | audio_data = np.pad(audio_data, (left_pad, right_pad), mode='constant') 77 | audio_len = len(audio_data) 78 | 79 | while True: 80 | if audio_len == window_length: 81 | # If we only have a single 1*window_length audio, just yield. 82 | sample = audio_data 83 | else: 84 | # Sample a random window from the audio 85 | start_idx = np.random.randint(0, (audio_len - window_length) // 2) 86 | end_idx = start_idx + window_length 87 | sample = audio_data[start_idx:end_idx] 88 | 89 | sample = sample.astype('float32') 90 | assert not np.any(np.isnan(sample)) 91 | 92 | yield {'X': sample} 93 | 94 | 95 | def get_all_audio_filepaths(audio_dir): 96 | return [os.path.join(root, fname) 97 | for (root, dir_names, file_names) in os.walk(audio_dir, followlinks=True) 98 | for fname in file_names 99 | if (fname.lower().endswith('.wav') or fname.lower().endswith('.mp3'))] 100 | 101 | 102 | def batch_generator(audio_path_list, batch_size): 103 | streamers = [] 104 | for audio_path in audio_path_list: 105 | s = pescador.Streamer(sample_generator, audio_path) 106 | streamers.append(s) 107 | 108 | mux = pescador.ShuffledMux(streamers) 109 | batch_gen = pescador.buffer_stream(mux, batch_size) 110 | 111 | return batch_gen 112 | 113 | 114 | def split_data(audio_path_list, valid_ratio, test_ratio, batch_size): 115 | num_files = len(audio_path_list) 116 | num_valid = int(np.ceil(num_files * valid_ratio)) 117 | num_test = int(np.ceil(num_files * test_ratio)) 118 | num_train = num_files - num_valid - num_test 119 | 120 | if not (num_valid > 0 and num_test > 0 and num_train > 0): 121 | LOGGER.error("Please download DATASET '{}' and put it under current path !".format(DATASET_NAME)) 122 | 123 | # Random shuffle the audio_path_list for splitting. 124 | random.shuffle(audio_path_list) 125 | 126 | valid_files = audio_path_list[:num_valid] 127 | test_files = audio_path_list[num_valid:num_valid + num_test] 128 | train_files = audio_path_list[num_valid + num_test:] 129 | train_size = len(train_files) 130 | 131 | train_data = batch_generator(train_files, batch_size) 132 | valid_data = batch_generator(valid_files, batch_size) 133 | test_data = batch_generator(test_files, batch_size) 134 | 135 | return train_data, valid_data, test_data, train_size 136 | 137 | 138 | # Adapted from https://github.com/caogang/wgan-gp/blob/master/gan_toy.py 139 | def calc_gradient_penalty(net_dis, real_data, fake_data, batch_size, lmbda, use_cuda=False): 140 | # Compute interpolation factors 141 | alpha = torch.rand(batch_size, 1, 1) 142 | alpha = alpha.expand(real_data.size()) 143 | alpha = alpha.cuda() if use_cuda else alpha 144 | 145 | # Interpolate between real and fake data. 146 | interpolates = alpha * real_data + (1 - alpha) * fake_data 147 | if use_cuda: 148 | interpolates = interpolates.cuda() 149 | interpolates = autograd.Variable(interpolates, requires_grad=True) 150 | 151 | # Evaluate discriminator 152 | disc_interpolates = net_dis(interpolates) 153 | 154 | # Obtain gradients of the discriminator with respect to the inputs 155 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 156 | grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else 157 | torch.ones(disc_interpolates.size()), 158 | create_graph=True, retain_graph=True, only_inputs=True)[0] 159 | gradients = gradients.view(gradients.size(0), -1) 160 | 161 | # Compute MSE between 1.0 and the gradient of the norm penalty to make discriminator 162 | # to be a 1-Lipschitz function. 163 | gradient_penalty = lmbda * ((gradients.norm(2, dim=1) - 1) ** 2).mean() 164 | return gradient_penalty 165 | 166 | 167 | def numpy_to_var(numpy_data, cuda): 168 | """ 169 | Convert numpy array to Variable. 170 | """ 171 | data = numpy_data[:, np.newaxis, :] 172 | data = torch.Tensor(data) 173 | if cuda: 174 | data = data.cuda() 175 | return Variable(data, requires_grad=False) 176 | 177 | 178 | def plot_loss(D_cost_train, D_wass_train, D_cost_valid, D_wass_valid, 179 | G_cost, save_path): 180 | assert len(D_cost_train) == len(D_wass_train) == len(D_cost_valid) == len(D_wass_valid) == len(G_cost) 181 | 182 | save_path = os.path.join(save_path, "loss_curve.png") 183 | 184 | x = range(len(D_cost_train)) 185 | 186 | y1 = D_cost_train 187 | y2 = D_wass_train 188 | y3 = D_cost_valid 189 | y4 = D_wass_valid 190 | y5 = G_cost 191 | 192 | plt.plot(x, y1, label='D_loss_train') 193 | plt.plot(x, y2, label='D_wass_train') 194 | plt.plot(x, y3, label='D_loss_valid') 195 | plt.plot(x, y4, label='D_wass_valid') 196 | plt.plot(x, y5, label='G_loss') 197 | 198 | plt.xlabel('Epoch') 199 | plt.ylabel('Loss') 200 | 201 | plt.legend(loc=4) 202 | plt.grid(True) 203 | plt.tight_layout() 204 | 205 | plt.savefig(save_path) 206 | 207 | 208 | def parse_arguments(): 209 | """ 210 | Get command line arguments 211 | """ 212 | parser = argparse.ArgumentParser(description='Train a WaveGAN on a given set of audio') 213 | 214 | parser.add_argument('-ms', '--model-size', dest='model_size', type=int, default=64, 215 | help='Model size parameter used in WaveGAN') 216 | parser.add_argument('-pssf', '--phase-shuffle-shift-factor', dest='shift_factor', type=int, default=2, 217 | help='Maximum shift used by phase shuffle') 218 | parser.add_argument('-psb', '--phase-shuffle-batchwise', dest='batch_shuffle', action='store_true', 219 | help='If true, apply phase shuffle to entire batches rather than individual samples') 220 | parser.add_argument('-ppfl', '--post-proc-filt-len', dest='post_proc_filt_len', type=int, default=512, 221 | help='Length of post processing filter used by generator. Set to 0 to disable.') 222 | parser.add_argument('-lra', '--lrelu-alpha', dest='alpha', type=float, default=0.2, 223 | help='Slope of negative part of LReLU used by discriminator') 224 | parser.add_argument('-vr', '--valid-ratio', dest='valid_ratio', type=float, default=0.1, 225 | help='Ratio of audio files used for validation') 226 | parser.add_argument('-tr', '--test-ratio', dest='test_ratio', type=float, default=0.1, 227 | help='Ratio of audio files used for testing') 228 | parser.add_argument('-bs', '--batch-size', dest='batch_size', type=int, default=BATCH_SIZE, 229 | help='Batch size used for training') 230 | parser.add_argument('-ne', '--num-epochs', dest='num_epochs', type=int, default=EPOCHS, help='Number of epochs') 231 | parser.add_argument('-ng', '--ngpus', dest='ngpus', type=int, default=4, 232 | help='Number of GPUs to use for training') 233 | parser.add_argument('-ld', '--latent-dim', dest='latent_dim', type=int, default=100, 234 | help='Size of latent dimension used by generator') 235 | parser.add_argument('-eps', '--epochs-per-sample', dest='epochs_per_sample', type=int, default=SAMPLE_EVERY, 236 | help='How many epochs between every set of samples generated for inspection') 237 | parser.add_argument('-ss', '--sample-size', dest='sample_size', type=int, default=SAMPLE_NUM, 238 | help='Number of inspection samples generated') 239 | parser.add_argument('-rf', '--regularization-factor', dest='lmbda', type=float, default=10.0, 240 | help='Gradient penalty regularization factor') 241 | parser.add_argument('-lr', '--learning-rate', dest='learning_rate', type=float, default=1e-4, 242 | help='Initial ADAM learning rate') 243 | parser.add_argument('-bo', '--beta-one', dest='beta1', type=float, default=0.5, help='beta_1 ADAM parameter') 244 | parser.add_argument('-bt', '--beta-two', dest='beta2', type=float, default=0.9, help='beta_2 ADAM parameter') 245 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true') 246 | parser.add_argument('-audio_dir', '--audio_dir', dest='audio_dir', type=str, default=traindata, help='Path to directory containing audio files') 247 | parser.add_argument('-output_dir', '--output_dir', dest='output_dir', type=str, default=output, help='Path to directory where model files will be output') 248 | args = parser.parse_args() 249 | return vars(args) 250 | -------------------------------------------------------------------------------- /wavegan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | 6 | 7 | class Transpose1dLayer(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding=11, upsample=None, output_padding=1): 9 | super(Transpose1dLayer, self).__init__() 10 | self.upsample = upsample 11 | 12 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample) 13 | reflection_pad = kernel_size // 2 14 | self.reflection_pad = nn.ConstantPad1d(reflection_pad, value=0) 15 | self.conv1d = torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride) 16 | self.Conv1dTrans = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding) 17 | 18 | def forward(self, x): 19 | if self.upsample: 20 | return self.conv1d(self.reflection_pad(self.upsample_layer(x))) 21 | else: 22 | return self.Conv1dTrans(x) 23 | 24 | 25 | class WaveGANGenerator(nn.Module): 26 | def __init__(self, model_size=64, ngpus=1, num_channels=1, 27 | latent_dim=100, post_proc_filt_len=512, 28 | verbose=False, upsample=True): 29 | super(WaveGANGenerator, self).__init__() 30 | self.ngpus = ngpus 31 | self.model_size = model_size # d 32 | self.num_channels = num_channels # c 33 | self.latent_di = latent_dim 34 | self.post_proc_filt_len = post_proc_filt_len 35 | self.verbose = verbose 36 | # "Dense" is the same meaning as fully connection. 37 | self.fc1 = nn.Linear(latent_dim, 256 * model_size) 38 | 39 | stride = 4 40 | if upsample: 41 | stride = 1 42 | upsample = 4 43 | self.deconv_1 = Transpose1dLayer(16 * model_size, 8 * model_size, 25, stride, upsample=upsample) 44 | self.deconv_2 = Transpose1dLayer(8 * model_size, 4 * model_size, 25, stride, upsample=upsample) 45 | self.deconv_3 = Transpose1dLayer(4 * model_size, 2 * model_size, 25, stride, upsample=upsample) 46 | self.deconv_4 = Transpose1dLayer(2 * model_size, model_size, 25, stride, upsample=upsample) 47 | self.deconv_5 = Transpose1dLayer(model_size, num_channels, 25, stride, upsample=upsample) 48 | 49 | if post_proc_filt_len: 50 | self.ppfilter1 = nn.Conv1d(num_channels, num_channels, post_proc_filt_len) 51 | 52 | for m in self.modules(): 53 | if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear): 54 | nn.init.kaiming_normal(m.weight.data) 55 | 56 | def forward(self, x): 57 | x = self.fc1(x).view(-1, 16 * self.model_size, 16) 58 | x = F.relu(x) 59 | if self.verbose: 60 | print(x.shape) 61 | 62 | x = F.relu(self.deconv_1(x)) 63 | if self.verbose: 64 | print(x.shape) 65 | 66 | x = F.relu(self.deconv_2(x)) 67 | if self.verbose: 68 | print(x.shape) 69 | 70 | x = F.relu(self.deconv_3(x)) 71 | if self.verbose: 72 | print(x.shape) 73 | 74 | x = F.relu(self.deconv_4(x)) 75 | if self.verbose: 76 | print(x.shape) 77 | 78 | output = F.tanh(self.deconv_5(x)) 79 | return output 80 | 81 | 82 | class PhaseShuffle(nn.Module): 83 | """ 84 | Performs phase shuffling, i.e. shifting feature axis of a 3D tensor 85 | by a random integer in {-n, n} and performing reflection padding where 86 | necessary. 87 | """ 88 | # Copied from https://github.com/jtcramer/wavegan/blob/master/wavegan.py#L8 89 | def __init__(self, shift_factor): 90 | super(PhaseShuffle, self).__init__() 91 | self.shift_factor = shift_factor 92 | 93 | def forward(self, x): 94 | if self.shift_factor == 0: 95 | return x 96 | # uniform in (L, R) 97 | k_list = torch.Tensor(x.shape[0]).random_(0, 2 * self.shift_factor + 1) - self.shift_factor 98 | k_list = k_list.numpy().astype(int) 99 | 100 | # Combine sample indices into lists so that less shuffle operations 101 | # need to be performed 102 | k_map = {} 103 | for idx, k in enumerate(k_list): 104 | k = int(k) 105 | if k not in k_map: 106 | k_map[k] = [] 107 | k_map[k].append(idx) 108 | 109 | # Make a copy of x for our output 110 | x_shuffle = x.clone() 111 | 112 | # Apply shuffle to each sample 113 | for k, idxs in k_map.items(): 114 | if k > 0: 115 | x_shuffle[idxs] = F.pad(x[idxs][..., :-k], (k, 0), mode='reflect') 116 | else: 117 | x_shuffle[idxs] = F.pad(x[idxs][..., -k:], (0, -k), mode='reflect') 118 | 119 | assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape, 120 | x.shape) 121 | return x_shuffle 122 | 123 | 124 | class PhaseRemove(nn.Module): 125 | def __init__(self): 126 | super(PhaseRemove, self).__init__() 127 | 128 | def forward(self, x): 129 | pass 130 | 131 | 132 | class WaveGANDiscriminator(nn.Module): 133 | def __init__(self, model_size=64, ngpus=1, num_channels=1, shift_factor=2, 134 | alpha=0.2, verbose=False): 135 | super(WaveGANDiscriminator, self).__init__() 136 | self.model_size = model_size # d 137 | self.ngpus = ngpus 138 | self.num_channels = num_channels # c 139 | self.shift_factor = shift_factor # n 140 | self.alpha = alpha 141 | self.verbose = verbose 142 | 143 | self.conv1 = nn.Conv1d(num_channels, model_size, 25, stride=4, padding=11) 144 | self.conv2 = nn.Conv1d(model_size, 2 * model_size, 25, stride=4, padding=11) 145 | self.conv3 = nn.Conv1d(2 * model_size, 4 * model_size, 25, stride=4, padding=11) 146 | self.conv4 = nn.Conv1d(4 * model_size, 8 * model_size, 25, stride=4, padding=11) 147 | self.conv5 = nn.Conv1d(8 * model_size, 16 * model_size, 25, stride=4, padding=11) 148 | 149 | self.ps1 = PhaseShuffle(shift_factor) 150 | self.ps2 = PhaseShuffle(shift_factor) 151 | self.ps3 = PhaseShuffle(shift_factor) 152 | self.ps4 = PhaseShuffle(shift_factor) 153 | 154 | self.fc1 = nn.Linear(256 * model_size, 1) 155 | 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear): 158 | nn.init.kaiming_normal(m.weight.data) 159 | 160 | def forward(self, x): 161 | x = F.leaky_relu(self.conv1(x), negative_slope=self.alpha) 162 | if self.verbose: 163 | print(x.shape) 164 | x = self.ps1(x) 165 | 166 | x = F.leaky_relu(self.conv2(x), negative_slope=self.alpha) 167 | if self.verbose: 168 | print(x.shape) 169 | x = self.ps2(x) 170 | 171 | x = F.leaky_relu(self.conv3(x), negative_slope=self.alpha) 172 | if self.verbose: 173 | print(x.shape) 174 | x = self.ps3(x) 175 | 176 | x = F.leaky_relu(self.conv4(x), negative_slope=self.alpha) 177 | if self.verbose: 178 | print(x.shape) 179 | x = self.ps4(x) 180 | 181 | x = F.leaky_relu(self.conv5(x), negative_slope=self.alpha) 182 | if self.verbose: 183 | print(x.shape) 184 | 185 | x = x.view(-1, 256 * self.model_size) 186 | if self.verbose: 187 | print(x.shape) 188 | 189 | return self.fc1(x) 190 | 191 | 192 | """ 193 | from torch.autograd import Variable 194 | x = Variable(torch.randn(10, 100)) 195 | G = WaveGANGenerator(verbose=True, upsample=False) 196 | out = G(x) 197 | print(out.shape) 198 | D = WaveGANDiscriminator(verbose=True) 199 | out2 = D(out) 200 | print(out2.shape) 201 | """ 202 | --------------------------------------------------------------------------------