├── images └── eets.png ├── models ├── alignment.png ├── __pycache__ │ ├── decoder.cpython-36.pyc │ ├── encoder.cpython-36.pyc │ └── modules.cpython-36.pyc ├── modules.py ├── v2_discriminator.py ├── model.py ├── encoder.py ├── decoder.py └── discriminator.py ├── README.md ├── utils ├── plot.py ├── writer.py ├── optimizer.py ├── util.py ├── audio.py ├── dataset.py └── loss.py ├── process.py ├── generate.py └── train.py /images/eets.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggeng1995/EATS/HEAD/images/eets.png -------------------------------------------------------------------------------- /models/alignment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggeng1995/EATS/HEAD/models/alignment.png -------------------------------------------------------------------------------- /models/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggeng1995/EATS/HEAD/models/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggeng1995/EATS/HEAD/models/__pycache__/encoder.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggeng1995/EATS/HEAD/models/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EATS 2 | A pytorch implementation of the EATS: End-to-End Adversarial Text-to-Speech (https://arxiv.org/pdf/2006.03575.pdf) 3 | 4 | ![](./images/eets.png) 5 | 6 | ## Attention 7 | * I only implemented the preliminary framework. 8 | * In my opinion, the training cost of EETS is very expensive. Maybe we can consider using ground truth duration during training process and introduce additional auxiliary loss to accelerate convergence. 9 | 10 | ## Author 11 | Geng Yang ([@yanggeng1995](https://github.com/yanggeng1995)) 12 | Jian Cong ([@npujcong](https://github.com/npujcong)) 13 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | data = np.transpose(data, (2, 0, 1)) 12 | return data 13 | 14 | 15 | def plot_waveform_to_numpy(waveform): 16 | fig, ax = plt.subplots(figsize=(12, 3)) 17 | ax.plot() 18 | ax.plot(range(len(waveform)), waveform, 19 | linewidth=0.1, alpha=0.7, color='blue') 20 | 21 | plt.xlabel("Samples") 22 | plt.ylabel("Amplitude") 23 | plt.ylim(-1, 1) 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | 30 | return data 31 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | from tensorboardX import SummaryWriter 2 | from .plot import plot_waveform_to_numpy 3 | 4 | class Writer(SummaryWriter): 5 | def __init__(self, logdir, sample_rate=16000): 6 | super(Writer, self).__init__(logdir) 7 | 8 | self.sample_rate = sample_rate 9 | self.logdir = logdir 10 | 11 | def logging_loss(self, losses, step): 12 | for key in losses: 13 | self.add_scalar('{}'.format(key), losses[key], step) 14 | 15 | def logging_audio(self, target, prediction, step): 16 | self.add_audio('raw_audio_predicted', prediction, step, self.sample_rate) 17 | self.add_image('waveform_predicted', plot_waveform_to_numpy(prediction), step) 18 | self.add_audio('raw_audio_target', target, step, self.sample_rate) 19 | self.add_image('waveform_target', plot_waveform_to_numpy(target), step) 20 | 21 | def logging_histogram(self, model, step): 22 | for tag, value in model.named_parameters(): 23 | self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step) 24 | -------------------------------------------------------------------------------- /utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Optimizer(object): 4 | ''' A simple wrapper class for learning rate scheduling ''' 5 | 6 | def __init__(self, 7 | optimizer, 8 | init_lr, 9 | current_step=0, 10 | warmup_steps=50000, 11 | decay_learning_rate=0.5): 12 | 13 | self.optimizer = optimizer 14 | self.init_lr = init_lr 15 | self.current_steps = current_step 16 | self.warmup_steps = warmup_steps 17 | self.decay_learning_rate = decay_learning_rate 18 | 19 | def zero_grad(self): 20 | self.optimizer.zero_grad() 21 | 22 | def step_and_update_lr(self): 23 | self.update_learning_rate() 24 | self.optimizer.step() 25 | 26 | def get_lr_scale(self): 27 | if self.current_steps >= self.warmup_steps: 28 | lr_scale = np.power(self.decay_learning_rate, self.current_steps / self.warmup_steps) 29 | else: 30 | lr_scale = 1 31 | 32 | return lr_scale 33 | 34 | def update_learning_rate(self): 35 | self.current_steps += 1 36 | lr = self.init_lr * self.get_lr_scale() 37 | lr = np.maximum(1e-6, lr) 38 | self.lr = lr 39 | 40 | for param_group in self.optimizer.param_groups: 41 | param_group['lr'] = self.lr 42 | 43 | 44 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class ExponentialMovingAverage(object): 4 | def __init__(self, decay): 5 | self.decay = decay 6 | self.shadow = {} 7 | 8 | def register(self, name, val): 9 | self.shadow[name] = val.clone() 10 | 11 | def update(self, name, x): 12 | assert name in self.shadow 13 | update_delta = self.shadow[name] - x 14 | self.shadow[name] -= (1.0 - self.decay) * update_delta 15 | 16 | 17 | def apply_moving_average(model, ema): 18 | for name, param in model.named_parameters(): 19 | if name in ema.shadow: 20 | ema.update(name, param.data) 21 | 22 | def register_model_to_ema(model, ema): 23 | for name, param in model.named_parameters(): 24 | if param.requires_grad: 25 | ema.register(name, param.data) 26 | 27 | def mu_law_encode(signal, quantization_channels=65536): 28 | # Manual mu-law companding and mu-bits quantization 29 | mu = quantization_channels - 1 30 | 31 | magnitude = np.log1p(mu * np.abs(signal)) / np.log1p(mu) 32 | signal = np.sign(signal) * magnitude 33 | 34 | # Map signal from [-1, +1] to [0, mu] 35 | signal = (signal + 1) / 2 * mu + 0.5 36 | quantized_signal = signal.astype(np.int32) 37 | 38 | return quantized_signal 39 | 40 | 41 | def mu_law_decode(signal, quantization_channels=65536): 42 | # Calculate inverse mu-law companding and dequantization 43 | mu = quantization_channels - 1 44 | y = signal.astype(np.float32) 45 | 46 | y = 2 * (y / mu) - 1 47 | x = np.sign(y) * (1.0 / mu) * ((1.0 + mu)**abs(y) - 1.0) 48 | return x 49 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import spectral_norm 3 | 4 | class Conv1d(nn.Module): 5 | 6 | "Conv1d for spectral normalisation and orthogonal initialisation" 7 | 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size=1, 12 | stride=1, 13 | dilation=1, 14 | groups=1): 15 | super(Conv1d, self).__init__() 16 | 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | self.kernel_size = kernel_size 20 | self.stride = stride 21 | self.dilation = dilation 22 | self.groups = groups 23 | pad = dilation * (kernel_size - 1) // 2 24 | 25 | layer = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, 26 | stride=stride, padding=pad, dilation=dilation, groups=groups) 27 | nn.init.orthogonal_(layer.weight) 28 | self.layer = spectral_norm(layer) 29 | 30 | def forward(self, inputs): 31 | return self.layer(inputs) 32 | 33 | class Linear(nn.Module): 34 | 35 | "Linear for spectral normalisation and orthogonal initialisation" 36 | 37 | def __init__(self, 38 | in_features, 39 | out_features, 40 | bias=True): 41 | super(Linear, self).__init__() 42 | 43 | self.in_features = in_features 44 | self.out_features = out_features 45 | 46 | layer = nn.Linear(in_features, out_features, bias) 47 | nn.init.orthogonal_(layer.weight) 48 | self.layer = spectral_norm(layer) 49 | 50 | def forward(self, inputs): 51 | return self.layer(inputs) 52 | -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import librosa 3 | import scipy 4 | 5 | sample_rate = 24000 6 | n_fft = 2048 7 | fft_bins = n_fft // 2 + 1 8 | num_mels = 80 9 | #frame_length_ms=50 10 | #frame_shift_ms=12.5 11 | hop_length = 120 #frame_shift_ms * sample_rate / 1000 12 | win_length = 240 #frame_length_ms * sample_rate / 1000 13 | fmin = 40 14 | min_level_db = -100 15 | ref_level_db = 20 16 | 17 | def convert_audio(wav_path): 18 | wav = load_wav(wav_path) 19 | mel = melspectrogram(wav).astype(np.float32) 20 | return mel.transpose(), wav 21 | 22 | def load_wav(filename) : 23 | x = librosa.load(filename, sr=sample_rate)[0] 24 | return x 25 | 26 | def save_wav(y, filename) : 27 | scipy.io.wavfile.write(filename, sample_rate, y) 28 | 29 | mel_basis = None 30 | 31 | def linear_to_mel(spectrogram): 32 | global mel_basis 33 | if mel_basis is None: 34 | mel_basis = build_mel_basis() 35 | return np.dot(mel_basis, spectrogram) 36 | 37 | def build_mel_basis(): 38 | return librosa.filters.mel(sample_rate, n_fft, n_mels=num_mels, fmin=fmin) 39 | 40 | def normalize(S): 41 | return np.clip((S - min_level_db) / -min_level_db, 0, 1) 42 | 43 | def denormalize(S): 44 | return (np.clip(S, 0, 1) * -min_level_db) + min_level_db 45 | 46 | def amp_to_db(x): 47 | return 20 * np.log10(np.maximum(1e-5, x)) 48 | 49 | def db_to_amp(x): 50 | return np.power(10.0, x * 0.05) 51 | 52 | def spectrogram(y): 53 | D = stft(y) 54 | S = amp_to_db(np.abs(D)) - ref_level_db 55 | return normalize(S) 56 | 57 | def melspectrogram(y): 58 | D = stft(y) 59 | S = amp_to_db(linear_to_mel(np.abs(D))) 60 | return normalize(S) 61 | 62 | def stft(y): 63 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 64 | -------------------------------------------------------------------------------- /models/v2_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Discriminator(nn.Module): 5 | def __init__(self): 6 | super(Discriminator, self).__init__() 7 | 8 | self.discriminator = nn.ModuleList([ 9 | nn.Sequential( 10 | nn.ReflectionPad1d(7), 11 | nn.utils.spectral_norm(nn.Conv1d(1, 16, kernel_size=15)), 12 | nn.LeakyReLU(0.2, True), 13 | ), 14 | nn.Sequential( 15 | nn.utils.spectral_norm(nn.Conv1d(16, 64, kernel_size=41, 16 | stride=4, padding=20, groups=4)), 17 | nn.LeakyReLU(0.2, True), 18 | ), 19 | nn.Sequential( 20 | nn.utils.spectral_norm(nn.Conv1d(64, 256, kernel_size=41, 21 | stride=4, padding=20, groups=16)), 22 | nn.LeakyReLU(0.2, True), 23 | ), 24 | nn.Sequential( 25 | nn.utils.spectral_norm(nn.Conv1d(256, 1024, kernel_size=41, 26 | stride=4, padding=20, groups=64)), 27 | nn.LeakyReLU(0.2, True), 28 | ), 29 | nn.Sequential( 30 | nn.utils.spectral_norm(nn.Conv1d(1024, 1024, kernel_size=41, 31 | stride=4, padding=20, groups=256)), 32 | nn.LeakyReLU(0.2, True), 33 | ), 34 | nn.Sequential( 35 | nn.utils.spectral_norm(nn.Conv1d(1024, 1024, kernel_size=5, 36 | stride=1, padding=2)), 37 | nn.LeakyReLU(0.2, True), 38 | ), 39 | nn.utils.spectral_norm(nn.Conv1d(1024, 1, kernel_size=3, 40 | stride=1, padding=1)), 41 | ]) 42 | 43 | def forward(self, x): 44 | for layer in self.discriminator: 45 | x = layer(x) 46 | 47 | return x 48 | 49 | if __name__ == '__main__': 50 | 51 | model = Discriminator() 52 | x = torch.randn(3, 1, 24000) 53 | 54 | score = model(x) 55 | print(score.shape) 56 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import torch 5 | from torch.utils.data import Dataset 6 | from utils.util import mu_law_encode, mu_law_decode 7 | 8 | class CustomerDataset(Dataset): 9 | def __init__(self, 10 | path, 11 | upsample_factor=120, 12 | local_condition=True, 13 | global_condition=False): 14 | 15 | self.path = path 16 | self.metadata = self.get_metadata(path) 17 | 18 | self.upsample_factor = upsample_factor 19 | 20 | self.local_condition = local_condition 21 | self.global_condition = global_condition 22 | 23 | def __getitem__(self, index): 24 | 25 | sample = np.load(os.path.join(self.path, 'audio', self.metadata[index])) 26 | condition = np.load(os.path.join(self.path, 'mel', self.metadata[index])) 27 | 28 | length = min([len(sample), len(condition) * self.upsample_factor]) 29 | 30 | sample = sample[: length] 31 | condition = condition[: length // self.upsample_factor , :] 32 | 33 | sample = sample.reshape(-1, 1) 34 | 35 | if self.local_condition: 36 | return sample, condition 37 | else: 38 | return sample 39 | 40 | def __len__(self): 41 | return len(self.metadata) 42 | 43 | def get_metadata(self, path): 44 | with open(os.path.join(path, 'names.pkl'), 'rb') as f: 45 | metadata = pickle.load(f) 46 | 47 | return metadata 48 | 49 | class CustomerCollate(object): 50 | 51 | def __init__(self, 52 | upsample_factor=120, 53 | condition_window=200, 54 | local_condition=True, 55 | global_condition=False): 56 | 57 | self.upsample_factor = upsample_factor 58 | self.condition_window = condition_window 59 | self.sample_window = condition_window * upsample_factor 60 | self.local_condition = local_condition 61 | self.global_condition = global_condition 62 | 63 | def __call__(self, batch): 64 | return self._collate_fn(batch) 65 | 66 | def _collate_fn(self, batch): 67 | 68 | sample_batch = [] 69 | condition_batch = [] 70 | for (i, x) in enumerate(batch): 71 | if len(x[1]) < self.condition_window: 72 | sample = np.pad(x[0], [[0, self.sample_window - len(x[0])], [0, 0]], 'constant') 73 | condition = np.pad(x[1], [[0, self.condition_window - len(x[1])], [0, 0]], 'edge') 74 | else: 75 | lc_index = np.random.randint(0, len(x[1]) - self.condition_window) 76 | sample = x[0][lc_index * self.upsample_factor : 77 | (lc_index + self.condition_window) * self.upsample_factor] 78 | condition = x[1][lc_index : (lc_index + self.condition_window)] 79 | sample_batch.append(sample) 80 | condition_batch.append(condition) 81 | 82 | sample_batch = np.stack(sample_batch) 83 | condition_batch = np.stack(condition_batch) 84 | sample_batch = mu_law_encode(sample_batch) 85 | sample_batch = mu_law_decode(sample_batch) 86 | 87 | samples = torch.FloatTensor(sample_batch).transpose(1, 2) 88 | conditions = torch.FloatTensor(condition_batch).transpose(1, 2) 89 | 90 | return samples, conditions 91 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import glob 4 | import argparse 5 | import numpy as np 6 | from multiprocessing import cpu_count 7 | from concurrent.futures import ProcessPoolExecutor 8 | from functools import partial 9 | from utils.audio import convert_audio, hop_length, sample_rate 10 | from tqdm import tqdm 11 | import random 12 | 13 | train_rate = 0.9995 14 | test_rate = 0.0005 15 | 16 | def find_files(path, pattren="*.wav"): 17 | filenames = [] 18 | for filename in glob.iglob(f'{path}/**/*{pattren}', recursive=True): 19 | filenames.append(filename) 20 | return filenames 21 | 22 | def data_prepare(audio_path, mel_path, wav_file): 23 | mel, audio = convert_audio(wav_file) 24 | np.save(audio_path, audio, allow_pickle=False) 25 | np.save(mel_path, mel, allow_pickle=False) 26 | return audio_path, mel_path, mel.shape[0] 27 | 28 | def process(output_dir, wav_files, train_dir, test_dir, num_workers): 29 | executor = ProcessPoolExecutor(max_workers=num_workers) 30 | results = [] 31 | names = [] 32 | 33 | random.shuffle(wav_files) 34 | train_num = int(len(wav_files) * train_rate) 35 | 36 | for wav_file in wav_files[0 : train_num]: 37 | fid = os.path.basename(wav_file).replace('.wav','.npy') 38 | names.append(fid) 39 | results.append(executor.submit(partial(data_prepare, os.path.join(train_dir, "audio", fid), os.path.join(train_dir, "mel", fid), wav_file))) 40 | 41 | with open(os.path.join(output_dir, "train", 'names.pkl'), 'wb') as f: 42 | pickle.dump(names, f) 43 | 44 | names = [] 45 | for wav_file in wav_files[train_num : len(wav_files)]: 46 | fid = os.path.basename(wav_file).replace('.wav','.npy') 47 | names.append(fid) 48 | results.append(executor.submit(partial(data_prepare, os.path.join(test_dir, "audio", fid), os.path.join(test_dir, "mel", fid), wav_file))) 49 | 50 | with open(os.path.join(output_dir, "test", 'names.pkl'), 'wb') as f: 51 | pickle.dump(names, f) 52 | 53 | 54 | return [result.result() for result in tqdm(results)] 55 | 56 | def preprocess(args): 57 | train_dir = os.path.join(args.output, 'train') 58 | test_dir = os.path.join(args.output, 'test') 59 | os.makedirs(args.output, exist_ok=True) 60 | os.makedirs(train_dir, exist_ok=True) 61 | os.makedirs(test_dir, exist_ok=True) 62 | os.makedirs(os.path.join(train_dir, "audio"), exist_ok=True) 63 | os.makedirs(os.path.join(train_dir, "mel"), exist_ok=True) 64 | os.makedirs(os.path.join(test_dir, "audio"), exist_ok=True) 65 | os.makedirs(os.path.join(test_dir, "mel"), exist_ok=True) 66 | 67 | wav_files = find_files(args.wav_dir) 68 | metadata = process(args.output, wav_files, train_dir, test_dir, args.num_workers) 69 | write_metadata(metadata, args.output) 70 | 71 | def write_metadata(metadata, out_dir): 72 | with open(os.path.join(out_dir, 'metadata.txt'), 'w', encoding='utf-8') as f: 73 | for m in metadata: 74 | f.write('|'.join([str(x) for x in m]) + '\n') 75 | frames = sum([m[2] for m in metadata]) 76 | frame_shift_ms = hop_length * 1000 / sample_rate 77 | hours = frames * frame_shift_ms / (3600 * 1000) 78 | print('Write %d utterances, %d frames (%.2f hours)' % (len(metadata), frames, hours)) 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--wav_dir', default='wavs') 83 | parser.add_argument('--output', default='data') 84 | parser.add_argument('--num_workers', type=int, default=int(cpu_count())) 85 | args = parser.parse_args() 86 | preprocess(args) 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.audio import save_wav 3 | import argparse 4 | import os 5 | import time 6 | import numpy as np 7 | from models.generator import Generator 8 | from utils.util import mu_law_encode, mu_law_decode 9 | 10 | def attempt_to_restore(generate, checkpoint_dir, use_cuda): 11 | checkpoint_list = os.path.join(checkpoint_dir, 'checkpoint') 12 | 13 | if os.path.exists(checkpoint_list): 14 | checkpoint_filename = open(checkpoint_list).readline().strip() 15 | checkpoint_path = os.path.join( 16 | checkpoint_dir, "{}".format(checkpoint_filename)) 17 | print("Restore from {}".format(checkpoint_path)) 18 | checkpoint = load_checkpoint(checkpoint_path, use_cuda) 19 | generate.load_state_dict(checkpoint["generator"]) 20 | 21 | def load_checkpoint(checkpoint_path, use_cuda): 22 | if use_cuda: 23 | checkpoint = torch.load(checkpoint_path) 24 | else: 25 | checkpoint = torch.load( 26 | checkpoint_path, map_location=lambda storage, loc: storage) 27 | return checkpoint 28 | 29 | def create_model(args): 30 | 31 | generator = Generator(args.local_condition_dim, args.z_dim) 32 | 33 | return generator 34 | 35 | def synthesis(args): 36 | 37 | model = create_model(args) 38 | if args.resume is not None: 39 | attempt_to_restore(model, args.resume, args.use_cuda) 40 | 41 | device = torch.device("cuda" if args.use_cuda else "cpu") 42 | model.to(device) 43 | 44 | output_dir = "samples" 45 | os.makedirs(output_dir, exist_ok=True) 46 | 47 | avg_rtf = [] 48 | for filename in os.listdir(os.path.join(args.input, 'mel')): 49 | start = time.time() 50 | conditions = np.load(os.path.join(args.input, 'mel', filename)) 51 | conditions = torch.FloatTensor(conditions).unsqueeze(0) 52 | conditions = conditions.transpose(1, 2).to(device) 53 | 54 | batch_size = conditions.size()[0] 55 | z = torch.randn(batch_size, args.z_dim).to(device).normal_(0.0, 1.0) 56 | audios = model(conditions, z) 57 | audios = audios.cpu().squeeze().detach().numpy() 58 | print(audios.shape) 59 | name = filename.split('.')[0] 60 | sample = np.load(os.path.join(args.input, 'audio', filename)) 61 | sample = mu_law_decode(mu_law_encode(sample)) 62 | save_wav(np.squeeze(sample), '{}/{}_target.wav'.format(output_dir, name)) 63 | save_wav(np.asarray(audios), '{}/{}.wav'.format(output_dir, name)) 64 | time_used = time.time() - start 65 | rtf = time_used / (len(audios) / 24000) 66 | avg_rtf.append(rtf) 67 | print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf)) 68 | 69 | print("Average RTF: {:.3f}".format(sum(avg_rtf) / len(avg_rtf))) 70 | 71 | def main(): 72 | 73 | def _str_to_bool(s): 74 | """Convert string to bool (in argparse context).""" 75 | if s.lower() not in ['true', 'false']: 76 | raise ValueError('Argument needs to be a ' 77 | 'boolean, got {}'.format(s)) 78 | return {'true': True, 'false': False}[s.lower()] 79 | 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--input', type=str, default='data/test', help='Directory of tests data') 83 | parser.add_argument('--num_workers',type=int, default=4, help='Number of dataloader workers.') 84 | parser.add_argument('--resume', type=str, default="logdir") 85 | parser.add_argument('--local_condition_dim', type=int, default=80) 86 | parser.add_argument('--z_dim', type=int, default=128) 87 | parser.add_argument('--use_cuda', type=_str_to_bool, default=False) 88 | 89 | args = parser.parse_args() 90 | synthesis(args) 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from encoder import Encoder 4 | from decoder import Decoder 5 | 6 | class EETS(nn.Module): 7 | def __init__(self, 8 | phone_embedding_dim=128, 9 | tone_embedding_dim=64, 10 | prosody_embedding_dim=32, 11 | seg_embedding_dim=32, 12 | spk_embedding_dim=64, 13 | z_dim=64): 14 | super(EETS, self).__init__() 15 | 16 | self.phone_embedding_dim = phone_embedding_dim 17 | self.tone_embedding_dim = tone_embedding_dim 18 | self.prosody_embedding_dim = prosody_embedding_dim 19 | self.seg_embedding_dim = seg_embedding_dim 20 | 21 | self.spk_embedding_dim = spk_embedding_dim 22 | self.z_dim = z_dim 23 | 24 | self.encoder_dim = phone_embedding_dim + tone_embedding_dim + prosody_embedding_dim + seg_embedding_dim 25 | 26 | self.phone_embed = nn.Embedding(100, phone_embedding_dim) 27 | self.tone_embed = nn.Embedding(8, tone_embedding_dim) 28 | self.prosody_embed = nn.Embedding(4, prosody_embedding_dim) 29 | self.seg_embed = nn.Embedding(4, seg_embedding_dim) 30 | self.speaker_embed = nn.Embedding(100, spk_embedding_dim) 31 | 32 | self.encoder = Encoder(self.encoder_dim, spk_embedding_dim, z_dim) 33 | self.expand_model = ExpandFrame() 34 | self.decoder = Decoder(256, spk_embedding_dim + z_dim) 35 | 36 | def forward(self, phone, tone, prosody, segment, noise, speakers, duration): 37 | phone_inputs = self.phone_embed(phone).transpose(1, 2) 38 | tone_inputs = self.tone_embed(tone).transpose(1, 2) 39 | prosody_inputs = self.prosody_embed(prosody).transpose(1, 2) 40 | segment_inputs = self.seg_embed(segment).transpose(1, 2) 41 | 42 | encoder_inputs = torch.cat([phone_inputs, tone_inputs, prosody_inputs, segment_inputs], dim=1) 43 | 44 | speaker_inputs = self.speaker_embed(speakers).squeeze(1) 45 | 46 | encoder_outputs, predict_duration = self.encoder(encoder_inputs, noise, speaker_inputs) 47 | 48 | decoder_inputs = self.expand_model(encoder_outputs, duration) 49 | 50 | z_inputs = torch.cat([noise, speaker_inputs], dim=1) 51 | outputs = self.decoder(decoder_inputs.transpose(1, 2), z_inputs) 52 | 53 | return outputs, predict_duration 54 | 55 | class ExpandFrame(nn.Module): 56 | def __init__(self): 57 | super(ExpandFrame, self).__init__() 58 | pass 59 | 60 | def forward(self, encoder_outputs, duration): 61 | t = torch.round(torch.sum(duration, dim=-1, keepdim=True)) #[B, 1] 62 | e = torch.cumsum(duration, dim=-1).float() #[B, L] 63 | c = e - 0.5 * torch.round(duration) #[B, L] 64 | 65 | t = torch.range(0, torch.max(t)) 66 | t = t.unsqueeze(0).unsqueeze(1) #[1, 1, T] 67 | c = c.unsqueeze(2) 68 | w_1 = torch.exp(-0.1 * (t - c) ** 2) # [B, L, T] 69 | w_2 = torch.sum(torch.exp(-0.1 * (t - c) ** 2), dim=1, keepdim=True) # [B, 1, T] 70 | 71 | w = w_1 / w_2 72 | 73 | out = torch.matmul(w.transpose(1, 2), encoder_outputs) 74 | 75 | return out 76 | 77 | if __name__ == "__main__": 78 | l = [[1.5, 2.3, 3.4, 4.4, 5.1, 4.2, 3.5, 2.6, 1.8, 0, 0, 0, 0], 79 | [1.5, 2.3, 3.4, 4.4, 5.1, 4.2, 3.5, 2.6, 1.8, 4.5, 5.5, 2.3, 5.6]] 80 | model = EETS() 81 | outputs = torch.randn(2, 13, 512) 82 | phone = torch.LongTensor([1,0,2,1,0,1,1,1,1,1,1,1,1]).view(1, 13).repeat(2, 1) 83 | speaker = torch.LongTensor([0, 1]).view(2, 1) 84 | duration = torch.FloatTensor(l) 85 | noise = torch.randn(2, 64) 86 | print(duration.shape) 87 | out, pduration= model(phone, phone, phone, phone, noise, speaker, duration) 88 | print(out.shape, pduration.shape) 89 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def stft(x, fft_size, hop_size, win_size, window): 6 | """Perform STFT and convert to magnitude spectrogram. 7 | 8 | Args: 9 | x: Input signal tensor (B, T). 10 | 11 | Returns: 12 | Tensor: Magnitude spectrogram (B, T, fft_size // 2 + 1). 13 | 14 | """ 15 | x_stft = torch.stft(x, fft_size, hop_size, win_size, window) 16 | real = x_stft[..., 0] 17 | imag = x_stft[..., 1] 18 | outputs = torch.clamp(real ** 2 + imag ** 2, min=1e-7).transpose(2, 1) 19 | outputs = torch.sqrt(outputs) 20 | 21 | return outputs 22 | 23 | class SpectralConvergence(nn.Module): 24 | def __init__(self): 25 | """Initilize spectral convergence loss module.""" 26 | super(SpectralConvergence, self).__init__() 27 | 28 | def forward(self, predicts_mag, targets_mag): 29 | x = torch.norm(targets_mag - predicts_mag, p='fro') 30 | y = torch.norm(targets_mag, p='fro') 31 | 32 | return x / y 33 | 34 | class LogSTFTMagnitude(nn.Module): 35 | def __init__(self): 36 | super(LogSTFTMagnitude, self).__init__() 37 | 38 | def forward(self, predicts_mag, targets_mag): 39 | log_predicts_mag = torch.log(predicts_mag) 40 | log_targets_mag = torch.log(targets_mag) 41 | outputs = F.l1_loss(log_predicts_mag, log_targets_mag) 42 | 43 | return outputs 44 | 45 | class STFTLoss(nn.Module): 46 | def __init__(self, 47 | fft_size=1024, 48 | hop_size=120, 49 | win_size=600): 50 | super(STFTLoss, self).__init__() 51 | 52 | self.fft_size = fft_size 53 | self.hop_size = hop_size 54 | self.win_size = win_size 55 | self.window = torch.hann_window(win_size) 56 | self.sc_loss = SpectralConvergence() 57 | self.mag = LogSTFTMagnitude() 58 | 59 | 60 | def forward(self, predicts, targets): 61 | """ 62 | Args: 63 | x: predicted signal (B, T). 64 | y: truth signal (B, T). 65 | 66 | Returns: 67 | Tensor: STFT loss values. 68 | """ 69 | predicts_mag = stft(predicts, self.fft_size, self.hop_size, self.win_size, self.window) 70 | targets_mag = stft(targets, self.fft_size, self.hop_size, self.win_size, self.window) 71 | 72 | sc_loss = self.sc_loss(predicts_mag, targets_mag) 73 | mag_loss = self.mag(predicts_mag, targets_mag) 74 | 75 | return sc_loss, mag_loss 76 | 77 | class MultiResolutionSTFTLoss(nn.Module): 78 | def __init__(self, 79 | fft_sizes=[1024, 2048, 512], 80 | win_sizes=[600, 1200, 240], 81 | hop_sizes=[120, 240, 50]): 82 | super(MultiResolutionSTFTLoss, self).__init__() 83 | self.loss_layers = torch.nn.ModuleList() 84 | for (fft_size, win_size, hop_size) in zip(fft_sizes, win_sizes, hop_sizes): 85 | self.loss_layers.append(STFTLoss(fft_size, hop_size, win_size)) 86 | 87 | def forward(self, fake_signals, true_signals): 88 | sc_losses, mag_losses = [], [] 89 | for layer in self.loss_layers: 90 | sc_loss, mag_loss = layer(fake_signals, true_signals) 91 | sc_losses.append(sc_loss) 92 | mag_losses.append(mag_loss) 93 | 94 | sc_loss = sum(sc_losses) / len(sc_losses) 95 | mag_loss = sum(mag_losses) / len(mag_losses) 96 | 97 | return sc_loss, mag_loss 98 | 99 | if __name__ == "__main__": 100 | model = MultiResolutionSTFTLoss() 101 | x = torch.randn(2, 16000) 102 | y = torch.randn(2, 16000) 103 | 104 | loss = model(x, y) 105 | print(loss) 106 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import spectral_norm 5 | from modules import Conv1d, Linear 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, 9 | encoder_dim, 10 | spk_embedding_dim, 11 | z_dim): 12 | super(Encoder, self).__init__() 13 | 14 | self.encoder_dim = encoder_dim 15 | self.spk_embedding_dim = spk_embedding_dim 16 | self.z_dim = z_dim 17 | 18 | self.aligner = Aligner(encoder_dim, z_dim, spk_embedding_dim) 19 | 20 | def forward(self, encoder_inputs, z, speaker_inputs): 21 | encoder_outputs, duration = self.aligner(encoder_inputs, z, speaker_inputs) 22 | 23 | return encoder_outputs, duration 24 | 25 | class Aligner(nn.Module): 26 | def __init__(self, 27 | in_channels, 28 | z_channels, 29 | s_channels, 30 | num_dilation_layer=10): 31 | super(Aligner, self).__init__() 32 | 33 | self.in_channels = in_channels 34 | self.z_channels = z_channels 35 | self.s_channels = s_channels 36 | 37 | self.pre_process = Conv1d(in_channels, 256, kernel_size=3) 38 | 39 | self.dilated_conv_layers = nn.ModuleList() 40 | for i in range(num_dilation_layer): 41 | dilation = 2**i 42 | self.dilated_conv_layers.append(DilatedConvBlock(256, 256, 43 | z_channels, s_channels, dilation)) 44 | 45 | self.post_process = nn.Sequential( 46 | Linear(256, 256), 47 | nn.ReLU(inplace=False), 48 | Linear(256, 1), 49 | nn.ReLU(inplace=False), 50 | ) 51 | 52 | def forward(self, inputs, z, s): 53 | outputs = self.pre_process(inputs) 54 | for layer in self.dilated_conv_layers: 55 | outputs = layer(outputs, z, s) 56 | 57 | encoder_outputs = outputs.transpose(1, 2) 58 | duration = self.post_process(outputs.transpose(1, 2)) 59 | 60 | return encoder_outputs, duration.squeeze(-1) 61 | 62 | class DilatedConvBlock(nn.Module): 63 | 64 | """A stack of dilated convolutions interspersed 65 | with batch normalisation and ReLU activations """ 66 | 67 | def __init__(self, 68 | in_channels, 69 | out_channels, 70 | z_channels, 71 | s_channels, 72 | dilation): 73 | super(DilatedConvBlock, self).__init__() 74 | 75 | self.in_channels = in_channels 76 | self.out_channels = out_channels 77 | self.z_channels = z_channels 78 | self.s_channels = s_channels 79 | 80 | self.conv1d = Conv1d(in_channels, out_channels, kernel_size=3, dilation=dilation) 81 | self.batch_layer = BatchNorm1dLayer(out_channels, s_channels, z_channels) 82 | 83 | def forward(self, inputs, z, s): 84 | outputs = self.conv1d(inputs) 85 | outputs = self.batch_layer(outputs, z, s) 86 | return F.relu(outputs) 87 | 88 | class BatchNorm1dLayer(nn.Module): 89 | 90 | """The latents z and speaker embedding s modulate the scale and 91 | shift parameters of the batch normalisation layers""" 92 | 93 | def __init__(self, 94 | num_features, 95 | s_channels=128, 96 | z_channels=128): 97 | super().__init__() 98 | 99 | self.num_features = num_features 100 | self.s_channels = s_channels 101 | self.z_channels = z_channels 102 | self.batch_nrom = nn.BatchNorm1d(num_features, affine=False) 103 | 104 | self.scale_layer = spectral_norm(nn.Linear(z_channels, num_features)) 105 | self.scale_layer.weight.data.normal_(1, 0.02) # Initialise scale at N(1, 0.02) 106 | self.scale_layer.bias.data.zero_() # Initialise bias at 0 107 | 108 | self.shift_layer = spectral_norm(nn.Linear(s_channels, num_features)) 109 | self.shift_layer.weight.data.normal_(1, 0.02) # Initialise scale at N(1, 0.02) 110 | self.shift_layer.bias.data.zero_() # Initialise bias at 0 111 | 112 | def forward(self, inputs, z, s): 113 | outputs = self.batch_nrom(inputs) 114 | scale = self.scale_layer(z) 115 | scale = scale.view(-1, self.num_features, 1) 116 | 117 | shift = self.shift_layer(s) 118 | shift = shift.view(-1, self.num_features, 1) 119 | 120 | outputs = scale * outputs + shift 121 | 122 | return outputs 123 | 124 | if __name__ == "__main__": 125 | model = Encoder(256, 64, 64) 126 | encoder_inputs = torch.randn(2, 256, 10) 127 | z = torch.randn(2, 64) 128 | speaker = torch.randn(1, 64) 129 | outputs, duration = model(encoder_inputs, z, speaker) 130 | print(outputs.shape, duration.shape) -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import spectral_norm 3 | from modules import Conv1d 4 | 5 | class Decoder(nn.Module): 6 | def __init__(self, 7 | in_channels=567, 8 | z_channels=128): 9 | super(Decoder, self).__init__() 10 | 11 | self.in_channels = in_channels 12 | self.z_channels = z_channels 13 | 14 | self.preprocess = Conv1d(in_channels, 768, kernel_size=3) 15 | self.gblocks = nn.ModuleList ([ 16 | GBlock(768, 768, z_channels, 1), 17 | GBlock(768, 768, z_channels, 1), 18 | GBlock(768, 384, z_channels, 2), 19 | GBlock(384, 384, z_channels, 2), 20 | GBlock(384, 384, z_channels, 2), 21 | GBlock(384, 192, z_channels, 3), 22 | GBlock(192, 96, z_channels, 5) 23 | ]) 24 | self.postprocess = nn.Sequential( 25 | Conv1d(96, 1, kernel_size=3), 26 | nn.Tanh() 27 | ) 28 | 29 | def forward(self, inputs, z): 30 | inputs = self.preprocess(inputs) 31 | outputs = inputs 32 | for (_, layer) in enumerate(self.gblocks): 33 | outputs = layer(outputs, z) 34 | outputs = self.postprocess(outputs) 35 | 36 | return outputs 37 | 38 | class GBlock(nn.Module): 39 | def __init__(self, 40 | in_channels, 41 | hidden_channels, 42 | z_channels, 43 | upsample_factor): 44 | super(GBlock, self).__init__() 45 | 46 | self.in_channels = in_channels 47 | self.hidden_channels = hidden_channels 48 | self.z_channels = z_channels 49 | self.upsample_factor = upsample_factor 50 | 51 | self.condition_batchnorm1 = ConditionalBatchNorm1d(in_channels, z_channels) 52 | self.first_stack = nn.Sequential( 53 | nn.ReLU(inplace=False), 54 | UpsampleNet(in_channels, in_channels, upsample_factor), 55 | Conv1d(in_channels, hidden_channels, kernel_size=3) 56 | ) 57 | 58 | self.condition_batchnorm2 = ConditionalBatchNorm1d(hidden_channels, z_channels) 59 | self.second_stack = nn.Sequential( 60 | nn.ReLU(inplace=False), 61 | Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=2) 62 | ) 63 | 64 | self.residual1 = nn.Sequential( 65 | UpsampleNet(in_channels, in_channels, upsample_factor), 66 | Conv1d(in_channels, hidden_channels, kernel_size=1) 67 | ) 68 | 69 | self.condition_batchnorm3 = ConditionalBatchNorm1d(hidden_channels, z_channels) 70 | self.third_stack = nn.Sequential( 71 | nn.ReLU(inplace=False), 72 | Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=4) 73 | ) 74 | 75 | self.condition_batchnorm4 = ConditionalBatchNorm1d(hidden_channels, z_channels) 76 | self.fourth_stack = nn.Sequential( 77 | nn.ReLU(inplace=False), 78 | Conv1d(hidden_channels, hidden_channels, kernel_size=3, dilation=8) 79 | ) 80 | 81 | def forward(self, condition, z): 82 | inputs = condition 83 | 84 | outputs = self.condition_batchnorm1(inputs, z) 85 | outputs = self.first_stack(outputs) 86 | outputs = self.condition_batchnorm2(outputs, z) 87 | outputs = self.second_stack(outputs) 88 | 89 | residual_outputs = self.residual1(inputs) + outputs 90 | 91 | outputs = self.condition_batchnorm3(residual_outputs, z) 92 | outputs = self.third_stack(outputs) 93 | outputs = self.condition_batchnorm4(outputs, z) 94 | outputs = self.fourth_stack(outputs) 95 | 96 | outputs = outputs + residual_outputs 97 | 98 | return outputs 99 | 100 | class UpsampleNet(nn.Module): 101 | def __init__(self, 102 | input_size, 103 | output_size, 104 | upsample_factor): 105 | 106 | super(UpsampleNet, self).__init__() 107 | self.input_size = input_size 108 | self.output_size = output_size 109 | self.upsample_factor = upsample_factor 110 | 111 | layer = nn.ConvTranspose1d(input_size, output_size, upsample_factor * 2, 112 | upsample_factor, padding=upsample_factor // 2) 113 | nn.init.orthogonal_(layer.weight) 114 | self.layer = spectral_norm(layer) 115 | 116 | def forward(self, inputs): 117 | outputs = self.layer(inputs) 118 | outputs = outputs[:, :, : inputs.size(-1) * self.upsample_factor] 119 | return outputs 120 | 121 | class ConditionalBatchNorm1d(nn.Module): 122 | 123 | """Conditional Batch Normalization""" 124 | 125 | def __init__(self, num_features, z_channels=128): 126 | super().__init__() 127 | 128 | self.num_features = num_features 129 | self.z_channels = z_channels 130 | self.batch_nrom = nn.BatchNorm1d(num_features, affine=False) 131 | 132 | self.layer = spectral_norm(nn.Linear(z_channels, num_features * 2)) 133 | self.layer.weight.data.normal_(1, 0.02) # Initialise scale at N(1, 0.02) 134 | self.layer.bias.data.zero_() # Initialise bias at 0 135 | 136 | def forward(self, inputs, noise): 137 | outputs = self.batch_nrom(inputs) 138 | gamma, beta = self.layer(noise).chunk(2, 1) 139 | gamma = gamma.view(-1, self.num_features, 1) 140 | beta = beta.view(-1, self.num_features, 1) 141 | 142 | outputs = gamma * outputs + beta 143 | 144 | return outputs 145 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules import Conv1d 4 | import numpy as np 5 | 6 | class Multiple_Random_Window_Discriminators(nn.Module): 7 | def __init__(self, 8 | lc_channels, 9 | window_size=(2, 4, 8, 16, 30), 10 | upsample_factor=120): 11 | 12 | super(Multiple_Random_Window_Discriminators, self).__init__() 13 | 14 | self.lc_channels = lc_channels 15 | self.window_size = window_size 16 | self.upsample_factor = upsample_factor 17 | 18 | self.udiscriminators = nn.ModuleList([ 19 | UnConditionalDBlocks(in_channels=1, factors=(5, 3), out_channels=(128, 256)), 20 | UnConditionalDBlocks(in_channels=2, factors=(5, 3), out_channels=(128, 256)), 21 | UnConditionalDBlocks(in_channels=4, factors=(5, 3), out_channels=(128, 256)), 22 | UnConditionalDBlocks(in_channels=8, factors=(5, 3), out_channels=(128, 256)), 23 | UnConditionalDBlocks(in_channels=15, factors=(2, 2), out_channels=(128, 256)), 24 | ]) 25 | 26 | def forward(self, samples): 27 | 28 | outputs = [] 29 | #unconditional discriminator 30 | for (size, layer) in zip(self.window_size, self.udiscriminators): 31 | size = size * self.upsample_factor 32 | index = np.random.randint(samples.size()[-1] - size) 33 | 34 | output = layer(samples[:, :, index : index + size]) 35 | outputs.append(output) 36 | 37 | return outputs 38 | 39 | class CondDBlock(nn.Module): 40 | def __init__(self, 41 | in_channels, 42 | lc_channels, 43 | downsample_factor): 44 | super(CondDBlock, self).__init__() 45 | 46 | self.in_channels = in_channels 47 | self.lc_channels = lc_channels 48 | self.downsample_factor = downsample_factor 49 | 50 | self.start = nn.Sequential( 51 | nn.AvgPool1d(downsample_factor, stride=downsample_factor), 52 | nn.ReLU(), 53 | Conv1d(in_channels, in_channels * 2, kernel_size=3) 54 | ) 55 | self.lc_conv1d = Conv1d(lc_channels, in_channels * 2, 1) 56 | self.end = nn.Sequential( 57 | nn.ReLU(), 58 | Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2) 59 | ) 60 | self.residual = nn.Sequential( 61 | Conv1d(in_channels, in_channels * 2, kernel_size=1), 62 | nn.AvgPool1d(downsample_factor, stride=downsample_factor) 63 | ) 64 | 65 | def forward(self, inputs, conditions): 66 | outputs = self.start(inputs) + self.lc_conv1d(conditions) 67 | outputs = self.end(outputs) 68 | residual_outputs = self.residual(inputs) 69 | outputs = outputs + residual_outputs 70 | 71 | return outputs 72 | 73 | class DBlock(nn.Module): 74 | def __init__(self, 75 | in_channels, 76 | out_channels, 77 | downsample_factor): 78 | super(DBlock, self).__init__() 79 | 80 | self.in_channels = in_channels 81 | self.out_channels = out_channels 82 | self.downsample_factor = downsample_factor 83 | 84 | self.layers = nn.Sequential( 85 | nn.AvgPool1d(downsample_factor, stride=downsample_factor), 86 | nn.ReLU(), 87 | Conv1d(in_channels, out_channels, kernel_size=3), 88 | nn.ReLU(), 89 | Conv1d(out_channels, out_channels, kernel_size=3, dilation=2) 90 | ) 91 | self.residual = nn.Sequential( 92 | Conv1d(in_channels, out_channels, kernel_size=1), 93 | nn.AvgPool1d(downsample_factor, stride=downsample_factor) 94 | ) 95 | 96 | def forward(self, inputs): 97 | outputs = self.layers(inputs) + self.residual(inputs) 98 | return outputs 99 | 100 | class ConditionalDBlocks(nn.Module): 101 | def __init__(self, 102 | in_channels, 103 | lc_channels, 104 | factors=(2, 2, 2), 105 | out_channels=(128, 256)): 106 | super(ConditionalDBlocks, self).__init__() 107 | 108 | assert len(factors) == len(out_channels) + 1 109 | 110 | self.in_channels = in_channels 111 | self.lc_channels = lc_channels 112 | self.factors = factors 113 | self.out_channels = out_channels 114 | 115 | self.layers = nn.ModuleList() 116 | self.layers.append(DBlock(in_channels, 64, 1)) 117 | in_channels = 64 118 | for (i, channel) in enumerate(out_channels): 119 | self.layers.append(DBlock(in_channels, channel, factors[i])) 120 | in_channels = channel 121 | 122 | self.cond_layer = CondDBlock(in_channels, lc_channels, factors[-1]) 123 | 124 | self.post_process = nn.ModuleList([ 125 | DBlock(in_channels * 2, in_channels * 2, 1), 126 | DBlock(in_channels * 2, in_channels * 2, 1) 127 | ]) 128 | 129 | def forward(self, inputs, conditions): 130 | batch_size = inputs.size()[0] 131 | outputs = inputs.view(batch_size, self.in_channels, -1) 132 | for layer in self.layers: 133 | outputs = layer(outputs) 134 | outputs = self.cond_layer(outputs, conditions) 135 | for layer in self.post_process: 136 | outputs = layer(outputs) 137 | 138 | return outputs 139 | 140 | class UnConditionalDBlocks(nn.Module): 141 | def __init__(self, 142 | in_channels, 143 | factors=(5, 3), 144 | out_channels=(128, 256)): 145 | super(UnConditionalDBlocks, self).__init__() 146 | 147 | self.in_channels = in_channels 148 | self.factors = factors 149 | self.out_channels = out_channels 150 | 151 | self.layers = nn.ModuleList() 152 | self.layers.append(DBlock(in_channels, 64, 1)) 153 | in_channels = 64 154 | for (i, factor) in enumerate(factors): 155 | self.layers.append(DBlock(in_channels, out_channels[i], factor)) 156 | in_channels = out_channels[i] 157 | self.layers.append(DBlock(in_channels, in_channels, 1)) 158 | self.layers.append(DBlock(in_channels, in_channels, 1)) 159 | 160 | def forward(self, inputs): 161 | batch_size = inputs.size()[0] 162 | outputs = inputs.view(batch_size, self.in_channels, -1) 163 | for layer in self.layers: 164 | outputs = layer(outputs) 165 | 166 | return outputs 167 | 168 | if __name__ == "__main__": 169 | model = Multiple_Random_Window_Discriminators(567) 170 | 171 | x = torch.randn(2, 1, 24000) 172 | y = torch.randn(2, 1, 24000) 173 | real_outputs = model(x) 174 | for real in real_outputs: 175 | print(real.shape) 176 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.dataset import CustomerDataset, CustomerCollate 3 | from torch.utils.data import DataLoader 4 | import torch.nn.parallel.data_parallel as parallel 5 | import torch.optim as optim 6 | import torch.nn as nn 7 | import argparse 8 | import os 9 | import time 10 | from models.generator import Generator 11 | from models.discriminator import Multiple_Random_Window_Discriminators 12 | from models.v2_discriminator import Discriminator 13 | from tensorboardX import SummaryWriter 14 | from utils.optimizer import Optimizer 15 | from utils.audio import hop_length 16 | from utils.loss import MultiResolutionSTFTLoss 17 | 18 | def create_model(args): 19 | 20 | generator = Generator(args.local_condition_dim, args.z_dim) 21 | #discriminator = Multiple_Random_Window_Discriminators(args.local_condition_dim) 22 | discriminator = Discriminator() 23 | 24 | return generator, discriminator 25 | 26 | def save_checkpoint(args, generator, discriminator, 27 | g_optimizer, d_optimizer, step, ema=None): 28 | checkpoint_path = os.path.join(args.checkpoint_dir, "model.ckpt-{}.pt".format(step)) 29 | 30 | torch.save({"generator": generator.state_dict(), 31 | "discriminator": discriminator.state_dict(), 32 | "g_optimizer": g_optimizer.state_dict(), 33 | "d_optimizer": d_optimizer.state_dict(), 34 | "global_step": step 35 | }, checkpoint_path) 36 | 37 | print("Saved checkpoint: {}".format(checkpoint_path)) 38 | 39 | with open(os.path.join(args.checkpoint_dir, 'checkpoint'), 'w') as f: 40 | f.write("model.ckpt-{}.pt".format(step)) 41 | 42 | def attempt_to_restore(generator, discriminator, g_optimizer, 43 | d_optimizer, checkpoint_dir, use_cuda): 44 | checkpoint_list = os.path.join(checkpoint_dir, 'checkpoint') 45 | 46 | if os.path.exists(checkpoint_list): 47 | checkpoint_filename = open(checkpoint_list).readline().strip() 48 | checkpoint_path = os.path.join( 49 | checkpoint_dir, "{}".format(checkpoint_filename)) 50 | print("Restore from {}".format(checkpoint_path)) 51 | checkpoint = load_checkpoint(checkpoint_path, use_cuda) 52 | generator.load_state_dict(checkpoint["generator"]) 53 | g_optimizer.load_state_dict(checkpoint["g_optimizer"]) 54 | discriminator.load_state_dict(checkpoint["discriminator"]) 55 | d_optimizer.load_state_dict(checkpoint["d_optimizer"]) 56 | global_step = checkpoint["global_step"] 57 | 58 | else: 59 | global_step = 0 60 | 61 | return global_step 62 | 63 | def load_checkpoint(checkpoint_path, use_cuda): 64 | if use_cuda: 65 | checkpoint = torch.load(checkpoint_path) 66 | else: 67 | checkpoint = torch.load( 68 | checkpoint_path, map_location=lambda storage, loc: storage) 69 | 70 | return checkpoint 71 | 72 | def train(args): 73 | 74 | os.makedirs(args.checkpoint_dir, exist_ok=True) 75 | 76 | train_dataset = CustomerDataset( 77 | args.input, 78 | upsample_factor=hop_length, 79 | local_condition=True, 80 | global_condition=False) 81 | 82 | device = torch.device("cuda" if args.use_cuda else "cpu") 83 | generator, discriminator = create_model(args) 84 | 85 | print(generator) 86 | print(discriminator) 87 | 88 | num_gpu = torch.cuda.device_count() if args.use_cuda else 1 89 | 90 | global_step = 0 91 | 92 | g_parameters = list(generator.parameters()) 93 | g_optimizer = optim.Adam(g_parameters, lr=args.g_learning_rate) 94 | 95 | d_parameters = list(discriminator.parameters()) 96 | d_optimizer = optim.Adam(d_parameters, lr=args.d_learning_rate) 97 | 98 | writer = SummaryWriter(args.checkpoint_dir) 99 | 100 | generator.to(device) 101 | discriminator.to(device) 102 | 103 | if args.resume is not None: 104 | restore_step = attempt_to_restore(generator, discriminator, g_optimizer, 105 | d_optimizer, args.resume, args.use_cuda) 106 | global_step = restore_step 107 | 108 | customer_g_optimizer = Optimizer(g_optimizer, args.g_learning_rate, 109 | global_step, args.warmup_steps, args.decay_learning_rate) 110 | customer_d_optimizer = Optimizer(d_optimizer, args.d_learning_rate, 111 | global_step, args.warmup_steps, args.decay_learning_rate) 112 | 113 | stft_criterion = MultiResolutionSTFTLoss().to(device) 114 | criterion = nn.MSELoss().to(device) 115 | 116 | for epoch in range(args.epochs): 117 | 118 | collate = CustomerCollate( 119 | upsample_factor=hop_length, 120 | condition_window=args.condition_window, 121 | local_condition=True, 122 | global_condition=False) 123 | 124 | train_data_loader = DataLoader(train_dataset, collate_fn=collate, 125 | batch_size=args.batch_size, num_workers=args.num_workers, 126 | shuffle=True, pin_memory=True) 127 | 128 | #train one epoch 129 | for batch, (samples, conditions) in enumerate(train_data_loader): 130 | 131 | start = time.time() 132 | batch_size = int(conditions.shape[0] // num_gpu * num_gpu) 133 | 134 | samples = samples[:batch_size, :].to(device) 135 | conditions = conditions[:batch_size, :, :].to(device) 136 | z = torch.randn(batch_size, args.z_dim).to(device) 137 | 138 | losses = {} 139 | 140 | if num_gpu > 1: 141 | g_outputs = parallel(generator, (conditions, z)) 142 | else: 143 | g_outputs = generator(conditions, z) 144 | 145 | #train discriminator 146 | if global_step > args.discriminator_train_start_steps: 147 | if num_gpu > 1: 148 | real_output = parallel(discriminator, (samples, )) 149 | fake_output = parallel(discriminator, (g_outputs.detach(), )) 150 | else: 151 | real_output = discriminator(samples, ) 152 | fake_output = discriminator(g_outputs.detach(), ) 153 | 154 | fake_loss = criterion(fake_output, torch.zeros_like(fake_output)) 155 | real_loss = criterion(real_output, torch.ones_like(real_output)) 156 | 157 | d_loss = fake_loss + real_loss 158 | 159 | customer_d_optimizer.zero_grad() 160 | d_loss.backward() 161 | nn.utils.clip_grad_norm_(d_parameters, max_norm=0.5) 162 | customer_d_optimizer.step_and_update_lr() 163 | else: 164 | d_loss = torch.Tensor([0]) 165 | fake_loss = torch.Tensor([0]) 166 | real_loss = torch.Tensor([0]) 167 | 168 | losses['fake_loss'] = fake_loss.item() 169 | losses['real_loss'] = real_loss.item() 170 | losses['d_loss'] = d_loss.item() 171 | 172 | #train generator 173 | if num_gpu > 1: 174 | fake_output = parallel(discriminator, (g_outputs, )) 175 | else: 176 | fake_output = discriminator(g_outputs) 177 | 178 | adv_loss = criterion(fake_output, torch.ones_like(fake_output)) 179 | 180 | sc_loss, mag_loss = stft_criterion(g_outputs.squeeze(1), samples.squeeze(1)) 181 | 182 | if global_step > args.discriminator_train_start_steps: 183 | g_loss = adv_loss * args.lamda_adv + sc_loss + mag_loss 184 | else: 185 | g_loss = sc_loss + mag_loss 186 | 187 | losses['adv_loss'] = adv_loss.item() 188 | losses['sc_loss'] = sc_loss 189 | losses['mag_loss'] = mag_loss 190 | losses['g_loss'] = g_loss.item() 191 | 192 | customer_g_optimizer.zero_grad() 193 | g_loss.backward() 194 | nn.utils.clip_grad_norm_(g_parameters, max_norm=0.5) 195 | customer_g_optimizer.step_and_update_lr() 196 | 197 | time_used = time.time() - start 198 | if global_step > args.discriminator_train_start_steps: 199 | print("Step: {} --adv_loss: {:.3f} --real_loss: {:.3f} --fake_loss: {:.3f} --sc_loss: {:.3f} --mag_loss: {:.3f} --Time: {:.2f} seconds".format( 200 | global_step, adv_loss, real_loss, fake_loss, sc_loss, mag_loss, time_used)) 201 | else: 202 | print("Step: {} --sc_loss: {:.3f} --mag_loss: {:.3f} --Time: {:.2f} seconds".format(global_step, sc_loss, mag_loss, time_used)) 203 | 204 | global_step += 1 205 | 206 | if global_step % args.checkpoint_step == 0: 207 | save_checkpoint(args, generator, discriminator, 208 | g_optimizer, d_optimizer, global_step) 209 | 210 | if global_step % args.summary_step == 0: 211 | for key in losses: 212 | writer.add_scalar('{}'.format(key), losses[key], global_step) 213 | 214 | def main(): 215 | 216 | def _str_to_bool(s): 217 | """Convert string to bool (in argparse context).""" 218 | if s.lower() not in ['true', 'false']: 219 | raise ValueError('Argument needs to be a ' 220 | 'boolean, got {}'.format(s)) 221 | return {'true': True, 'false': False}[s.lower()] 222 | 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument('--input', type=str, default='data/train', help='Directory of training data') 225 | parser.add_argument('--num_workers',type=int, default=4, help='Number of dataloader workers.') 226 | parser.add_argument('--epochs', type=int, default=50000) 227 | parser.add_argument('--checkpoint_dir', type=str, default="logdir", help="Directory to save model") 228 | parser.add_argument('--resume', type=str, default=None, help="The model name to restore") 229 | parser.add_argument('--checkpoint_step', type=int, default=5000) 230 | parser.add_argument('--summary_step', type=int, default=100) 231 | parser.add_argument('--use_cuda', type=_str_to_bool, default=True) 232 | parser.add_argument('--g_learning_rate', type=float, default=0.0001) 233 | parser.add_argument('--d_learning_rate', type=float, default=0.0001) 234 | parser.add_argument('--warmup_steps', type=int, default=200000) 235 | parser.add_argument('--decay_learning_rate', type=float, default=0.5) 236 | parser.add_argument('--local_condition_dim', type=int, default=80) 237 | parser.add_argument('--z_dim', type=int, default=128) 238 | parser.add_argument('--batch_size', type=int, default=30) 239 | parser.add_argument('--condition_window', type=int, default=100) 240 | parser.add_argument('--lamda_adv', type=float, default=4.0) 241 | parser.add_argument('--discriminator_train_start_steps', type=int, default=100000) 242 | 243 | args = parser.parse_args() 244 | train(args) 245 | 246 | if __name__ == "__main__": 247 | main() 248 | --------------------------------------------------------------------------------