├── README.md ├── audio_processing.py ├── config.json ├── extract_emb.py ├── filelists └── example_filelist.txt ├── hparams.py ├── inference.py ├── inference_textlist.txt ├── layers.py ├── mean_i2i.py ├── model_hifigan.py ├── modules ├── __pycache__ │ ├── SubLayers.cpython-36.pyc │ ├── SubLayers.cpython-38.pyc │ ├── align_loss.cpython-36.pyc │ ├── align_loss.cpython-38.pyc │ ├── aligner.cpython-36.pyc │ ├── aligner.cpython-38.pyc │ ├── attention.cpython-36.pyc │ ├── attention.cpython-38.pyc │ ├── attn_loss_function.cpython-36.pyc │ ├── attn_loss_function.cpython-38.pyc │ ├── commons.cpython-36.pyc │ ├── flow.cpython-36.pyc │ ├── init_layer.cpython-36.pyc │ ├── init_layer.cpython-37.pyc │ ├── init_layer.cpython-38.pyc │ ├── loss.cpython-36.pyc │ ├── loss.cpython-37.pyc │ ├── loss.cpython-38.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── modules.cpython-36.pyc │ ├── saln.cpython-36.pyc │ ├── saln.cpython-38.pyc │ ├── style.cpython-36.pyc │ ├── style.cpython-38.pyc │ ├── transformer.cpython-36.pyc │ ├── transformer.cpython-37.pyc │ ├── transformer.cpython-38.pyc │ ├── transforms.cpython-36.pyc │ └── vae.cpython-36.pyc ├── attention.py ├── attn_loss_function.py ├── init_layer.py ├── loss.py ├── model.py ├── saln.py ├── style.py └── transformer.py ├── preprocess.py ├── requirements.txt ├── stft.py ├── text ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── cleaners.cpython-35.pyc │ ├── cleaners.cpython-36.pyc │ ├── cleaners.cpython-37.pyc │ ├── cleaners.cpython-38.pyc │ ├── cmudict.cpython-35.pyc │ ├── cmudict.cpython-36.pyc │ ├── cmudict.cpython-37.pyc │ ├── numbers.cpython-35.pyc │ ├── numbers.cpython-36.pyc │ ├── numbers.cpython-37.pyc │ ├── symbols.cpython-35.pyc │ ├── symbols.cpython-36.pyc │ ├── symbols.cpython-37.pyc │ └── symbols.cpython-38.pyc ├── cleaners.py ├── english_utils │ ├── cleaners.py │ ├── cmudict.py │ ├── numbers.py │ └── symbols.py └── symbols.py ├── train.py └── utils ├── __pycache__ ├── data_utils.cpython-36.pyc ├── data_utils.cpython-37.pyc ├── data_utils.cpython-38.pyc ├── plot_image.cpython-36.pyc ├── plot_image.cpython-37.pyc ├── plot_image.cpython-38.pyc ├── test_utils.cpython-36.pyc ├── utils.cpython-36.pyc ├── utils.cpython-37.pyc ├── utils.cpython-38.pyc ├── writer.cpython-36.pyc ├── writer.cpython-37.pyc └── writer.cpython-38.pyc ├── data_utils.py ├── plot_image.py ├── utils.py └── writer.py /README.md: -------------------------------------------------------------------------------- 1 | # FluentTTS: Text-dependent Fine-grained Style Control for Multi-style TTS 2 | 3 | Official PyTorch Implementation of [FluentTTS: Text-dependent Fine-grained Style Control for Multi-style TTS](https://www.isca-speech.org/archive/pdfs/interspeech_2022/kim22j_interspeech.pdf). 4 | Codes are based on the [Acknowledgements](https://github.com/monglechap/fluenttts#acknowledgements) below. 5 | 6 | **Abstract**: In this paper, we propose a method to flexibly control the local prosodic variation of a neural text-to-speech (TTS) model. To provide expressiveness for synthesized speech, conventional TTS models utilize utterance-wise global style embeddings that are obtained by compressing frame-level embeddings along the time axis. However, since utterance-wise global features do not contain sufficient information to represent the characteristics of word-level local features, they are not appropriate for direct use on controlling prosody at a fine scale. 7 | In multi-style TTS models, it is very important to have the capability to control local prosody because it plays a key role in finding the most appropriate text-to-speech pair among many one-to-many mapping candidates. 8 | To explicitly present local prosodic characteristics to the contextual information of the corresponding input text, we propose a module to predict the fundamental frequency ( $F0$ ) of each text by conditioning on the utterance-wise global style embedding. 9 | We also estimate multi-style embeddings using a multi-style encoder, which takes as inputs both a global utterance-wise embedding and a local $F0$ embedding. 10 | Our multi-style embedding enhances the naturalness and expressiveness of synthesized speech and is able to control prosody styles at the word-level or phoneme-level. 11 | 12 | Visit our [Demo](https://kchap0118.github.io/fluenttts/) for audio samples. 13 | 14 | ## Prerequisites 15 | 16 | - Clone this repository 17 | - Install python requirements. Please refer [requirements.txt](requirements.txt) 18 | - Like [Code reference](https://github.com/Deepest-Project/Transformer-TTS), please modify return values of _torch.nn.funtional.multi_head_attention.forward()_ to draw attention of all head in the validation step. 19 | ``` 20 | #Before 21 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 22 | #After 23 | return attn_output, attn_output_weights 24 | ``` 25 | 26 | ## Preprocessing 27 | 28 | 1. Prepare text preprocessing 29 | 30 | 1-1. Our codes are used for internal Korean dataset. If you run the code with another languages, please modify files in [text](text/) and [hparams.py](hparams.py) that are related to symbols and text preprocessing. 31 | 32 | 1-2. Make data filelists like format of [filelists/example_filelist.txt](filelists/example_filelist.txt). They used for preprocessing and training. 33 | 34 | ``` 35 | /your/data/path/angry_f_1234.wav|your_data_text|speaker_type 36 | /your/data/path/happy_m_5678.wav|your_data_text|speaker_type 37 | /your/data/path/sadness_f_111.wav|your_data_test|speaker_type 38 | ... 39 | ``` 40 | 41 | 1-3. For finding the number of speaker and emotion and defining file names to save, we used format of [filelists/example_filelist.txt](filelists/example_filelist.txt). Thus, please modify the data-specific part (annotated) in [utils/data_utils.py](utils/data_utils.py), [extract_emb.py](extract_emb.py), [mean_i2i.py](mean_i2i.py) and [inference.py](inference.py) 42 | 43 | 1-4. Like 1-3., we implemented emotion classification loss based on the format of data. You can use classification loss as _nn.CrossEntropyLoss()_ instead. 44 | 2. Preprocessing 45 | 46 | 2-1. Before run [preprocess.py](preprocess.py), modify path (data path) and file_path (filelist that you make in _1-2_.) in the line [21](https://github.com/monglechap/fluenttts/blob/main/preprocess.py#L21) , [25](https://github.com/monglechap/fluenttts/blob/main/preprocess.py#L25). 47 | 48 | 2-2. Run 49 | 50 | ``` 51 | python preprocess.py 52 | ``` 53 | 54 | 2-3. Modify path of data, train and validation filelist [hparams.py](hparams.py) 55 | 56 | ## Training 57 | 58 | ``` 59 | python train.py -o [SAVE DIRECTORY PATH] -m [BASE OR PROP] 60 | ``` 61 | 62 | (Arguments) 63 | 64 | ``` 65 | -c: Ckpt path for loading 66 | -o: Path for saving ckpt and log 67 | -m: Choose baseline or proposed model 68 | ``` 69 | 70 | ## Inference 71 | 72 | 0. Mean (i2i) style embedding extraction (optional) 73 | 74 | 0-1. Extract emotion embeddings of dataset 75 | 76 | ``` 77 | python extract_emb.py -o [SAVE DIRECTORY PATH] -c [CHECKPOINT PATH] -m [BASE OR PROP] 78 | ``` 79 | 80 | (Arguments) 81 | 82 | ``` 83 | -o: Path for saving emotion embs 84 | -c: Ckpt path for loading 85 | -m: Choose baseline or proposed model 86 | ``` 87 | 88 | 0-2. Compute mean (or I2I) embs 89 | 90 | ``` 91 | python mean_i2i.py -i [EXTRACED EMB PATH] -o [SAVE DIRECTORY PATH] -m [NEU OR ALL] 92 | ``` 93 | 94 | (Arguments) 95 | 96 | ``` 97 | -i: Path of saved emotion embs 98 | -o: Path for saving mean or i2i embs 99 | -m: Set the farthest emotion as only neutral or other emotions (explained in mean_i2i.py) 100 | ``` 101 | 1. Inference 102 | 103 | ``` 104 | python inference.py -c [CHECKPOINT PATH] -v [VOCODER PATH] -s [MEAN EMB PATH] -o [SAVE DIRECTORY PATH] -m [BASE OR PROP] 105 | ``` 106 | 107 | (Arguments) 108 | 109 | ``` 110 | -c: Ckpt path of acoustic model 111 | -v: Ckpt path of vocoder (Hifi-GAN) 112 | -s (optional): Path of saved mean (i2i) embs 113 | -o: Path for saving generated wavs 114 | -m: Choose baseline or proposed model 115 | --control (optional): F0 controal at the utterance or phoneme-level 116 | --hz (optional): values to modify F0 117 | --ref_dir (optional): Path of reference wavs. Use when you do not apply mean (i2i) algs. 118 | --spk (optional): Use with --ref_dir 119 | --emo (optional): Use with --ref_dir 120 | ``` 121 | 122 | ## Acknowledgements 123 | 124 | We refered to the following codes for official version of implementation. 125 | 126 | 1. NVIDIA/tacotron2: [Link](https://github.com/NVIDIA/tacotron2) 127 | 2. Deepest-Project/Transformer-TTS: [Link](https://github.com/Deepest-Project/Transformer-TTS) 128 | 3. NVIDIA/FastPitch: [Link](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/FastPitch) 129 | 4. KevinMIN95/StyleSpeech: [Link](https://github.com/KevinMIN95/StyleSpeech) 130 | 5. Kyubong/g2pK: [Link](https://github.com/Kyubyong/g2pK) 131 | 6. jik876/hifi-gan: [Link](https://github.com/jik876/hifi-gan) 132 | 7. KinglittleQ/GST-Tacotron: [Link](https://github.com/KinglittleQ/GST-Tacotron) 133 | 134 | ## Citation 135 | 136 | ``` 137 | @article{kim2022fluenttts, 138 | title={FluentTTS: Text-dependent Fine-grained Style Control for Multi-style TTS$\}$$\}$}, 139 | author={Kim, Changhwan and Um, Se-yun and Yoon, Hyungchan and Kang, Hong-Goo}, 140 | journal={Proc. Interspeech 2022}, 141 | pages={4561--4565}, 142 | year={2022} 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, 8 | n_frames, 9 | hop_length=200, 10 | win_length=800, 11 | n_fft=800, 12 | dtype=np.float32, 13 | norm=None): 14 | """ 15 | # from librosa 0.6 16 | Compute the sum-square envelope of a window function at a given hop length. 17 | 18 | This is used to estimate modulation effects induced by windowing 19 | observations in short-time fourier transforms. 20 | 21 | Parameters 22 | ---------- 23 | window : string, tuple, number, callable, or list-like 24 | Window specification, as in `get_window` 25 | 26 | n_frames : int > 0 27 | The number of analysis frames 28 | 29 | hop_length : int > 0 30 | The number of samples to advance between frames 31 | 32 | win_length : [optional] 33 | The length of the window function. By default, this matches `n_fft`. 34 | 35 | n_fft : int > 0 36 | The length of each analysis frame. 37 | 38 | dtype : np.dtype 39 | The data type of the output 40 | 41 | Returns 42 | ------- 43 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 44 | The sum-squared envelope of the window function 45 | """ 46 | if win_length is None: 47 | win_length = n_fft 48 | 49 | n = n_fft + hop_length * (n_frames - 1) 50 | x = np.zeros(n, dtype=dtype) 51 | 52 | # Compute the squared window at the desired length 53 | win_sq = get_window(window, win_length, fftbins=True) 54 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 55 | win_sq = librosa_util.pad_center(win_sq, n_fft) 56 | 57 | # Fill the envelope 58 | for i in range(n_frames): 59 | sample = i * hop_length 60 | x[sample:min(n, sample+n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 61 | return x 62 | 63 | 64 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 65 | """ 66 | PARAMS 67 | ------ 68 | magnitudes: spectrogram magnitudes 69 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 70 | """ 71 | 72 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 73 | angles = angles.astype(np.float32) 74 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 75 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 76 | 77 | for i in range(n_iters): 78 | _, angles = stft_fn.transform(signal) 79 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 80 | return signal 81 | 82 | 83 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 84 | """ 85 | PARAMS 86 | ------ 87 | C: compression factor 88 | """ 89 | return torch.log(torch.clamp(x, min=clip_val) * C) 90 | 91 | 92 | def dynamic_range_decompression(x, C=1): 93 | """ 94 | PARAMS 95 | ------ 96 | C: compression factor used to compress 97 | """ 98 | return torch.exp(x) / C -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0002, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | 11 | "upsample_rates": [5,5,4,2], 12 | "upsample_kernel_sizes": [9,9,8,4], 13 | "upsample_initial_channel": 512, 14 | "resblock_kernel_sizes": [3,7,11], 15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 16 | 17 | "segment_size": 6400, 18 | "num_mels": 80, 19 | "num_freq": 1025, 20 | "n_fft": 1024, 21 | "hop_size": 200, 22 | "win_size": 800, 23 | 24 | "sampling_rate": 16000, 25 | 26 | "fmin": 50, 27 | "fmax": 7200, 28 | "fmax_for_loss": null, 29 | 30 | "num_workers": 4, 31 | 32 | "dist_config": { 33 | "dist_backend": "nccl", 34 | "dist_url": "tcp://localhost:54321", 35 | "world_size": 1 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /extract_emb.py: -------------------------------------------------------------------------------- 1 | import os, argparse, torch, pdb, sys 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | 7 | import hparams 8 | from modules.model import FluentTTS 9 | from utils.utils import * 10 | 11 | 12 | def main(args): 13 | """Extract emotion embeddings of all dataset for calculating mean of emotion embeddings""" 14 | if not os.path.isdir(args.emb_dir): 15 | os.mkdir(args.emb_dir) 16 | mode = args.mode 17 | 18 | # Prepare valid or test dataset 19 | _, val_loader, collate_fn = prepare_dataloaders(hparams) 20 | 21 | # Load acoustic model 22 | model = FluentTTS(hparams, mode).cuda() 23 | optimizer = torch.optim.Adam(model.parameters(), lr=hparams.lr, betas=(0.9, 0.98), eps=1e-09) 24 | model, _, _, _, _ = load_checkpoint(args.checkpoint_path, model, optimizer) 25 | model.eval() 26 | 27 | # Extract 28 | for batch in tqdm(val_loader): 29 | text_padded, text_lengths, mel_padded, mel_lengths, gate_padded, \ 30 | f0_padded, prior_padded, name, spk, emo = [x for x in batch] 31 | 32 | # Style embeddings 33 | style = model.Emo_encoder(mel_padded.cuda()).transpose(0,1).squeeze() 34 | # spk_emb = model.Spk_encoder(spk.cuda()).unsqueeze(0) 35 | # emb = torch.cat((spk_emb, style), dim=2) 36 | # style = model.Global_style_encoder(emb).transpose(0,1) 37 | 38 | # Data-specific name definition. You should change this codes for your own data structure 39 | # In our dataset, we divide name for 4 speakers (f1, f2, m1, m2) and 4 emotions (a, h, s, n) 40 | for k in range(len(style)): 41 | emo, spk, idx = name[k].split('_') 42 | if spk == 'f': 43 | if len(idx) == 4: 44 | spk = 'f2' 45 | else: 46 | spk = 'f1' 47 | elif spk == 'm': 48 | if len(idx) == 4: 49 | spk = 'm2' 50 | else: 51 | spk = 'm1' 52 | 53 | # Save 54 | name[k] = emo + '_' + spk + '_' + idx 55 | np.save(os.path.join(args.emb_dir, name[k]), style[k].detach().cpu().numpy()) 56 | 57 | if __name__ == '__main__': 58 | p = argparse.ArgumentParser() 59 | p.add_argument('--gpu', type=str, default='0,1') 60 | p.add_argument('-v', '--verbose', type=str, default='0') 61 | p.add_argument('-o', '--emb_dir', type=str, default='emb_dir', help='Directory for saving style embeddings') 62 | p.add_argument('-c', '--checkpoint_path', type=str, required=True) 63 | p.add_argument('-m', '--mode', type=str, help='base, prop') 64 | 65 | args = p.parse_args() 66 | 67 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 68 | torch.manual_seed(hparams.seed) 69 | torch.cuda.manual_seed(hparams.seed) 70 | 71 | if args.verbose=='0': 72 | import warnings 73 | warnings.filterwarnings("ignore") 74 | 75 | main(args) 76 | -------------------------------------------------------------------------------- /filelists/example_filelist.txt: -------------------------------------------------------------------------------- 1 | /YOUR/DATA/PATH/wavs/anger_m_999.wav|안녕하세요.|m1 2 | 3 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from text import symbols 2 | 3 | ### Experiment Parameters 4 | seed = 118 5 | n_gpus = 1 6 | data_path = '/YOUR/PREPROCESSED/DATA/PATH' 7 | training_files = 'filelists/your_train_file.txt' 8 | validation_files = 'filelists/your_valid_file.txt' 9 | inference_files = 'inference_textlist.txt' 10 | text_cleaners = ['basic_cleaners'] # For Korean 11 | 12 | ### Audio Parameters 13 | sampling_rate = 16000 14 | filter_length = 1024 15 | hop_length = 200 # 12.5ms 16 | win_length = 800 # 50ms 17 | n_mel_channels = 80 18 | mel_fmin = 50 19 | mel_fmax = 7200 20 | 21 | ### Model Parameters 22 | n_symbols = len(symbols) 23 | symbols_embedding_dim = 256 24 | hidden_dim = 256 25 | spk_hidden_dim = 16 26 | dprenet_dim = 32 27 | ff_dim = 1024 28 | n_heads = 2 29 | n_layers = 4 30 | sliding_window = [-1, 4] # For sliding window attention in inference 31 | 32 | # Multi-style generation 33 | ms_kernel = 3 34 | n_layers_lp_enc = 6 35 | 36 | # reference encoder 37 | E = 256 38 | ref_enc_filters = [32,32,64,64,128,128] 39 | ref_enc_size = [3,3] 40 | ref_enc_strides = [2,2] 41 | ref_enc_pad = [1,1] 42 | ref_enc_gru_size = E // 2 43 | 44 | # Loss scale 45 | emo_scale = 1.0 46 | f0_scale = 1.0 47 | kl_scale = 0.1 48 | 49 | # Dataset configuration 50 | num_spk = 4 51 | num_emo = 4 52 | 53 | ### Optimization Hyperparameters 54 | lr = 0.05 55 | batch_size = 32 56 | warmup_steps = 4000 57 | grad_clip_thresh = 1.0 58 | 59 | iters_per_validation = 5000 60 | iters_per_checkpoint = 10000 61 | 62 | training_epochs = 100000 63 | train_steps = 500000 64 | local_style_step = 20000 65 | bin_loss_enable_steps = 10000 66 | bin_loss_warmup_steps = 5000 67 | 68 | ### HiFi-GAN 69 | resblock = "1" 70 | num_gpus = 0 71 | learning_rate = 0.0002 72 | adam_b1 = 0.8 73 | adam_b2 = 0.99 74 | lr_decay = 0.999 75 | 76 | upsample_rates = [5,5,4,2] 77 | upsample_kernel_sizes = [9,9,8,4] 78 | upsample_initial_channel = 512 79 | resblock_kernel_sizes = [3,7,11] 80 | resblock_dilation_sizes = [[1,3,5], [1,3,5], [1,3,5]] 81 | 82 | segment_size = 6400 83 | num_mels = 80 84 | num_freq = 1025 85 | n_fft = 1024 86 | hop_size = 200 87 | win_size = 800 88 | 89 | fmin = 50 90 | fmax = 7200 91 | num_workers = 4 92 | 93 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os, sys, argparse, librosa, torch, pdb, glob, warnings 2 | import numpy as np 3 | import torch.nn as nn 4 | import matplotlib.pyplot as plt 5 | from scipy.io.wavfile import write 6 | from g2pK.g2pk.g2pk import G2p 7 | 8 | import hparams 9 | from text import * 10 | from text.symbols import symbols 11 | from text.cleaners import basic_cleaners 12 | from modules.model import FluentTTS 13 | from utils.utils import * 14 | from utils.data_utils import process_meta, create_id_table 15 | from model_hifigan import Generator 16 | from layers import TacotronSTFT 17 | 18 | 19 | # Prepare text preprocessing 20 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 21 | id_to_symbol = {i: s for i, s in enumerate(symbols)} 22 | stft = TacotronSTFT() 23 | 24 | def text2seq(text): 25 | """Text preprocessing""" 26 | sequence=[symbol_to_id['^']] 27 | sequence.extend(text_to_sequence(text, hparams.text_cleaners)) 28 | sequence.append(symbol_to_id['~']) 29 | return sequence 30 | 31 | def text2seq_target(text): 32 | """For word & phoneme-level f0 control""" 33 | sequence=[] 34 | sequence.extend(text_to_sequence(text, hparams.text_cleaners)) 35 | return sequence 36 | 37 | def get_mel(filename): 38 | """Prepare mel spectrogram from reference wav""" 39 | wav, sr = librosa.load(filename, sr=hparams.sampling_rate) 40 | wav = librosa.effects.trim(wav, top_db=23, frame_length=1024, hop_length=256)[0] 41 | wav = torch.FloatTensor(wav.astype(np.float32)) 42 | melspec, _ = stft.mel_spectrogram(wav.unsqueeze(0)) 43 | return melspec.squeeze(0) 44 | 45 | 46 | def synthesize(args, style_list): 47 | # Load Acoustic model 48 | mode = args.mode 49 | 50 | model = FluentTTS(hparams, mode).cuda() 51 | model, _, _, _, _ = load_checkpoint(args.checkpoint_path, model, None) 52 | model.cuda().eval() 53 | 54 | # Load Vocoder 55 | generator = Generator(hparams).cuda() 56 | state_dict_g = torch.load(args.vocoder) 57 | generator.load_state_dict(state_dict_g['generator']) 58 | generator.eval() 59 | generator.remove_weight_norm() 60 | 61 | # Prepare speaker ID and input text 62 | _, speakers, _ = process_meta(hparams.validation_files) 63 | sid_dict = create_id_table(speakers) # {'f1': 0, 'f2': 1, 'm1': 2, 'm2': 3} 64 | 65 | g2p = G2p() 66 | with open(hparams.inference_files, 'r') as f: 67 | text_list = f.readlines() 68 | idx = 0 # Text number 69 | 70 | # Inference for each utterance 71 | for text in text_list: 72 | # Word & Phoneme-level F0 control 73 | if args.control == 'pho': 74 | text, target = text.strip('\n').split('|') 75 | text, target = g2p(text), g2p(target) 76 | print(f'{text}|{target}') 77 | 78 | src_seq = np.array(text2seq(text)) 79 | tgt_seq = np.array(text2seq_target(target)) 80 | # Find index of target sequence in source sequence 81 | i = 0 82 | while True: 83 | if np.equal(src_seq[i:len(tgt_seq)+i], tgt_seq).all(): 84 | start, end = i, len(tgt_seq) + i 85 | break 86 | else: 87 | i += 1 88 | print(start, end) 89 | # Utterance-level 90 | else: 91 | text = text.strip('\n') 92 | text = g2p(text) 93 | print(text) 94 | 95 | # Inference for each style vectors 96 | for style_path in style_list: 97 | # Text sequence 98 | sequence = np.array(text2seq(text))[None, :] 99 | sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() 100 | 101 | # Style vector of mean or i2i or reference wav 102 | if style_path[-4:] == '.npy': 103 | style = torch.from_numpy(np.load(style_path)).view(1,1,-1).cuda() 104 | spk = style_path.split('_')[-1][:-4] 105 | emo = style_path.split('_')[-2][0] 106 | elif style_path[-4:] == '.wav': 107 | ref_mel, _, _ = get_mel(style_path) 108 | ref_mel = torch.from_numpy(ref_mel).unsqueeze(0).float().cuda() 109 | style = model.Emo_encoder(ref_mel, logit=False).transpose(0,1) 110 | spk, emo = args.spk, args.emo 111 | 112 | # Load mean and std of f0 113 | ms_path = os.path.join(hparams.data_path, 'mean_std/') 114 | mean_name = 'mean_' + spk + '_' + emo + '.npy' 115 | std_name = 'std_' + spk + '_' + emo + '.npy' 116 | f0_mean = torch.from_numpy(np.load(os.path.join(ms_path, mean_name))).cuda() 117 | f0_std = torch.from_numpy(np.load(os.path.join(ms_path, std_name))).cuda() 118 | 119 | # Speaker ID 120 | spk_id = sid_dict[spk] 121 | spk_id = torch.LongTensor([spk_id]).cuda() 122 | 123 | # Inference 124 | with torch.no_grad(): 125 | # Word & Phoneme-level F0 control 126 | if args.control == 'pho': 127 | melspec, enc_alignments, dec_alignments, enc_dec_alignments, stop = model.inference(sequence, style, spk_id, f0_mean, f0_std, max_len=512, mode=mode, start=start, end=end, hz=args.hz) 128 | # Uttr or not controlling F0 129 | else: 130 | melspec, enc_alignments, dec_alignments, enc_dec_alignments, stop = model.inference(sequence, style, spk_id, f0_mean, f0_std, max_len=512, mode=mode, hz=args.hz) 131 | 132 | T=len(stop) 133 | melspec = melspec[:,:,:T] 134 | 135 | # Waveform generation 136 | y_g_hat = generator(melspec) 137 | audio = y_g_hat.squeeze() 138 | audio = audio*32768 139 | audio = audio.detach().cpu().numpy().astype('int16') 140 | name = style_path.split('/')[-1][:-4][4:] 141 | output = args.out_dir + '/' + str(idx) + '_' + args.mode + '_' + name + '.wav' 142 | write(output, hparams.sampling_rate, audio) 143 | print(output) 144 | 145 | # Plot mel spectrogram 146 | plot_mel = melspec.squeeze(0).detach().cpu().numpy() 147 | plt.figure(figsize=(8,6)) 148 | plt.imshow(plot_mel, origin='lower', aspect='auto') 149 | name = output[:-4] + '.png' 150 | plt.savefig(name) 151 | plt.close() 152 | 153 | idx += 1 154 | 155 | if __name__ == '__main__': 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument('--gpu', type=str, default='0,1') 158 | parser.add_argument('-c', '--checkpoint_path', type=str) 159 | parser.add_argument('-v', '--vocoder', type=str, help='ckpt path of Hifi-GAN') 160 | parser.add_argument('-s', '--style_dir', type=str, default='emb_mean_dir') 161 | parser.add_argument('-o', '--out_dir', type=str, default='generated_files') 162 | parser.add_argument('-m', '--mode', type=str, help='base, prop') 163 | parser.add_argument('--control', type=str, default=None, help='uttr, pho') 164 | parser.add_argument('--hz', type=float, default=None, help='value to modify f0') 165 | parser.add_argument('--ref_dir', type=str, default=None, help='use when using referece wav') 166 | parser.add_argument('--spk', type=str, default='f2', help='use when using reference wav') 167 | parser.add_argument('--emo', type=str, default='a', help='use when using reference wav') 168 | 169 | args = parser.parse_args() 170 | 171 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 172 | torch.manual_seed(hparams.seed) 173 | torch.cuda.manual_seed(hparams.seed) 174 | 175 | os.makedirs(args.out_dir, exist_ok=True) 176 | 177 | if args.ref_dir is not None: 178 | style_path = args.ref_dir + '**/*.wav' 179 | else: 180 | style_path = args.style_dir + '**/*.npy' 181 | 182 | style_list = [file for file in glob.glob(style_path, recursive=True)] 183 | print(f'Number of style vector: {len(style_list)}') 184 | 185 | warnings.filterwarnings('ignore') 186 | 187 | synthesize(args, style_list) 188 | 189 | -------------------------------------------------------------------------------- /inference_textlist.txt: -------------------------------------------------------------------------------- 1 | 이번 겨울엔, 다같이 놀러가면 좋겠다. 2 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from librosa.filters import mel as librosa_mel_fn 5 | from audio_processing import dynamic_range_compression 6 | from audio_processing import dynamic_range_decompression 7 | from stft import STFT 8 | 9 | 10 | class TacotronSTFT(torch.nn.Module): 11 | def __init__(self, filter_length=1024, hop_length=200, win_length=800, 12 | n_mel_channels=80, sampling_rate=16000, mel_fmin=50.0, 13 | mel_fmax=7200.0): 14 | super(TacotronSTFT, self).__init__() 15 | self.n_mel_channels = n_mel_channels 16 | self.sampling_rate = sampling_rate 17 | self.stft_fn = STFT(filter_length, hop_length, win_length) 18 | mel_basis = librosa_mel_fn( 19 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 20 | mel_basis = torch.from_numpy(mel_basis).float() 21 | self.register_buffer('mel_basis', mel_basis) 22 | 23 | def spectral_normalize(self, magnitudes): 24 | output = dynamic_range_compression(magnitudes) 25 | return output 26 | 27 | def spectral_de_normalize(self, magnitudes): 28 | output = dynamic_range_decompression(magnitudes) 29 | return output 30 | 31 | def mel_spectrogram(self, y): 32 | """Computes mel-spectrograms from a batch of waves 33 | PARAMS 34 | ------ 35 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 36 | RETURNS 37 | ------- 38 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 39 | """ 40 | assert(torch.min(y.data) >= -1) 41 | assert(torch.max(y.data) <= 1) 42 | 43 | magnitudes, phases = self.stft_fn.transform(y) 44 | magnitudes = magnitudes.data 45 | mel_output = torch.matmul(self.mel_basis, magnitudes) 46 | mel_output = self.spectral_normalize(mel_output) 47 | energy = torch.norm(magnitudes, dim=1) 48 | 49 | return mel_output, energy 50 | -------------------------------------------------------------------------------- /mean_i2i.py: -------------------------------------------------------------------------------- 1 | import os, glob, argparse, pdb 2 | import numpy as np 3 | 4 | import hparams 5 | 6 | 7 | def stack_emb(style_list, emb_dim=None): 8 | """Stack all emotion embeddings""" 9 | stacked = np.array([]).astype('float32').reshape(0, emb_dim) 10 | 11 | for fpath in style_list: 12 | emb = np.load(fpath).reshape(1, emb_dim) 13 | stacked = np.concatenate((stacked, emb)) 14 | 15 | return stacked 16 | 17 | def get_distance_single(current_sample, other_stack): 18 | """Compute distance between current sample and all samples of other stack""" 19 | distance = np.mean(np.sqrt(np.sum((current_sample - other_stack) ** 2, axis=1)), axis=0) 20 | return distance 21 | 22 | def get_distance(current_stack, other_stack): 23 | """Compute distances among samples of current stack and samples of other stack""" 24 | distances = np.array([get_distance_single(sample, other_stack) for sample in current_stack]) 25 | return distances 26 | 27 | def get_i2i(current_stack, inter_distance, intra_distance, eps=np.finfo(np.float32).eps): 28 | """I2I algorithm""" 29 | ratio = inter_distance / (eps + intra_distance) 30 | return current_stack[np.argmax(ratio)] 31 | 32 | def i2i_neutral(emb_dir, out_dir, speakers, emotions): 33 | """This code assumes that setting farthest emotion as neutral for all emotion""" 34 | # I2I algorithm depends on the distribution of the data, thus, the general approach is set the farthest emotion as neutral. 35 | # This code assumes that setting farthest emotion as neutral. 36 | # Therefore, 'neutral' should be first in the emotions list. 37 | 38 | for spk in speakers: 39 | for emo in emotions: 40 | current_style_list = glob.glob(os.path.join(emb_dir, emo) + '_' + spk + '*') 41 | current_style = stack_emb(current_style_list, hparams.E) 42 | mean_name = 'mean_' + emo + '_' + spk 43 | i2i_name = 'i2i_' + emo + '_' + spk 44 | 45 | if emo == 'neutral': 46 | mean_emb = np.mean(current_style, axis=0) 47 | np.save(os.path.join(out_dir, mean_name), mean_emb) 48 | far_style = current_style # For i2i 49 | 50 | else: 51 | if len(current_style) != 0: 52 | # mean_emb = np.mean(current_style, axis=0) 53 | # np.save(os.path.join(out_dir, mean_name), mean_emb) 54 | 55 | inter_dist = get_distance(current_style, far_style) 56 | intra_dist = get_distance(current_style, current_style) 57 | i2i_emb = get_i2i(current_style, inter_dist, intra_dist) 58 | 59 | np.save(os.path.join(out_dir, i2i_name), i2i_emb) 60 | 61 | def i2i_all(emb_dir, out_dir, speakers, emotions): 62 | """This code is for obtain i2i embedding while considering all emotion for farthest emotion""" 63 | for spk in speakers: 64 | for emo in emotions: 65 | current_style_list = glob.glob(os.path.join(emb_dir, emo) + '_' + spk + '*') 66 | current_style = stack_emb(current_style_list, hparams.E) 67 | 68 | if len(current_style) != 0: 69 | if emo == 'anger' or emo =='angry': 70 | ang_style = current_style 71 | elif emo == 'sadness' or emo =='sad': 72 | sad_style = current_style 73 | elif emo == 'happy': 74 | hap_style = current_style 75 | elif emo == 'neutral': 76 | neu_style = current_style 77 | 78 | # Angry 79 | inter_dist = (get_distance(ang_style, hap_style) + get_distance(ang_style, sad_style) + get_distance(ang_style, neu_style)) / 3 80 | intra_dist = get_distance(ang_style, ang_style) 81 | 82 | i2i_emb = get_i2i(ang_style, inter_dist, intra_dist) 83 | 84 | np.save(os.path.join(out_dir, 'i2i_angry_' + spk + '.npy'), i2i_emb) 85 | 86 | # Happy 87 | inter_dist = (get_distance(hap_style, ang_style) + get_distance(hap_style, sad_style) + get_distance(hap_style, neu_style)) / 3 88 | intra_dist = get_distance(hap_style, hap_style) 89 | 90 | i2i_emb = get_i2i(hap_style, inter_dist, intra_dist) 91 | 92 | np.save(os.path.join(out_dir, 'i2i_happy_' + spk + '.npy'), i2i_emb) 93 | 94 | # Sad 95 | inter_dist = (get_distance(sad_style, ang_style) + get_distance(sad_style, hap_style) + get_distance(sad_style, neu_style)) / 3 96 | intra_dist = get_distance(sad_style, sad_style) 97 | 98 | i2i_emb = get_i2i(sad_style, inter_dist, intra_dist) 99 | 100 | np.save(os.path.join(out_dir, 'i2i_sad_' + spk + '.npy'), i2i_emb) 101 | 102 | # Neutral 103 | inter_dist = (get_distance(neu_style, ang_style) + get_distance(neu_style, hap_style) + get_distance(neu_style, sad_style)) / 3 104 | intra_dist = get_distance(neu_style, neu_style) 105 | 106 | i2i_emb = get_i2i(neu_style, inter_dist, intra_dist) 107 | 108 | np.save(os.path.join(out_dir, 'i2i_neutral_' + spk + '.npy'), i2i_emb) 109 | 110 | 111 | def main(emb_dir, out_dir, mode): 112 | if not os.path.isdir(out_dir): 113 | os.mkdir(out_dir) 114 | 115 | # Data-specific configuration. You should change this codes for your own data structure 116 | speakers = ['m1', 'm2', 'f1', 'f2'] 117 | emotions = ['neutral', 'anger', 'angry', 'sadness', 'sad', 'happy'] 118 | 119 | if mode == 'neu': 120 | i2i_neutral(emb_dir, out_dir, speakers, emotions) 121 | elif mode == 'all': 122 | i2i_all(emb_dir, out_dir, speakers, emotions) 123 | 124 | print('Processing done') 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('-i', '--emb_dir', type=str, default='emb_dir', help='emotion embeddings extracted from data') 130 | parser.add_argument('-o', '--out_dir', type=str, default='emb_mean_dir', help='mean or i2i embeddings') 131 | parser.add_argument('--mode', type=str, default='neu', help='neu, all') 132 | args = parser.parse_args() 133 | 134 | main(args.emb_dir, args.out_dir, args.mode) 135 | -------------------------------------------------------------------------------- /model_hifigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | from utils.utils import init_weights, get_padding 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | 11 | class ResBlock1(torch.nn.Module): 12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 13 | super(ResBlock1, self).__init__() 14 | self.h = h 15 | self.convs1 = nn.ModuleList([ 16 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 17 | padding=get_padding(kernel_size, dilation[0]))), 18 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 19 | padding=get_padding(kernel_size, dilation[1]))), 20 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 21 | padding=get_padding(kernel_size, dilation[2]))) 22 | ]) 23 | self.convs1.apply(init_weights) 24 | 25 | self.convs2 = nn.ModuleList([ 26 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 27 | padding=get_padding(kernel_size, 1))), 28 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 29 | padding=get_padding(kernel_size, 1))), 30 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 31 | padding=get_padding(kernel_size, 1))) 32 | ]) 33 | self.convs2.apply(init_weights) 34 | 35 | def forward(self, x): 36 | for c1, c2 in zip(self.convs1, self.convs2): 37 | xt = F.leaky_relu(x, LRELU_SLOPE) 38 | xt = c1(xt) 39 | xt = F.leaky_relu(xt, LRELU_SLOPE) 40 | xt = c2(xt) 41 | x = xt + x 42 | return x 43 | 44 | def remove_weight_norm(self): 45 | for l in self.convs1: 46 | remove_weight_norm(l) 47 | for l in self.convs2: 48 | remove_weight_norm(l) 49 | 50 | 51 | class ResBlock2(torch.nn.Module): 52 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 53 | super(ResBlock2, self).__init__() 54 | self.h = h 55 | self.convs = nn.ModuleList([ 56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 57 | padding=get_padding(kernel_size, dilation[0]))), 58 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 59 | padding=get_padding(kernel_size, dilation[1]))) 60 | ]) 61 | self.convs.apply(init_weights) 62 | 63 | def forward(self, x): 64 | for c in self.convs: 65 | xt = F.leaky_relu(x, LRELU_SLOPE) 66 | xt = c(xt) 67 | x = xt + x 68 | return x 69 | 70 | def remove_weight_norm(self): 71 | for l in self.convs: 72 | remove_weight_norm(l) 73 | 74 | 75 | class Generator(torch.nn.Module): 76 | def __init__(self, h): 77 | super(Generator, self).__init__() 78 | self.h = h 79 | self.num_kernels = len(h.resblock_kernel_sizes) 80 | self.num_upsamples = len(h.upsample_rates) 81 | self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) 82 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 83 | 84 | self.ups = nn.ModuleList() 85 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 86 | self.ups.append(weight_norm( 87 | ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), 88 | k, u, padding=(k-u)//2))) 89 | 90 | self.resblocks = nn.ModuleList() 91 | for i in range(len(self.ups)): 92 | ch = h.upsample_initial_channel//(2**(i+1)) 93 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 94 | self.resblocks.append(resblock(h, ch, k, d)) 95 | 96 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 97 | self.ups.apply(init_weights) 98 | self.conv_post.apply(init_weights) 99 | #self.fc_post = nn.Linear(h.upsample_initial_channel*(2**len(h.upsample_rates)), h.segment_size) 100 | #self.fc_post_val = nn.Linear((h.upsample_initial_channel*(2**len(h.upsample_rates)), h.sampling_rate*) 101 | 102 | def forward(self, x): # [B, 80, 32(16kHz, frames)] 103 | x = self.conv_pre(x) # [B, 512, 32] 104 | for i in range(self.num_upsamples): 105 | x = F.leaky_relu(x, LRELU_SLOPE) 106 | x = self.ups[i](x) # [B, 256, 256] -> [B, 128, 2048] -> [B, 64, 4096] -> [B, 32, 8192] 107 | xs = None 108 | for j in range(self.num_kernels): 109 | if xs is None: 110 | xs = self.resblocks[i*self.num_kernels+j](x) 111 | #print('res 1: {}'.format(xs.size())) 112 | else: 113 | xs += self.resblocks[i*self.num_kernels+j](x) 114 | #print('res 2: {}'.format(xs.size())) 115 | x = xs / self.num_kernels 116 | #if x.size(-1) == 8192: 117 | # x = self.fc_post(x) 118 | x = F.leaky_relu(x) 119 | x = self.conv_post(x) 120 | x = torch.tanh(x) 121 | 122 | return x 123 | 124 | def remove_weight_norm(self): 125 | print('Removing weight norm...') 126 | for l in self.ups: 127 | remove_weight_norm(l) 128 | for l in self.resblocks: 129 | l.remove_weight_norm() 130 | remove_weight_norm(self.conv_pre) 131 | remove_weight_norm(self.conv_post) 132 | 133 | 134 | class DiscriminatorP(torch.nn.Module): 135 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 136 | super(DiscriminatorP, self).__init__() 137 | self.period = period 138 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 139 | self.convs = nn.ModuleList([ 140 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 141 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 142 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 143 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 144 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 145 | ]) 146 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 147 | 148 | def forward(self, x): 149 | fmap = [] 150 | 151 | # 1d to 2d 152 | b, c, t = x.shape 153 | if t % self.period != 0: # pad first 154 | n_pad = self.period - (t % self.period) 155 | x = F.pad(x, (0, n_pad), "reflect") 156 | t = t + n_pad 157 | x = x.view(b, c, t // self.period, self.period) 158 | 159 | for l in self.convs: 160 | x = l(x) 161 | x = F.leaky_relu(x, LRELU_SLOPE) 162 | fmap.append(x) 163 | x = self.conv_post(x) 164 | fmap.append(x) 165 | x = torch.flatten(x, 1, -1) 166 | 167 | return x, fmap 168 | 169 | 170 | class MultiPeriodDiscriminator(torch.nn.Module): 171 | def __init__(self): 172 | super(MultiPeriodDiscriminator, self).__init__() 173 | self.discriminators = nn.ModuleList([ 174 | DiscriminatorP(2), 175 | DiscriminatorP(3), 176 | DiscriminatorP(5), 177 | DiscriminatorP(7), 178 | DiscriminatorP(11), 179 | ]) 180 | 181 | def forward(self, y, y_hat): 182 | y_d_rs = [] 183 | y_d_gs = [] 184 | fmap_rs = [] 185 | fmap_gs = [] 186 | for i, d in enumerate(self.discriminators): 187 | y_d_r, fmap_r = d(y) 188 | y_d_g, fmap_g = d(y_hat) 189 | y_d_rs.append(y_d_r) 190 | fmap_rs.append(fmap_r) 191 | y_d_gs.append(y_d_g) 192 | fmap_gs.append(fmap_g) 193 | 194 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 195 | 196 | 197 | class DiscriminatorS(torch.nn.Module): 198 | def __init__(self, use_spectral_norm=False): 199 | super(DiscriminatorS, self).__init__() 200 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 201 | self.convs = nn.ModuleList([ 202 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 203 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 204 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 205 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 206 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 207 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 208 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 209 | ]) 210 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 211 | 212 | def forward(self, x): 213 | fmap = [] 214 | for l in self.convs: 215 | x = l(x) 216 | x = F.leaky_relu(x, LRELU_SLOPE) 217 | fmap.append(x) 218 | x = self.conv_post(x) 219 | fmap.append(x) 220 | x = torch.flatten(x, 1, -1) 221 | 222 | return x, fmap 223 | 224 | 225 | class MultiScaleDiscriminator(torch.nn.Module): 226 | def __init__(self): 227 | super(MultiScaleDiscriminator, self).__init__() 228 | self.discriminators = nn.ModuleList([ 229 | DiscriminatorS(use_spectral_norm=True), 230 | DiscriminatorS(), 231 | DiscriminatorS(), 232 | ]) 233 | self.meanpools = nn.ModuleList([ 234 | AvgPool1d(4, 2, padding=2), 235 | AvgPool1d(4, 2, padding=2) 236 | ]) 237 | 238 | def forward(self, y, y_hat): 239 | y_d_rs = [] 240 | y_d_gs = [] 241 | fmap_rs = [] 242 | fmap_gs = [] 243 | for i, d in enumerate(self.discriminators): 244 | if i != 0: 245 | y = self.meanpools[i-1](y) 246 | y_hat = self.meanpools[i-1](y_hat) 247 | y_d_r, fmap_r = d(y) 248 | y_d_g, fmap_g = d(y_hat) 249 | y_d_rs.append(y_d_r) 250 | fmap_rs.append(fmap_r) 251 | y_d_gs.append(y_d_g) 252 | fmap_gs.append(fmap_g) 253 | 254 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 255 | 256 | 257 | def feature_loss(fmap_r, fmap_g): 258 | loss = 0 259 | for dr, dg in zip(fmap_r, fmap_g): 260 | for rl, gl in zip(dr, dg): 261 | loss += torch.mean(torch.abs(rl - gl)) 262 | 263 | return loss*2 264 | 265 | 266 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 267 | loss = 0 268 | r_losses = [] 269 | g_losses = [] 270 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 271 | r_loss = torch.mean((1-dr)**2) 272 | g_loss = torch.mean(dg**2) 273 | loss += (r_loss + g_loss) 274 | r_losses.append(r_loss.item()) 275 | g_losses.append(g_loss.item()) 276 | 277 | return loss, r_losses, g_losses 278 | 279 | 280 | def generator_loss(disc_outputs): 281 | loss = 0 282 | gen_losses = [] 283 | for dg in disc_outputs: 284 | l = torch.mean((1-dg)**2) 285 | gen_losses.append(l) 286 | loss += l 287 | 288 | return loss, gen_losses 289 | 290 | -------------------------------------------------------------------------------- /modules/__pycache__/SubLayers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/SubLayers.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/SubLayers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/SubLayers.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/align_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/align_loss.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/align_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/align_loss.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/aligner.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/aligner.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/aligner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/aligner.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/attn_loss_function.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/attn_loss_function.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/attn_loss_function.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/attn_loss_function.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/commons.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/commons.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/flow.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/flow.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/init_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/init_layer.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/init_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/init_layer.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/init_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/init_layer.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/saln.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/saln.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/saln.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/saln.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/style.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/style.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/style.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/style.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/vae.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/modules/__pycache__/vae.cpython-36.pyc -------------------------------------------------------------------------------- /modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pdb 16 | import numpy as np 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | 22 | class LinearNorm(torch.nn.Module): 23 | def __init__(self, in_features, out_features, bias=False): 24 | super(LinearNorm, self).__init__() 25 | self.linear = nn.Linear(in_features, out_features, bias) 26 | 27 | nn.init.xavier_uniform_(self.linear.weight) 28 | if bias: nn.init.constant_(self.linear.bias, 0.0) 29 | 30 | def forward(self, x): 31 | x = self.linear(x) 32 | return x 33 | 34 | 35 | class ConvNorm(torch.nn.Module): 36 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 37 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 38 | super(ConvNorm, self).__init__() 39 | if padding is None: 40 | assert(kernel_size % 2 == 1) 41 | padding = int(dilation * (kernel_size - 1) / 2) 42 | 43 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 44 | kernel_size=kernel_size, stride=stride, 45 | padding=padding, dilation=dilation, 46 | bias=bias) 47 | 48 | torch.nn.init.xavier_uniform_( 49 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 50 | 51 | def forward(self, signal): 52 | conv_signal = self.conv(signal) 53 | return conv_signal 54 | 55 | 56 | class ConvAttention(torch.nn.Module): 57 | def __init__(self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=1.0): 58 | super(ConvAttention, self).__init__() 59 | self.temperature = temperature 60 | self.softmax = torch.nn.Softmax(dim=3) 61 | self.log_softmax = torch.nn.LogSoftmax(dim=3) 62 | 63 | self.query_proj = nn.Sequential( 64 | ConvNorm(n_mel_channels, 65 | n_mel_channels * 2, 66 | kernel_size=3, 67 | bias=True, 68 | w_init_gain='relu'), 69 | torch.nn.ReLU(), 70 | ConvNorm(n_mel_channels * 2, 71 | n_mel_channels, 72 | kernel_size=1, 73 | bias=True), 74 | torch.nn.ReLU(), 75 | ConvNorm(n_mel_channels, 76 | n_att_channels, 77 | kernel_size=1, 78 | bias=True)) 79 | 80 | self.key_proj = nn.Sequential( 81 | ConvNorm(n_text_channels, 82 | n_text_channels * 2, 83 | kernel_size=3, 84 | bias=True, 85 | w_init_gain='relu'), 86 | torch.nn.ReLU(), 87 | ConvNorm(n_text_channels * 2, 88 | n_att_channels, 89 | kernel_size=1, 90 | bias=True)) 91 | 92 | self.key_style_proj = LinearNorm(n_text_channels, n_mel_channels) 93 | 94 | def forward(self, queries, keys, query_lens, mask=None, attn_prior=None, 95 | style_emb=None): 96 | """ 97 | Args: 98 | queries (torch.tensor): B x C x T1 tensor 99 | (probably going to be mel data) 100 | keys (torch.tensor): B x C2 x T2 tensor (text data) 101 | query_lens: lengths for sorting the queries in descending order 102 | mask (torch.tensor): uint8 binary mask for variable length entries 103 | (should be in the T2 domain) 104 | Output: 105 | attn (torch.tensor): B x 1 x T1 x T2 attention mask. 106 | Final dim T2 should sum to 1 107 | """ 108 | keys = keys + style_emb.transpose(1,2) 109 | 110 | keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 111 | 112 | # Beware can only do this since query_dim = attn_dim = n_mel_channels 113 | queries_enc = self.query_proj(queries) 114 | 115 | # Simplistic Gaussian Isotopic Attention 116 | # B x n_attn_dims x T1 x T2 117 | attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 118 | # compute log likelihood from a gaussian 119 | attn = -0.0005 * attn.sum(1, keepdim=True) 120 | if attn_prior is not None: 121 | attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]+1e-8) 122 | 123 | attn_logprob = attn.clone() 124 | 125 | if mask is not None: 126 | attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), 127 | -float("inf")) 128 | 129 | attn = self.softmax(attn) # Softmax along T2 130 | return attn, attn_logprob 131 | -------------------------------------------------------------------------------- /modules/attn_loss_function.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | 20 | class AttentionCTCLoss(torch.nn.Module): 21 | def __init__(self, blank_logprob=-1): 22 | super(AttentionCTCLoss, self).__init__() 23 | self.log_softmax = torch.nn.LogSoftmax(dim=3) 24 | self.blank_logprob = blank_logprob 25 | self.CTCLoss = nn.CTCLoss(zero_infinity=True) 26 | 27 | def forward(self, attn_logprob, in_lens, out_lens): 28 | key_lens = in_lens 29 | query_lens = out_lens 30 | attn_logprob_padded = F.pad(input=attn_logprob, 31 | pad=(1, 0, 0, 0, 0, 0, 0, 0), 32 | value=self.blank_logprob) 33 | cost_total = 0.0 34 | for bid in range(attn_logprob.shape[0]): 35 | target_seq = torch.arange(1, key_lens[bid]+1).unsqueeze(0) 36 | curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2) 37 | curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid]+1] 38 | curr_logprob = self.log_softmax(curr_logprob[None])[0] 39 | ctc_cost = self.CTCLoss( 40 | curr_logprob, target_seq, input_lengths=query_lens[bid:bid+1], 41 | target_lengths=key_lens[bid:bid+1]) 42 | cost_total += ctc_cost 43 | cost = cost_total/attn_logprob.shape[0] 44 | return cost 45 | 46 | 47 | class AttentionBinarizationLoss(torch.nn.Module): 48 | def __init__(self): 49 | super(AttentionBinarizationLoss, self).__init__() 50 | 51 | def forward(self, hard_attention, soft_attention, eps=1e-12): 52 | log_sum = torch.log(torch.clamp(soft_attention[hard_attention == 1], 53 | min=eps)).sum() 54 | return -log_sum / hard_attention.sum() 55 | -------------------------------------------------------------------------------- /modules/init_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import hparams as hp 6 | 7 | 8 | class Linear(nn.Linear): 9 | def __init__(self, 10 | in_dim, 11 | out_dim, 12 | bias=True, 13 | w_init_gain='linear'): 14 | super(Linear, self).__init__(in_dim, 15 | out_dim, 16 | bias) 17 | nn.init.xavier_uniform_(self.weight, 18 | gain=nn.init.calculate_gain(w_init_gain)) 19 | 20 | 21 | class Conv1d(nn.Conv1d): 22 | def __init__(self, 23 | in_channels, 24 | out_channels, 25 | kernel_size, 26 | stride=1, 27 | padding=0, 28 | dilation=1, 29 | groups=1, 30 | bias=True, 31 | padding_mode='zeros', 32 | w_init_gain='linear'): 33 | super(Conv1d, self).__init__(in_channels, 34 | out_channels, 35 | kernel_size, 36 | stride, 37 | padding, 38 | dilation, 39 | groups, 40 | bias, 41 | padding_mode) 42 | nn.init.xavier_uniform_(self.weight, 43 | gain=nn.init.calculate_gain(w_init_gain)) 44 | 45 | 46 | -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils.utils import get_mask_from_lengths 6 | 7 | 8 | class TransformerLoss(nn.Module): 9 | def __init__(self): 10 | super(TransformerLoss, self).__init__() 11 | self.g = 0.2 # For guided attention loss 12 | 13 | def forward(self, pred, target, guide): 14 | mel_out, gate_out = pred 15 | mel_target, gate_target = target 16 | alignments, text_lengths, mel_lengths = guide 17 | 18 | mask = ~get_mask_from_lengths(mel_lengths) 19 | 20 | mel_target = mel_target.masked_select(mask.unsqueeze(1)) 21 | mel_out = mel_out.masked_select(mask.unsqueeze(1)) 22 | 23 | gate_target = gate_target.masked_select(mask) 24 | gate_out = gate_out.masked_select(mask) 25 | 26 | mel_loss = nn.L1Loss()(mel_out, mel_target) 27 | bce_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 28 | guide_loss = self.guide_loss(alignments, text_lengths, mel_lengths) 29 | 30 | return mel_loss, bce_loss, guide_loss 31 | 32 | def guide_loss(self, alignments, text_lengths, mel_lengths): 33 | B, n_layers, n_heads, T, L = alignments.size() 34 | 35 | # B, T, L 36 | W = alignments.new_zeros(B, T, L) 37 | mask = alignments.new_zeros(B, T, L) 38 | 39 | for i, (t, l) in enumerate(zip(mel_lengths, text_lengths)): 40 | mel_seq = alignments.new_tensor( torch.arange(t).to(torch.float32).unsqueeze(-1).cuda()/t) 41 | text_seq = alignments.new_tensor( torch.arange(l).to(torch.float32).unsqueeze(0).cuda()/l) 42 | x = torch.pow(text_seq - mel_seq, 2) 43 | W[i, :t, :l] += alignments.new_tensor(1-torch.exp(-x/(2*(self.g**2)))) 44 | 45 | mask[i, :t, :l] = 1 46 | 47 | # Apply guided_loss to 1 heads of the last 2 layers 48 | applied_align = alignments[:, -2:, :1] 49 | losses = applied_align*(W.unsqueeze(1).unsqueeze(1)) 50 | 51 | return torch.mean(losses.masked_select(mask.unsqueeze(1).unsqueeze(1).to(torch.bool))) 52 | 53 | 54 | class EmotionLoss(nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | def forward(self, logit, emos): 58 | one_hot_name = torch.zeros((logit.size(0), logit.size(1))).cuda() 59 | 60 | # Data-specific name definition. You should change this codes for your own data structure 61 | for i in range(logit.size(0)): 62 | emo = emos[i].lower() 63 | if emo == 'h': 64 | one_hot_name[i][0]=1 65 | elif emo == 'a': 66 | one_hot_name[i][1]=1 67 | elif emo == 's': 68 | one_hot_name[i][2]=1 69 | elif emo == 'n': 70 | one_hot_name[i][3]=1 71 | 72 | loss = torch.sum(logit*one_hot_name, dim=1) 73 | threshold = 1e-5*torch.ones_like(loss).cuda() # For stability 74 | loss = torch.max(loss, threshold) 75 | loss = torch.mean(-torch.log(loss)) 76 | 77 | return loss 78 | 79 | 80 | -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .init_layer import * 7 | from .transformer import * 8 | from utils.utils import get_mask_from_lengths, binarize_attention_parallel 9 | from modules.saln import StyleAdaptiveLayerNorm 10 | from .style import Emotion_encoder 11 | from .attention import ConvAttention 12 | 13 | 14 | class Prenet_D(nn.Module): 15 | '''Prenet of decoder''' 16 | def __init__(self, hp): 17 | super(Prenet_D, self).__init__() 18 | self.linear1 = Linear(hp.n_mel_channels, hp.dprenet_dim, w_init_gain='relu') 19 | self.linear2 = Linear(hp.dprenet_dim, hp.dprenet_dim, w_init_gain='relu') 20 | self.linear3 = Linear(hp.dprenet_dim, hp.hidden_dim) 21 | 22 | def forward(self, x): 23 | x = F.dropout(F.relu(self.linear1(x)), p=0.5, training=True) 24 | x = F.dropout(F.relu(self.linear2(x)), p=0.5, training=True) 25 | x = F.relu(self.linear3(x)) 26 | return x 27 | 28 | 29 | class Speaker_encoder(nn.Module): 30 | '''Speaker encoder based on Deep voice 3''' 31 | def __init__(self, hp): 32 | super(Speaker_encoder, self).__init__() 33 | 34 | self.hp = hp 35 | self.embedding = nn.Embedding(hp.num_spk, hp.spk_hidden_dim) 36 | self.linear = nn.Linear(hp.spk_hidden_dim, hp.hidden_dim) 37 | self.softsign = nn.Softsign() 38 | 39 | def forward(self, spk_id): 40 | embedding = self.embedding(spk_id) 41 | spk_emb = self.softsign(self.linear(embedding)) 42 | return spk_emb 43 | 44 | 45 | class Global_style_encoder(nn.Module): 46 | '''FC layer for combining spk & emo embeddings''' 47 | def __init__(self, hp): 48 | super(Global_style_encoder, self).__init__() 49 | self.hp = hp 50 | self.linear = nn.Linear(hp.hidden_dim*2, hp.hidden_dim) 51 | self.softsign = nn.Softsign() 52 | 53 | def forward(self, x): 54 | x = self.linear(x) 55 | x = self.softsign(x) 56 | 57 | return x 58 | 59 | 60 | class F0_predictor(nn.Module): 61 | '''F0 predictor''' 62 | def __init__(self, hp): 63 | super(F0_predictor, self).__init__() 64 | 65 | self.hp = hp 66 | self.conv_layers = nn.ModuleList([Conv1d(hp.hidden_dim, hp.hidden_dim, 67 | kernel_size=hp.ms_kernel, padding=(hp.ms_kernel-1)//2, w_init_gain='relu') 68 | for _ in range(hp.n_layers_lp_enc)]) 69 | self.saln_layers = nn.ModuleList([StyleAdaptiveLayerNorm(hp.hidden_dim, hp.hidden_dim) for _ in range(hp.n_layers_lp_enc)]) 70 | 71 | self.drop = nn.Dropout(0.1) 72 | self.linear = nn.Linear(hp.hidden_dim, 1) 73 | 74 | def forward(self, x, cond, mask=None): 75 | x = x.transpose(1, 2) # [B, 256, L] 76 | 77 | for i in range(self.hp.n_layers_lp_enc): 78 | x = F.relu(self.conv_layers[i](x).transpose(1,2)) # [B, L, 256] 79 | x = self.saln_layers[i](x, cond) 80 | x = self.drop(x).transpose(1,2) # [B, 256, L] 81 | 82 | out = self.linear(x.transpose(1,2)) # [B, L, 1] 83 | 84 | if mask is not None: 85 | out = out.masked_fill(mask.unsqueeze(2), 0.) 86 | 87 | return out 88 | 89 | 90 | class F0_encoder(nn.Module): 91 | '''F0 embedding''' 92 | def __init__(self, hp): 93 | super(F0_encoder, self).__init__() 94 | self.hp = hp 95 | self.conv = nn.Conv1d(1, hp.hidden_dim, kernel_size=hp.ms_kernel, padding=(hp.ms_kernel-1)//2) 96 | 97 | def forward(self, x): 98 | x = x.transpose(1,2) 99 | x = self.conv(x).transpose(1,2) 100 | 101 | return x 102 | 103 | 104 | class Multi_style_encoder(nn.Module): 105 | '''Combining local F0 embeddings and global style embeddings''' 106 | def __init__(self, hp): 107 | super(Multi_style_encoder, self).__init__() 108 | self.hp = hp 109 | self.conv1 = Conv1d(hp.hidden_dim*2, hp.hidden_dim, kernel_size=hp.ms_kernel, padding=(hp.ms_kernel-1)//2, w_init_gain='relu') 110 | self.conv2 = Conv1d(hp.hidden_dim, hp.hidden_dim, kernel_size=hp.ms_kernel, padding=(hp.ms_kernel-1)//2, w_init_gain='relu') 111 | self.norm = nn.LayerNorm(hp.hidden_dim) 112 | self.drop = nn.Dropout(0.1) 113 | self.linear = nn.Linear(hp.hidden_dim, hp.hidden_dim) 114 | self.softsign = nn.Softsign() 115 | 116 | def forward(self, x, mask=None): 117 | x = x.transpose(1,2) 118 | x = F.relu(self.conv1(x).transpose(1,2)) 119 | x = self.drop(self.norm(x)) 120 | 121 | x = F.relu(self.conv2(x.transpose(1,2)).transpose(1,2)) 122 | x = self.drop(self.norm(x)) 123 | x = self.softsign(self.linear(x)) 124 | 125 | if mask is not None: 126 | x = x.masked_fill(mask.unsqueeze(2), 0.) 127 | 128 | return x 129 | 130 | 131 | class FluentTTS(nn.Module): 132 | '''FluentTTS''' 133 | def __init__(self, hp, mode): 134 | super(FluentTTS, self).__init__() 135 | self.hp = hp 136 | self.mode = mode 137 | 138 | # Text encoder 139 | self.Layernorm = nn.LayerNorm(hp.symbols_embedding_dim) 140 | self.Embedding = nn.Embedding(hp.n_symbols, hp.symbols_embedding_dim) 141 | self.alpha1 = nn.Parameter(torch.ones(1)) 142 | self.register_buffer('pe', PositionalEncoding(hp.hidden_dim).pe) 143 | 144 | self.Text_encoder = nn.ModuleList([TransformerEncoderLayer(d_model=hp.hidden_dim, nhead=hp.n_heads, 145 | ff_dim=hp.ff_dim) for _ in range(hp.n_layers)]) 146 | 147 | # Global style encoder 148 | self.Spk_encoder = Speaker_encoder(hp) 149 | self.Emo_encoder = Emotion_encoder(hp) 150 | self.Global_style_encoder = Global_style_encoder(hp) 151 | 152 | # Multi-style generation 153 | if self.mode == 'prop': 154 | self.Internal_aligner = ConvAttention(hp.n_mel_channels, hp.hidden_dim) 155 | self.F0_predictor = F0_predictor(hp) 156 | self.F0_encoder = F0_encoder(hp) 157 | self.Multi_style_encoder = Multi_style_encoder(hp) 158 | 159 | # Mel decoder 160 | self.Prenet_D = Prenet_D(hp) 161 | self.alpha2 = nn.Parameter(torch.ones(1)) 162 | self.register_buffer('pe_d', PositionalEncoding(hp.hidden_dim*2).pe) 163 | 164 | self.Decoder = nn.ModuleList([TransformerDecoderLayer(d_model=hp.hidden_dim*2, nhead=hp.n_heads, 165 | ff_dim=hp.ff_dim) for _ in range(hp.n_layers)]) 166 | 167 | self.Projection = nn.Linear(hp.hidden_dim*2, hp.n_mel_channels) 168 | self.Stop = nn.Linear(hp.n_mel_channels, 1) 169 | self.sigmoid = nn.Sigmoid() 170 | 171 | def outputs(self, text, melspec, text_lengths, mel_lengths, spk, emo, f0, prior, iteration): 172 | # Input data size 173 | B, L, T = text.size(0), text.size(1), melspec.size(2) 174 | 175 | # Speaker embedding (Deep voice 3) 176 | spk_embedding = self.Spk_encoder(spk).unsqueeze(0) # [1, B, 256] 177 | 178 | # Emotion embedding (Reference encoder) 179 | emo_embedding, emo_logit = self.Emo_encoder(melspec, logit=True) 180 | emo_embedding = emo_embedding.transpose(0,1) # [1, B, 256] 181 | 182 | # Style embedding (FC layer) 183 | style_embedding = torch.cat((spk_embedding, emo_embedding), dim=2) # [1, B, 512] 184 | style_embedding = self.Global_style_encoder(style_embedding) # [1, B, 256] 185 | 186 | # Text encoder input 187 | encoder_input = self.Layernorm(self.Embedding(text).transpose(0,1)) 188 | encoder_input = encoder_input + self.alpha1*(self.pe[:L].unsqueeze(1)) 189 | 190 | # Mel decoder input 191 | mel_input = F.pad(melspec, (1,-1)).transpose(1,2) # [B, T, 80] 192 | decoder_input = self.Prenet_D(mel_input).transpose(0,1) # [T, B, 256] 193 | 194 | # Masks 195 | text_mask = get_mask_from_lengths(text_lengths) 196 | mel_mask = get_mask_from_lengths(mel_lengths) 197 | diag_mask = torch.triu(melspec.new_ones(T,T)).transpose(0, 1) 198 | diag_mask[diag_mask == 0] = -float('inf') 199 | diag_mask[diag_mask == 1] = 0 200 | 201 | # Text encoder 202 | memory = encoder_input 203 | enc_alignments = [] 204 | for layer in self.Text_encoder: 205 | memory, enc_align = layer(memory, src_key_padding_mask=text_mask) # [L,B,256] 206 | enc_alignments.append(enc_align.unsqueeze(1)) 207 | enc_alignments = torch.cat(enc_alignments, 1) 208 | 209 | # Internal aligner 210 | if self.mode == 'prop': 211 | soft_A, attn_logprob = self.Internal_aligner(mel_input.transpose(1,2), memory.permute(1,2,0), mel_lengths, 212 | text_mask.unsqueeze(-1), prior.transpose(1,2), 213 | style_embedding.detach().transpose(0,1)) # [B, 1, T, L] [B, L, T] 214 | hard_A = binarize_attention_parallel(soft_A, text_lengths, mel_lengths) 215 | else: 216 | soft_A, hard_A, attn_logprob = torch.zeros(B, 1, T, L).cuda(), torch.zeros(B, 1, T, L).cuda(), torch.zeros(B, L, T).cuda() 217 | 218 | # Multi-style generation 219 | if iteration > self.hp.local_style_step and self.mode == 'prop': 220 | # Phoneme-level target F0 221 | aligned_f0 = torch.bmm(hard_A.squeeze().transpose(1,2), f0.unsqueeze(2)) # [B, L, 1] 222 | nonzero = torch.count_nonzero(hard_A.squeeze().transpose(1,2), dim=2) # [B, L] 223 | aligned_f0 = torch.div(aligned_f0.squeeze(2), nonzero).nan_to_num().unsqueeze(2) # [B, L, 1] 224 | 225 | # Phoneme-level predicted F0 226 | pred_f0 = self.F0_predictor(memory.detach().transpose(0,1), 227 | style_embedding.detach().transpose(0,1), text_mask) # [B, L, 1] 228 | 229 | f0_emb = self.F0_encoder(aligned_f0) 230 | 231 | expand_style_enc = style_embedding.expand(memory.size(0), -1, -1).transpose(0,1) # [B, L, 256] 232 | local_style_emb = torch.cat((expand_style_enc, f0_emb), dim=2) # [B, L, 512] 233 | local_style_emb = self.Multi_style_encoder(local_style_emb).transpose(0,1) # [L, B, 256] 234 | 235 | memory = torch.cat((memory, local_style_emb), dim=2) 236 | 237 | # For initial training part and Baseline 238 | else: 239 | expand_style_enc = style_embedding.expand(memory.size(0), -1, -1) # [L, B, 256] 240 | memory = torch.cat((memory, expand_style_enc), dim=2) # [L, B, 512] 241 | 242 | # Mel decoder 243 | expand_style_dec = style_embedding.expand(decoder_input.size(0), -1, -1) # [T, B, 256] 244 | tgt = torch.cat((decoder_input, expand_style_dec), dim=2) + self.alpha2*(self.pe_d[:T].unsqueeze(1)) # [T, B, 512] 245 | 246 | dec_alignments, enc_dec_alignments = [], [] 247 | for layer in self.Decoder: 248 | tgt, dec_align, enc_dec_align = layer(tgt, 249 | memory, 250 | tgt_mask=diag_mask, 251 | tgt_key_padding_mask=mel_mask, 252 | memory_key_padding_mask=text_mask) 253 | dec_alignments.append(dec_align.unsqueeze(1)) 254 | enc_dec_alignments.append(enc_dec_align.unsqueeze(1)) 255 | dec_alignments = torch.cat(dec_alignments, 1) 256 | enc_dec_alignments = torch.cat(enc_dec_alignments, 1) 257 | 258 | # Projection + Stop token 259 | mel_out = self.Projection(tgt.transpose(0, 1)).transpose(1, 2) 260 | gate_out = self.Stop(mel_out.transpose(1, 2)).squeeze(-1) 261 | 262 | # Return 263 | if iteration > self.hp.local_style_step and self.mode == 'prop': 264 | return mel_out, enc_alignments, dec_alignments, enc_dec_alignments, gate_out, \ 265 | soft_A, hard_A, attn_logprob, emo_logit, aligned_f0, pred_f0 266 | else: 267 | return mel_out, enc_alignments, dec_alignments, enc_dec_alignments, gate_out, \ 268 | soft_A, hard_A, attn_logprob, emo_logit 269 | 270 | 271 | def forward(self, text, melspec, gate, f0, prior, text_lengths, mel_lengths, 272 | criterion, criterion_emo, criterion_ctc, criterion_bin, criterion_f0, 273 | spk, emo, iteration, valid=None): 274 | # Input data 275 | text = text[:,:text_lengths.max().item()] 276 | melspec = melspec[:,:,:mel_lengths.max().item()] 277 | gate = gate[:, :mel_lengths.max().item()] 278 | f0 = f0[:, :mel_lengths.max().item()] 279 | prior = prior[:, :text_lengths.max().item(), :mel_lengths.max().item()] 280 | 281 | # Model outputs 282 | outputs = self.outputs(text, melspec, text_lengths, mel_lengths, spk, emo, f0, prior, iteration) 283 | 284 | # Parse 285 | mel_out = outputs[0] 286 | enc_dec_alignments = outputs[3] 287 | gate_out = outputs[4] 288 | soft_A, hard_A, attn_logprob = outputs[5], outputs[6], outputs[7] 289 | emo_logit = outputs[8] 290 | 291 | # TTS loss 292 | mel_loss, bce_loss, guide_loss = criterion((mel_out, gate_out), 293 | (melspec, gate), 294 | (enc_dec_alignments, text_lengths, mel_lengths)) 295 | 296 | # Internal aligner loss 297 | if self.mode == 'prop': 298 | ctc_loss = criterion_ctc(attn_logprob, text_lengths, mel_lengths) 299 | 300 | if iteration < self.hp.bin_loss_enable_steps: 301 | bin_loss_weight = 0. 302 | else: 303 | bin_loss_weight = self.hp.kl_scale 304 | bin_loss = criterion_bin(hard_A, soft_A) * bin_loss_weight 305 | else: 306 | ctc_loss, bin_loss = torch.FloatTensor([0]).cuda(), torch.FloatTensor([0]).cuda() 307 | 308 | # Emotion classification loss 309 | emo_loss = criterion_emo(emo_logit.squeeze(1), emo) 310 | emo_loss = emo_loss * self.hp.emo_scale 311 | 312 | # F0 loss 313 | if iteration > self.hp.local_style_step and self.mode == 'prop': 314 | aligned_f0, pred_f0 = outputs[9], outputs[10] 315 | f0_loss = criterion_f0(pred_f0, aligned_f0) 316 | f0_loss = f0_loss * self.hp.f0_scale 317 | if valid: 318 | return mel_loss, bce_loss, guide_loss, ctc_loss, bin_loss, f0_loss, emo_loss, outputs, mel_lengths 319 | else: 320 | return mel_loss, bce_loss, guide_loss, ctc_loss, bin_loss, f0_loss, emo_loss 321 | else: 322 | if valid: 323 | return mel_loss, bce_loss, guide_loss, ctc_loss, bin_loss, emo_loss, outputs, mel_lengths 324 | else: 325 | return mel_loss, bce_loss, guide_loss, ctc_loss, bin_loss, emo_loss 326 | 327 | 328 | def inference(self, text, emo_embedding, spk_id, f0_mean, f0_std, 329 | max_len=512, mode=None, start=None, end=None, hz=None): 330 | # Input data size 331 | (B, L), T = text.size(), max_len 332 | 333 | # Speaker embedding 334 | spk_embedding = self.Spk_encoder(spk_id).unsqueeze(0) # [1, 1, 256] 335 | 336 | # Emotion embedding (from reference mel or mean style embedding) 337 | emo_embedding = emo_embedding 338 | 339 | # Style embedding 340 | style_embedding = torch.cat((spk_embedding, emo_embedding), dim=2) # [1, 1, 512] 341 | style_embedding = self.Global_style_encoder(style_embedding) # [1, 1, 256] 342 | 343 | # Text encoder input 344 | encoder_input = self.Layernorm(self.Embedding(text).transpose(0,1)) 345 | encoder_input = encoder_input + self.alpha1*(self.pe[:L].unsqueeze(1)) 346 | 347 | # Masks 348 | text_mask = text.new_zeros(1, L).to(torch.bool) 349 | mel_mask = text.new_zeros(1, T).to(torch.bool) 350 | diag_mask = torch.triu(text.new_ones(T, T)).transpose(0, 1).contiguous() 351 | diag_mask[diag_mask == 0] = -1e9 352 | diag_mask[diag_mask == 1] = 0 353 | diag_mask = diag_mask.float() 354 | 355 | # Text encoder 356 | memory = encoder_input 357 | enc_alignments = [] 358 | for layer in self.Text_encoder: 359 | memory, enc_align = layer(memory, src_key_padding_mask=text_mask) # [L, 1, 256] 360 | enc_alignments.append(enc_align) 361 | enc_alignments = torch.cat(enc_alignments, dim=0) 362 | 363 | if mode == 'prop': 364 | # Multi-style generation 365 | pred_f0 = self.F0_predictor(memory.transpose(0,1), style_embedding.transpose(0,1), text_mask) # [1, L, 1] 366 | 367 | # Dynamic-level F0 control 368 | # Case 1. Word or phoneme-level 369 | if start is not None: 370 | print(f'Word/Phoneme-level F0 control | hz = {hz}') 371 | norm_hz = hz / f0_std 372 | pred_f0[:, start:end] = pred_f0[:, start:end] + norm_hz 373 | # Case 2. Utterance-level 374 | elif hz is not None: 375 | print(f'Utterance-level, F0 shift | hz = {hz}') 376 | norm_hz = hz / f0_std 377 | pred_f0 = pred_f0 + norm_hz 378 | 379 | f0_emb = self.F0_encoder(pred_f0) # [1, L, 256] 380 | 381 | expand_style_enc = style_embedding.expand(memory.size(0), -1, -1).transpose(0,1) # [1, L, 256] 382 | local_style_emb = torch.cat((expand_style_enc, f0_emb), dim=2) # [1, L, 512] 383 | local_style_emb = self.Multi_style_encoder(local_style_emb).transpose(0,1) # [L, 1, 256] 384 | 385 | memory = torch.cat((memory, local_style_emb), dim=2) # [L, 1, 512] 386 | 387 | else: 388 | expand_style_enc = style_embedding.expand(memory.size(0), -1, -1) # [1, L, 256] 389 | memory = torch.cat((memory, expand_style_enc), dim=2) # [L, 1, 512] 390 | 391 | # Decoder inputs 392 | mel_input = text.new_zeros(1, 393 | self.hp.n_mel_channels, 394 | max_len).to(torch.float32) 395 | dec_alignments = text.new_zeros(self.hp.n_layers, 396 | self.hp.n_heads, 397 | max_len, 398 | max_len).to(torch.float32) 399 | enc_dec_alignments = text.new_zeros(self.hp.n_layers, 400 | self.hp.n_heads, 401 | max_len, 402 | text.size(1)).to(torch.float32) 403 | 404 | # Autoregressive generation 405 | stop = [] # Stop token 406 | 407 | for i in range(max_len): 408 | # Preparation 409 | decoder_input = self.Prenet_D(mel_input.transpose(1,2).contiguous()).transpose(0,1).contiguous() 410 | expand_style_dec = style_embedding.expand(decoder_input.size(0), -1, -1) 411 | tgt = torch.cat((decoder_input, expand_style_dec), dim=2) + self.alpha2*(self.pe_d[:T].unsqueeze(1)) 412 | 413 | # Decoder 414 | for j, layer in enumerate(self.Decoder): 415 | tgt, dec_align, enc_dec_align = layer(tgt, 416 | memory, 417 | tgt_mask=diag_mask, 418 | tgt_key_padding_mask=mel_mask, 419 | memory_key_padding_mask=text_mask) 420 | 421 | dec_alignments[j, :, i] = dec_align[0, :, i] 422 | enc_dec_alignments[j, :, i] = enc_dec_align[0, :, i] 423 | 424 | # Outputs 425 | mel_out = self.Projection(tgt.transpose(0,1).contiguous()) 426 | stop.append(torch.sigmoid(self.Stop(mel_out[:,i]))[0,0].item()) 427 | 428 | # Store generated frame 429 | if i < max_len - 1: 430 | mel_input[0, :, i+1] = mel_out[0, i] # [1,80,1024] 431 | 432 | # Break point 433 | if stop[-1]>0.5: 434 | break 435 | 436 | return mel_out.transpose(1,2), enc_alignments, dec_alignments, enc_dec_alignments, stop 437 | 438 | 439 | -------------------------------------------------------------------------------- /modules/saln.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AffineLinear(nn.Module): 6 | def __init__(self, in_dim, out_dim): 7 | super(AffineLinear, self).__init__() 8 | affine = nn.Linear(in_dim, out_dim) 9 | self.affine = affine 10 | 11 | def forward(self, input_data): 12 | return self.affine(input_data) 13 | 14 | class StyleAdaptiveLayerNorm(nn.Module): 15 | def __init__(self, in_channel, style_dim): 16 | super(StyleAdaptiveLayerNorm, self).__init__() 17 | self.in_channel = in_channel 18 | self.norm = nn.LayerNorm(in_channel, elementwise_affine=False) 19 | 20 | self.style = AffineLinear(style_dim, in_channel * 2) 21 | self.style.affine.bias.data[:in_channel] = 1 22 | self.style.affine.bias.data[in_channel:] = 0 23 | 24 | def forward(self, input_data, style_code): 25 | style = self.style(style_code) 26 | 27 | gamma, beta = style.chunk(2, dim=-1) 28 | 29 | out = self.norm(input_data) 30 | out = gamma * out + beta 31 | 32 | return out 33 | 34 | -------------------------------------------------------------------------------- /modules/style.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import pdb 7 | import random 8 | from utils.utils import * 9 | 10 | 11 | class Emotion_encoder(nn.Module): 12 | def __init__(self,hparams): 13 | super().__init__() 14 | self.encoder = ReferenceEncoder(hparams) 15 | 16 | def forward(self, inputs, logit=None): 17 | emo_embed, emo_logit = self.encoder(inputs) 18 | 19 | if logit: 20 | return emo_embed, emo_logit 21 | else: 22 | return emo_embed 23 | 24 | 25 | class ReferenceEncoder(nn.Module): 26 | ''' 27 | inputs --- [N, Ty/r, n_mels*r] mels 28 | outputs --- [N, ref_enc_gru_size] 29 | ''' 30 | def __init__(self,hparams): 31 | super().__init__() 32 | self.ref_enc_filters = hparams.ref_enc_filters 33 | self.n_mel_channels = hparams.n_mel_channels 34 | self.E = hparams.E 35 | K = len(self.ref_enc_filters) 36 | filters = [1] + self.ref_enc_filters 37 | convs = [nn.Conv2d(in_channels=filters[i], 38 | out_channels=filters[i + 1], 39 | kernel_size=(3, 3), 40 | stride=(2, 2), 41 | padding=(1, 1)) for i in range(K)] 42 | self.convs = nn.ModuleList(convs) 43 | self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=self.ref_enc_filters[i]) for i in range(K)]) 44 | 45 | out_channels = self.calculate_channels(self.n_mel_channels, 3, 2, 1, K) 46 | self.gru = nn.GRU(input_size=self.ref_enc_filters[-1] * out_channels, 47 | hidden_size=self.E // 2, 48 | batch_first=True) 49 | self.out_fc = nn.Linear(self.E//2, self.E) 50 | self.softsign = nn.Softsign() 51 | self.emo_logit_extractor = torch.nn.Sequential(torch.nn.Linear(self.E, self.E//2), 52 | torch.nn.ReLU(), 53 | torch.nn.Linear(self.E//2, hparams.num_emo), 54 | torch.nn.Softmax(dim=-1)) 55 | 56 | def forward(self, inputs): 57 | N = inputs.size(0) 58 | out = inputs.view(N, 1, -1, self.n_mel_channels) # [N, 1, Ty, n_mels] 59 | for conv, bn in zip(self.convs, self.bns): 60 | out = conv(out) 61 | out = bn(out) 62 | out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] 63 | 64 | out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] 65 | T = out.size(1) 66 | N = out.size(0) 67 | out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] 68 | 69 | self.gru.flatten_parameters() 70 | memory, out = self.gru(out) # out --- [1, N, E//2] 71 | out = out.squeeze(0) 72 | out = self.softsign(self.out_fc(out.unsqueeze(1))) 73 | emo_logit = self.emo_logit_extractor(out) 74 | return out, emo_logit 75 | 76 | def calculate_channels(self, L, kernel_size, stride, pad, n_convs): 77 | for i in range(n_convs): 78 | L = (L - kernel_size + 2 * pad) // stride + 1 79 | return L 80 | 81 | 82 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .init_layer import * 6 | 7 | 8 | class TransformerEncoderLayer(nn.Module): 9 | def __init__(self, 10 | d_model, nhead, ff_dim, 11 | dropout=0.1, 12 | activation="relu"): 13 | super(TransformerEncoderLayer, self).__init__() 14 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 15 | 16 | self.ffn = nn.Sequential( 17 | Linear(d_model, ff_dim, w_init_gain='relu'), 18 | nn.ReLU(), 19 | nn.Dropout(dropout), 20 | Linear(ff_dim, d_model) 21 | ) 22 | 23 | self.norm1 = nn.LayerNorm(d_model) 24 | self.norm2 = nn.LayerNorm(d_model) 25 | 26 | self.dropout = nn.Dropout(dropout) 27 | 28 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 29 | # Self attention 30 | slf_attn_out, enc_align = self.self_attn(src, 31 | src, 32 | src, 33 | attn_mask=src_mask, 34 | key_padding_mask=src_key_padding_mask) 35 | # Add & Norm 36 | src = src + self.dropout(slf_attn_out) 37 | src = self.norm1(src) 38 | 39 | # FFN 40 | ffn_out = self.ffn(src) 41 | 42 | # Add & Norm 43 | src = src + self.dropout(ffn_out) 44 | src = self.norm2(src) 45 | 46 | return src, enc_align 47 | 48 | 49 | class TransformerDecoderLayer(nn.Module): 50 | def __init__(self, 51 | d_model, nhead, ff_dim, 52 | dropout=0.1, 53 | activation="relu"): 54 | super(TransformerDecoderLayer, self).__init__() 55 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 56 | self.cros_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 57 | 58 | self.ffn = nn.Sequential( 59 | Linear(d_model, ff_dim, w_init_gain='relu'), 60 | nn.ReLU(), 61 | nn.Dropout(dropout), 62 | Linear(ff_dim, d_model) 63 | ) 64 | 65 | self.norm1 = nn.LayerNorm(d_model) 66 | self.norm2 = nn.LayerNorm(d_model) 67 | self.norm3 = nn.LayerNorm(d_model) 68 | 69 | self.dropout = nn.Dropout(dropout) 70 | 71 | def forward(self, 72 | tgt, memory, 73 | tgt_mask=None, memory_mask=None, 74 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 75 | # Self attention 76 | slf_attn_out, dec_align = self.self_attn(tgt, 77 | tgt, 78 | tgt, 79 | attn_mask=tgt_mask, 80 | key_padding_mask=tgt_key_padding_mask) 81 | # Add & Norm 82 | tgt = tgt + self.dropout(slf_attn_out) 83 | tgt = self.norm1(tgt) 84 | 85 | # Cross attention 86 | crs_attn_out, enc_dec_align = self.cros_attn(tgt, 87 | memory, 88 | memory, 89 | attn_mask=memory_mask, 90 | key_padding_mask=memory_key_padding_mask) 91 | # Add & Norm 92 | tgt = tgt + self.dropout(crs_attn_out) 93 | tgt = self.norm2(tgt) 94 | 95 | # FFN 96 | ffn_out = self.ffn(tgt) 97 | 98 | # Add & Norm 99 | tgt = tgt + self.dropout(ffn_out) 100 | tgt = self.norm3(tgt) 101 | 102 | return tgt, dec_align, enc_dec_align 103 | 104 | 105 | class PositionalEncoding(nn.Module): 106 | def __init__(self, d_model, max_len=5000): 107 | super(PositionalEncoding, self).__init__() 108 | self.register_buffer('pe', self._get_pe_matrix(d_model, max_len)) 109 | 110 | def forward(self, x): 111 | return x + self.pe[:x.size(0)].unsqueeze(1) 112 | 113 | def _get_pe_matrix(self, d_model, max_len): 114 | pe = torch.zeros(max_len, d_model) 115 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 116 | div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model) 117 | 118 | pe[:, 0::2] = torch.sin(position / div_term) 119 | pe[:, 1::2] = torch.cos(position / div_term) 120 | 121 | return pe 122 | 123 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os, librosa, torch, pdb, sys 2 | import numpy as np 3 | import pyworld as pw 4 | from statistics import mean 5 | from tqdm import tqdm 6 | from scipy.stats import betabinom 7 | 8 | import hparams 9 | from text import * 10 | from text.cleaners import basic_cleaners 11 | from text.symbols import symbols 12 | from layers import TacotronSTFT 13 | from utils.data_utils import process_meta, create_id_table 14 | 15 | stft = TacotronSTFT() 16 | 17 | ### Mappings from symbol to numeric ID and vice versa: 18 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 19 | id_to_symbol = {i: s for i, s in enumerate(symbols)} 20 | 21 | ### Prepare data path 22 | path = 'Data path to save preprocessed files' 23 | for fname in ('mels', 'texts', 'f0', 'mean_std', 'alignment_priors'): 24 | os.makedirs(os.path.join(path, fname), exist_ok=True) 25 | 26 | file_path = 'Your filelist path' 27 | mean_std_txt = os.path.join(os.path.join(path, 'mean_std'), 'mean_std.txt') 28 | 29 | ### Save filelists 30 | metadata={} 31 | 32 | with open(file_path, 'r') as fid: 33 | for line in fid.readlines(): 34 | # Data-specific name definition. You should change this codes for your own data structure 35 | wav_path, text, spk = line.strip('\n').split("|") 36 | emo = wav_path.split('/')[-1][0] # Ex) 'a' 37 | 38 | clean_char = basic_cleaners(text.rstrip()) 39 | 40 | metadata[wav_path] = {'phone':clean_char, 'spk': spk, 'emo': emo} 41 | 42 | ### Define functions 43 | def text2seq(text): 44 | sequence=[symbol_to_id['^']] 45 | sequence.extend(text_to_sequence(text, hparams.text_cleaners)) 46 | sequence.append(symbol_to_id['~']) 47 | return sequence 48 | 49 | def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): 50 | P, M = phoneme_count, mel_count 51 | x = np.arange(0, P) 52 | mel_text_probs = [] 53 | for i in range(1, M+1): 54 | a, b = scaling_factor*i, scaling_factor*(M+1-i) 55 | rv = betabinom(P, a, b) 56 | mel_i_prob = rv.pmf(x) 57 | mel_text_probs.append(mel_i_prob) 58 | return np.array(mel_text_probs) 59 | 60 | def get_mel(filename): 61 | wav, sr = librosa.load(filename, sr=hparams.sampling_rate) 62 | wav = librosa.effects.trim(wav, top_db=23, frame_length=1024, hop_length=256)[0] 63 | wav_32 = wav.astype(np.float32) 64 | wav = torch.FloatTensor(wav.astype(np.float32)) 65 | melspec, _ = stft.mel_spectrogram(wav.unsqueeze(0)) 66 | return melspec.squeeze(0), wav, wav_32 67 | 68 | def get_wav(filename): 69 | wav, sr = librosa.load(filename, sr=hparams.sampling_rate) 70 | wav = librosa.effects.trim(wav, top_db=23, frame_length=1024, hop_length=256)[0] 71 | wav_32 = wav.astype(np.float32) 72 | return wav_32 73 | 74 | def compute_mean_f0(fname, current_spk, current_emo, mean_list, std_list): 75 | spk_id = metadata[fname]['spk'] 76 | emo_id = metadata[fname]['emo'] 77 | 78 | if spk_id == current_spk and emo_id == current_emo: 79 | wav_32 = get_wav(fname) 80 | 81 | f0, _ = pw.harvest(wav_32.astype(np.float64), hparams.sampling_rate, frame_period=hparams.hop_length/hparams.sampling_rate*1000) 82 | 83 | nonzero_f0 = np.array([x for x in f0 if x!=0]) # Collect only voiced region 84 | 85 | mean_f0 = nonzero_f0.mean() 86 | std_f0 = nonzero_f0.std() 87 | 88 | mean_list.append(mean_f0) 89 | std_list.append(std_f0) 90 | 91 | return mean_list, std_list 92 | 93 | def get_norm_f0(f0, mean, std): 94 | out = [(x-mean)/std if x!=0 else x for x in f0] 95 | return out 96 | 97 | def save_file(fname): 98 | phone_seq = torch.LongTensor(text2seq(metadata[fname]['phone'])) 99 | spk_id = metadata[fname]['spk'] 100 | emo_id = metadata[fname]['emo'] 101 | 102 | melspec, wav, wav_32 = get_mel(fname) 103 | 104 | f0, _ = pw.harvest(wav_32.astype(np.float64), hparams.sampling_rate, frame_period=hparams.hop_length/hparams.sampling_rate*1000) 105 | 106 | name_mean = 'mean_' + spk_id + '_' + emo_id 107 | name_std = 'std_' + spk_id + '_' + emo_id 108 | 109 | mean_f0 = np.load(os.path.join(os.path.join(path, 'mean_std'), name_mean + '.npy')) 110 | std_f0 = np.load(os.path.join(os.path.join(path, 'mean_std'), name_std + '.npy')) 111 | 112 | norm_f0 = get_norm_f0(f0, mean_f0, std_f0) 113 | 114 | attn_prior = beta_binomial_prior_distribution(len(phone_seq), melspec.size(1), 1) 115 | 116 | wav_name = fname.split('/')[-1][:-4] 117 | np.save(os.path.join(os.path.join(path, 'mels'), wav_name), melspec) 118 | np.save(os.path.join(os.path.join(path, 'texts'), wav_name), phone_seq) 119 | np.save(os.path.join(os.path.join(path, 'f0'), wav_name), norm_f0) 120 | np.save(os.path.join(os.path.join(path, 'alignment_priors'), wav_name), attn_prior) 121 | 122 | return wav_name 123 | 124 | ##### Preprocessing Start ##### 125 | ### Search the number of speakers and emotions ### 126 | name, speaker, emotion = process_meta(file_path) 127 | sid_dict = create_id_table(speaker) 128 | eid_dict = create_id_table(emotion) 129 | 130 | spk_label = [key for key in sid_dict] 131 | emo_label = [key for key in eid_dict] 132 | print(spk_label) 133 | print(emo_label) 134 | 135 | ### Compute mean and std of f0 ### 136 | dist_save_txt = open(mean_std_txt, 'w') 137 | 138 | for spk in spk_label: 139 | for emo in emo_label: 140 | mean_list = [] 141 | std_list = [] 142 | for filepath in tqdm(metadata.keys(), desc=f'{spk}|{emo}'): 143 | mean_list, std_list = compute_mean_f0(filepath, spk, emo, mean_list, std_list) 144 | 145 | mean_list, std_list = np.array([mean_list]), np.array([std_list]) 146 | mean, std = mean_list.mean(), std_list.mean() 147 | std_of_mean = mean_list.std() 148 | 149 | name_mean = 'mean_' + spk + '_' + emo 150 | name_std = 'std_' + spk + '_' + emo 151 | 152 | np.save(os.path.join(os.path.join(path, 'mean_std'), name_mean), mean) 153 | np.save(os.path.join(os.path.join(path, 'mean_std'), name_std), std) 154 | 155 | print(f'{spk} - {emo}: mean={mean:.2f}, std={std:.2f}, std of mean={std_of_mean:.2f}') 156 | dist_save_txt.write(f'{spk} - {emo}: mean={mean:.2f}, std={std:.2f}, std of mean={std_of_mean:.2f}\n') 157 | 158 | dist_save_txt.close() 159 | 160 | ### Prepare Data with Normalized F0 ### 161 | for filepath in tqdm(metadata.keys(), desc='Data preprocessing'): 162 | _ = save_file(filepath) 163 | 164 | print('Data preprocessing done!!!') 165 | 166 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.6.9 2 | pytorch==1.10.2 3 | scipy==1.5.2 4 | numpy==1.19.5 5 | librosa==0.8.1 6 | pyworld==0.3.0 7 | tqdm==4.62.3 -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from audio_processing import window_sumsquare 40 | 41 | 42 | class STFT(torch.nn.Module): 43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 44 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 45 | window='hann'): 46 | super(STFT, self).__init__() 47 | self.filter_length = filter_length 48 | self.hop_length = hop_length 49 | self.win_length = win_length 50 | self.window = window 51 | self.forward_transform = None 52 | scale = self.filter_length / self.hop_length 53 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 54 | 55 | cutoff = int((self.filter_length / 2 + 1)) 56 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 57 | np.imag(fourier_basis[:cutoff, :])]) 58 | 59 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 60 | inverse_basis = torch.FloatTensor( 61 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 62 | 63 | if window is not None: 64 | assert(filter_length >= win_length) 65 | # get window and zero center pad it to filter_length 66 | fft_window = get_window(window, win_length, fftbins=True) 67 | fft_window = pad_center(fft_window, filter_length) 68 | fft_window = torch.from_numpy(fft_window).float() 69 | 70 | # window the bases 71 | forward_basis *= fft_window 72 | inverse_basis *= fft_window 73 | 74 | self.register_buffer('forward_basis', forward_basis.float()) 75 | self.register_buffer('inverse_basis', inverse_basis.float()) 76 | 77 | def transform(self, input_data): 78 | num_batches = input_data.size(0) 79 | num_samples = input_data.size(1) 80 | 81 | self.num_samples = num_samples 82 | 83 | # similar to librosa, reflect-pad the input 84 | input_data = input_data.view(num_batches, 1, num_samples) 85 | input_data = F.pad( 86 | input_data.unsqueeze(1), 87 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 88 | mode='reflect') 89 | input_data = input_data.squeeze(1) 90 | 91 | forward_transform = F.conv1d( 92 | input_data, 93 | Variable(self.forward_basis, requires_grad=False), 94 | stride=self.hop_length, 95 | padding=0) 96 | 97 | cutoff = int((self.filter_length / 2) + 1) 98 | real_part = forward_transform[:, :cutoff, :] 99 | imag_part = forward_transform[:, cutoff:, :] 100 | 101 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 102 | phase = torch.autograd.Variable( 103 | torch.atan2(imag_part.data, real_part.data)) 104 | 105 | return magnitude, phase 106 | 107 | def inverse(self, magnitude, phase): 108 | recombine_magnitude_phase = torch.cat( 109 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 110 | 111 | inverse_transform = F.conv_transpose1d( 112 | recombine_magnitude_phase, 113 | Variable(self.inverse_basis, requires_grad=False), 114 | stride=self.hop_length, 115 | padding=0) 116 | 117 | if self.window is not None: 118 | window_sum = window_sumsquare( 119 | self.window, magnitude.size(-1), hop_length=self.hop_length, 120 | win_length=self.win_length, n_fft=self.filter_length, 121 | dtype=np.float32) 122 | # remove modulation effects 123 | approx_nonzero_indices = torch.from_numpy( 124 | np.where(window_sum > tiny(window_sum))[0]) 125 | window_sum = torch.autograd.Variable( 126 | torch.from_numpy(window_sum), requires_grad=False) 127 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 128 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 129 | 130 | # scale by hop ratio 131 | inverse_transform *= float(self.filter_length) / self.hop_length 132 | 133 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 134 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 135 | 136 | return inverse_transform 137 | 138 | def forward(self, input_data): 139 | self.magnitude, self.phase = self.transform(input_data) 140 | reconstruction = self.inverse(self.magnitude, self.phase) 141 | return reconstruction 142 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | '''Converts a sequence of IDs back to a string''' 45 | result = '' 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == '@': 51 | s = '{%s}' % s[1:] 52 | result += s 53 | return result.replace('}{', ' ') 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception('Unknown cleaner: %s' % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(['@' + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s is not '_' and s is not '~' 75 | -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/cleaners.cpython-35.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/cleaners.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/cleaners.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/cleaners.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/cmudict.cpython-35.pyc -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/cmudict.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/cmudict.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/numbers.cpython-35.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/numbers.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/numbers.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/symbols.cpython-35.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/symbols.cpython-36.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/symbols.cpython-37.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/text/__pycache__/symbols.cpython-38.pyc -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ modified for Korean language 2 | Folked from https://github.com/keithito/tacotron """ 3 | 4 | ''' 5 | Cleaners are transformations that ru over the input text at both training and eval time. 6 | 7 | Cleaners can be selected by passian a comma-delimited list of cleaner names as the "cleaners" 8 | hyperparameter. Most of cleaners are Korean-specific. 9 | ''' 10 | 11 | import re 12 | from jamo import h2j 13 | 14 | # Regular expression matching whitespace: 15 | _whitespace_re = re.compile(r'\s+') 16 | 17 | 18 | def collapse_whitespace(text): 19 | return re.sub(_whitespace_re, ' ', text) 20 | 21 | 22 | def basic_cleaners(text): 23 | '''Basic pipeline that support Korean-only''' 24 | text = collapse_whitespace(text) 25 | text = h2j(text) 26 | return text 27 | 28 | -------------------------------------------------------------------------------- /text/english_utils/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ]] 44 | 45 | 46 | def expand_abbreviations(text): 47 | for regex, replacement in _abbreviations: 48 | text = re.sub(regex, replacement, text) 49 | return text 50 | 51 | 52 | def expand_numbers(text): 53 | return normalize_numbers(text) 54 | 55 | 56 | def lowercase(text): 57 | return text.lower() 58 | 59 | 60 | def collapse_whitespace(text): 61 | return re.sub(_whitespace_re, ' ', text) 62 | 63 | 64 | def convert_to_ascii(text): 65 | return unidecode(text) 66 | 67 | 68 | def basic_cleaners(text): 69 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 70 | text = lowercase(text) 71 | text = collapse_whitespace(text) 72 | return text 73 | 74 | 75 | def transliteration_cleaners(text): 76 | '''Pipeline for non-English text that transliterates to ASCII.''' 77 | text = convert_to_ascii(text) 78 | text = lowercase(text) 79 | text = collapse_whitespace(text) 80 | return text 81 | 82 | 83 | def english_cleaners(text): 84 | '''Pipeline for English text, including number and abbreviation expansion.''' 85 | text = convert_to_ascii(text) 86 | text = lowercase(text) 87 | text = expand_numbers(text) 88 | text = expand_abbreviations(text) 89 | text = collapse_whitespace(text) 90 | return text 91 | 92 | -------------------------------------------------------------------------------- /text/english_utils/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/english_utils/numbers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ from https://github.com/keithito/tacotron """ 3 | 4 | import inflect 5 | import re 6 | 7 | 8 | _inflect = inflect.engine() 9 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 10 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 11 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 12 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 13 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 14 | _number_re = re.compile(r'[0-9]+') 15 | 16 | 17 | def _remove_commas(m): 18 | return m.group(1).replace(',', '') 19 | 20 | 21 | def _expand_decimal_point(m): 22 | return m.group(1).replace('.', ' point ') 23 | 24 | 25 | def _expand_dollars(m): 26 | match = m.group(1) 27 | parts = match.split('.') 28 | if len(parts) > 2: 29 | return match + ' dollars' # Unexpected format 30 | dollars = int(parts[0]) if parts[0] else 0 31 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 32 | if dollars and cents: 33 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 34 | cent_unit = 'cent' if cents == 1 else 'cents' 35 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 36 | elif dollars: 37 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 38 | return '%s %s' % (dollars, dollar_unit) 39 | elif cents: 40 | cent_unit = 'cent' if cents == 1 else 'cents' 41 | return '%s %s' % (cents, cent_unit) 42 | else: 43 | return 'zero dollars' 44 | 45 | 46 | def _expand_ordinal(m): 47 | return _inflect.number_to_words(m.group(0)) 48 | 49 | 50 | def _expand_number(m): 51 | num = int(m.group(0)) 52 | if num > 1000 and num < 3000: 53 | if num == 2000: 54 | return 'two thousand' 55 | elif num > 2000 and num < 2010: 56 | return 'two thousand ' + _inflect.number_to_words(num % 100) 57 | elif num % 100 == 0: 58 | return _inflect.number_to_words(num // 100) + ' hundred' 59 | else: 60 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 61 | else: 62 | return _inflect.number_to_words(num, andword='') 63 | 64 | 65 | def normalize_numbers(text): 66 | text = re.sub(_comma_number_re, _remove_commas, text) 67 | text = re.sub(_pounds_re, r'\1 pounds', text) 68 | text = re.sub(_dollars_re, _expand_dollars, text) 69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 70 | text = re.sub(_ordinal_re, _expand_ordinal, text) 71 | text = re.sub(_number_re, _expand_number, text) 72 | return text 73 | -------------------------------------------------------------------------------- /text/english_utils/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | 4 | The default is a Korean characters from U+11xx. 5 | You can check the code using 'ord' function in Python3. ''' 6 | 7 | _pad = '_' 8 | _sos = '^' 9 | _eos = '~' 10 | _special = '-' 11 | _punctuation = '!\'(),.:;? ' 12 | 13 | _jamo_leads = "".join(chr(c) for c in range(0x1100, 0x1113)) 14 | _jamo_vowels = "".join(chr(c) for c in range(0x1161, 0x1176)) 15 | _jamo_tails = "".join(chr(c) for c in range(0x11a8, 0x11c3)) 16 | _letters = _jamo_leads + _jamo_vowels + _jamo_tails 17 | 18 | symbols = [_pad] + [_sos] + [_eos] + list(_special) + list(_punctuation) + list(_letters) 19 | 20 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, argparse, pdb, random 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import hparams 7 | from modules.model import FluentTTS 8 | from modules.loss import TransformerLoss, EmotionLoss 9 | from modules.attn_loss_function import AttentionCTCLoss, AttentionBinarizationLoss 10 | from text import * 11 | from utils.utils import * 12 | from utils.writer import get_writer, plot_attn 13 | 14 | 15 | def validate(model, criterion, criterion_emo, criterion_ctc, criterion_bin, criterion_f0, val_loader, writer, iteration): 16 | model.eval() 17 | with torch.no_grad(): 18 | n_data, val_loss = 0, 0 19 | for i, batch in enumerate(val_loader): 20 | n_data += len(batch[0]) 21 | 22 | # Get mini-batch 23 | text_padded, text_lengths, mel_padded, mel_lengths, gate_padded, \ 24 | f0_padded, prior_padded, name, spk, emo = [reorder_batch(x, 1) for x in batch] 25 | 26 | # Forward (w/ multi-style generation) 27 | if iteration > hparams.local_style_step and model.mode == 'prop': 28 | mel_loss, bce_loss, guide_loss, \ 29 | ctc_loss, bin_loss, f0_loss, emo_loss, outputs, mel_lengths = model(text_padded, mel_padded, gate_padded, f0_padded, prior_padded, text_lengths, mel_lengths, 30 | criterion, criterion_emo, criterion_ctc, criterion_bin, criterion_f0, 31 | spk, emo, iteration, valid=True) 32 | 33 | mel_loss, bce_loss, guide_loss, \ 34 | ctc_loss, bin_loss, f0_loss, emo_loss = [torch.mean(x) for x in [mel_loss, bce_loss, guide_loss, \ 35 | ctc_loss, bin_loss, f0_loss, emo_loss]] 36 | sub_loss = mel_loss + bce_loss + guide_loss + ctc_loss + bin_loss + f0_loss + emo_loss 37 | 38 | # Forward (w/o multi-style generation) 39 | else: 40 | mel_loss, bce_loss, guide_loss, \ 41 | ctc_loss, bin_loss, emo_loss, outputs, mel_lengths = model(text_padded, mel_padded, gate_padded, f0_padded, prior_padded, text_lengths, mel_lengths, 42 | criterion, criterion_emo, criterion_ctc, criterion_bin, criterion_f0, 43 | spk, emo, iteration, valid=True) 44 | 45 | mel_loss, bce_loss, guide_loss, \ 46 | ctc_loss, bin_loss, emo_loss = [torch.mean(x) for x in [mel_loss, bce_loss, guide_loss, \ 47 | ctc_loss, bin_loss, emo_loss]] 48 | sub_loss = mel_loss + bce_loss + guide_loss + ctc_loss + bin_loss + emo_loss 49 | 50 | val_loss += sub_loss.item() * len(batch[0]) 51 | 52 | val_loss /= n_data 53 | 54 | # Tensorboard 55 | if iteration > hparams.local_style_step and model.mode == 'prop': 56 | writer.add_losses(val_loss, mel_loss.item(), bce_loss.item(), guide_loss.item(), 57 | ctc_loss.item(), bin_loss.item(), emo_loss.item(), 58 | iteration, 'Validation', f0_loss.item()) 59 | else: 60 | writer.add_losses(val_loss, mel_loss.item(), bce_loss.item(), guide_loss.item(), 61 | ctc_loss.item(), bin_loss.item(), emo_loss.item(), 62 | iteration, 'Validation') 63 | 64 | # Plot 65 | mel_out, enc_alignments, dec_alignments, enc_dec_alignments, gate_out = outputs[0], outputs[1], outputs[2], outputs[3], outputs[4] 66 | idx = random.randint(0, len(mel_out)-1) 67 | 68 | writer.add_specs(mel_padded.detach().cpu(), 69 | mel_out.detach().cpu(), 70 | mel_lengths.detach().cpu(), 71 | iteration, 'Validation', idx) 72 | 73 | writer.add_alignments(enc_alignments.detach().cpu(), 74 | dec_alignments.detach().cpu(), 75 | enc_dec_alignments.detach().cpu(), 76 | text_padded.detach().cpu(), 77 | mel_lengths.detach().cpu(), 78 | text_lengths.detach().cpu(), 79 | iteration, 'Validation', idx) 80 | 81 | if model.mode == 'prop': 82 | soft_A, hard_A = outputs[5], outputs[6] 83 | soft_A = soft_A[idx].squeeze().transpose(0,1)[:text_lengths[idx], :mel_lengths[idx]] 84 | hard_A = hard_A[idx].squeeze().transpose(0,1)[:text_lengths[idx], :mel_lengths[idx]] 85 | 86 | plot_attn(writer, soft_A, hard_A, iteration) 87 | 88 | writer.add_gates(gate_out[idx].detach().cpu(), iteration, 'Validation') 89 | 90 | print(f'\nValidation: {iteration} | loss {val_loss:.4f}') 91 | 92 | model.train() 93 | 94 | 95 | def main(args): 96 | # Preparation 97 | if not os.path.isdir(args.outdir): 98 | os.mkdir(args.outdir) 99 | mode = args.mode 100 | 101 | train_loader, val_loader, collate_fn = prepare_dataloaders(hparams) 102 | model = FluentTTS(hparams, mode).cuda() 103 | optimizer = torch.optim.Adam(model.parameters(), lr=hparams.lr, betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-6) 104 | model, optimizer, last_epoch, learning_rate, iteration = load_checkpoint(args.checkpoint_path, model, optimizer) 105 | 106 | criterion = TransformerLoss() 107 | criterion_emo = EmotionLoss() 108 | criterion_f0 = nn.L1Loss() 109 | criterion_ctc = AttentionCTCLoss() 110 | criterion_bin = AttentionBinarizationLoss() 111 | 112 | writer = get_writer(args.outdir, 'logdir') 113 | 114 | loss = 0 115 | num_param = 0 116 | for name, param in model.named_parameters(): 117 | if param.requires_grad: 118 | num_param += param.numel() 119 | print(f'Model parameters: {num_param/1000000:.2f}M') 120 | 121 | # Training 122 | model.train() 123 | print("Training start!") 124 | for epoch in range(max(0, last_epoch), hparams.training_epochs): 125 | for i, batch in enumerate(train_loader): 126 | # Get mini-batch 127 | text_padded, text_lengths, mel_padded, mel_lengths, gate_padded, \ 128 | f0_padded, prior_padded, name, spk, emo = [reorder_batch(x, hparams.n_gpus) for x in batch] 129 | 130 | # Forward (w/ multi-style generation) 131 | if iteration > hparams.local_style_step and mode == 'prop': 132 | mel_loss, bce_loss, guide_loss, \ 133 | ctc_loss, bin_loss, f0_loss, emo_loss = model(text_padded, mel_padded, gate_padded, f0_padded, prior_padded, text_lengths, mel_lengths, 134 | criterion, criterion_emo, criterion_ctc, criterion_bin, criterion_f0, 135 | spk, emo, iteration) 136 | 137 | mel_loss, bce_loss, guide_loss, \ 138 | ctc_loss, bin_loss, f0_loss, emo_loss = [torch.mean(x) for x in [mel_loss, bce_loss, guide_loss, \ 139 | ctc_loss, bin_loss, f0_loss, emo_loss]] 140 | sub_loss = mel_loss + bce_loss + guide_loss + ctc_loss + bin_loss + f0_loss + emo_loss 141 | 142 | # Forward (w/o multi-style generation) 143 | else: 144 | mel_loss, bce_loss, guide_loss, \ 145 | ctc_loss, bin_loss, emo_loss = model(text_padded, mel_padded, gate_padded, f0_padded, prior_padded, text_lengths, mel_lengths, 146 | criterion, criterion_emo, criterion_ctc, criterion_bin, criterion_f0, 147 | spk, emo, iteration) 148 | 149 | mel_loss, bce_loss, guide_loss, \ 150 | ctc_loss, bin_loss, emo_loss = [torch.mean(x) for x in [mel_loss, bce_loss, guide_loss, \ 151 | ctc_loss, bin_loss, emo_loss]] 152 | sub_loss = mel_loss + bce_loss + guide_loss + ctc_loss + bin_loss + emo_loss 153 | 154 | # Backward 155 | sub_loss.backward() 156 | loss = loss+sub_loss.item() 157 | 158 | iteration += 1 159 | lr_scheduling(optimizer, iteration) 160 | nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh) 161 | optimizer.step() 162 | model.zero_grad() 163 | optimizer.zero_grad() 164 | print(f"[Epoch {epoch}] Train: {iteration} step | loss {loss:.4f}", end='\r') 165 | 166 | # Tensorboard 167 | if iteration > hparams.local_style_step + 1 and mode == 'prop': 168 | writer.add_losses(loss, mel_loss.item(), bce_loss.item(), guide_loss.item(), 169 | ctc_loss.item(), bin_loss.item(), emo_loss.item(), 170 | iteration, 'Train', f0_loss.item()) 171 | else: 172 | writer.add_losses(loss, mel_loss.item(), bce_loss.item(), guide_loss.item(), 173 | ctc_loss.item(), bin_loss.item(), emo_loss.item(), 174 | iteration, 'Train') 175 | 176 | loss = 0 177 | 178 | # Validation & Save 179 | if iteration % hparams.iters_per_validation == 0: 180 | validate(model, criterion, criterion_emo, criterion_ctc, criterion_bin, criterion_f0, 181 | val_loader, writer, iteration) 182 | 183 | if iteration % hparams.iters_per_checkpoint == 0: 184 | save_checkpoint(model, optimizer, epoch, hparams.lr, iteration, filepath=f'{args.outdir}/logdir') 185 | 186 | 187 | if __name__ == '__main__': 188 | p = argparse.ArgumentParser() 189 | p.add_argument('--gpu', type=str, default='0,1') 190 | p.add_argument('-v', '--verbose', type=str, default='0') 191 | p.add_argument('-c', '--checkpoint_path', default=None) 192 | p.add_argument('-o', '--outdir', default='outdir') 193 | p.add_argument('-m', '--mode', type=str, help='base, prop') 194 | args = p.parse_args() 195 | 196 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 197 | torch.manual_seed(hparams.seed) 198 | torch.cuda.manual_seed(hparams.seed) 199 | 200 | if args.verbose=='0': 201 | import warnings 202 | warnings.filterwarnings("ignore") 203 | 204 | main(args) 205 | -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/data_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/plot_image.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/plot_image.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plot_image.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/plot_image.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/test_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/writer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/writer.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/writer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/writer.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/writer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/monglechap/fluenttts/ad23e09e9aebe95cbe0c28603b3f3b6fbb054e93/utils/__pycache__/writer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os, random, torch, pdb 2 | import numpy as np 3 | import torch.utils.data 4 | import torch.nn.functional as F 5 | 6 | import hparams 7 | 8 | 9 | # Prepare filelists 10 | def process_meta(meta_path): 11 | with open(meta_path, 'r', encoding='utf-8') as f: 12 | name, speaker, emotion = [], [], [] 13 | for line in f.readlines(): 14 | # Data-specific configuration. 15 | # You should change this codes for your own data structure 16 | path, text, spk = line.strip('\n').split('|') 17 | filename = path.split('/')[-1][:-4] 18 | emo = filename[0] 19 | 20 | name.append(filename) 21 | speaker.append(spk) 22 | emotion.append(emo) 23 | 24 | return name, speaker, emotion 25 | 26 | # Sort unique IDs 27 | def create_id_table(ids): 28 | sorted_ids = np.sort(np.unique(ids)) 29 | d = {sorted_ids[i]: i for i in range(len(sorted_ids))} 30 | return d 31 | 32 | # Read filelists 33 | def load_filepaths_and_text(metadata, split="|"): 34 | with open(metadata, encoding='utf-8') as f: 35 | filepaths_and_text = [line.strip().split(split) for line in f] 36 | return filepaths_and_text 37 | 38 | # Custom dataset 39 | class TextMelSet(torch.utils.data.Dataset): 40 | def __init__(self, audiopaths_and_text, hparams): 41 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 42 | self.seq_dir = os.path.join(hparams.data_path, 'texts') 43 | self.mel_dir = os.path.join(hparams.data_path, 'mels') 44 | self.norm_f0_dir = os.path.join(hparams.data_path, 'pitch_norm') 45 | self.prior_dir = os.path.join(hparams.data_path, 'alignment_priors') 46 | 47 | _, self.speaker, self.emotion = process_meta(audiopaths_and_text) 48 | self.sid_dict = create_id_table(self.speaker) 49 | self.eid_dict = create_id_table(self.emotion) 50 | # print(self.sid_dict) 51 | 52 | def get_mel_text_pair(self, audiopath_and_text): 53 | # separate filename and text 54 | # Data-specific configuration. 55 | # You should change this codes for your own data structure 56 | wav_path = audiopath_and_text[0] 57 | name = wav_path.split('/')[-1][:-4] 58 | emo = name[0] 59 | spk = audiopath_and_text[2] 60 | 61 | spk_id = self.sid_dict[spk] 62 | emo_id = self.eid_dict[emo] 63 | 64 | text = np.load(os.path.join(self.seq_dir, name+'.npy')) 65 | mel = np.load(os.path.join(self.mel_dir, name+'.npy')) 66 | f0 = np.load(os.path.join(self.norm_f0_dir, name+'.npy')) 67 | prior = np.load(os.path.join(self.prior_dir, name+'.npy')) 68 | 69 | return (torch.IntTensor(text), torch.FloatTensor(mel), torch.FloatTensor(f0), 70 | torch.FloatTensor(prior), name, torch.LongTensor([spk_id]), emo) 71 | 72 | def __getitem__(self, index): 73 | return self.get_mel_text_pair(self.audiopaths_and_text[index]) 74 | 75 | def __len__(self): 76 | return len(self.audiopaths_and_text) 77 | 78 | # Collate function 79 | class TextMelCollate(): 80 | def __init__(self): 81 | return 82 | 83 | def __call__(self, batch): 84 | # Right zero-pad all one-hot text sequences to max input length 85 | input_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor([len(x[0]) for x in batch]), 86 | dim=0, descending=True) 87 | max_input_len = input_lengths[0] 88 | 89 | text_padded = torch.zeros(len(batch), max_input_len, dtype=torch.long) 90 | 91 | for i in range(len(ids_sorted_decreasing)): 92 | text = batch[ids_sorted_decreasing[i]][0] 93 | text_padded[i, :text.size(0)] = text 94 | 95 | # Right zero-pad 96 | num_mels = batch[0][1].size(0) 97 | max_target_len = max([x[1].size(1) for x in batch]) 98 | 99 | # include Spec padded and gate padded 100 | mel_padded = torch.zeros(len(batch), num_mels, max_target_len) 101 | gate_padded = torch.zeros(len(batch), max_target_len) 102 | f0_padded = torch.zeros(len(batch), max_target_len) 103 | prior_padded = torch.zeros(len(batch), max_input_len, max_target_len) 104 | 105 | output_lengths = torch.LongTensor(len(batch)) 106 | name = [] 107 | spk = torch.LongTensor(len(batch)) 108 | emo = [] 109 | 110 | for i in range(len(ids_sorted_decreasing)): 111 | mel = batch[ids_sorted_decreasing[i]][1] 112 | f0 = batch[ids_sorted_decreasing[i]][2] 113 | prior = batch[ids_sorted_decreasing[i]][3].contiguous().transpose(0,1) 114 | 115 | mel_padded[i, :, :mel.size(1)] = mel 116 | gate_padded[i, mel.size(1)-1:] = 1 117 | f0_padded[i, :mel.size(1)] = f0 118 | 119 | prior_padded[i, :prior.size(0), :prior.size(1)] = prior 120 | 121 | output_lengths[i] = mel.size(1) 122 | name.append(batch[ids_sorted_decreasing[i]][4]) 123 | spk[i] = batch[ids_sorted_decreasing[i]][5] 124 | emo.append(batch[ids_sorted_decreasing[i]][6]) 125 | 126 | return text_padded, input_lengths, mel_padded, output_lengths, gate_padded, \ 127 | f0_padded, prior_padded, name, spk, emo 128 | -------------------------------------------------------------------------------- /utils/plot_image.py: -------------------------------------------------------------------------------- 1 | import torch, random, pdb 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | 6 | import hparams 7 | from text import * 8 | 9 | 10 | def plot_melspec(target, melspec, mel_lengths, idx): 11 | fig, axes = plt.subplots(2, 1, figsize=(20,30)) 12 | T = mel_lengths[idx] 13 | 14 | axes[0].imshow(target[idx][:,:T], 15 | origin='lower', 16 | aspect='auto') 17 | 18 | axes[1].imshow(melspec[idx][:,:T], 19 | origin='lower', 20 | aspect='auto') 21 | 22 | return fig 23 | 24 | 25 | def plot_alignments(alignments, text, mel_lengths, text_lengths, att_type, idx): 26 | fig, axes = plt.subplots(hparams.n_layers, hparams.n_heads, figsize=(5*hparams.n_heads,5*hparams.n_layers)) 27 | L, T = text_lengths[idx], mel_lengths[idx] 28 | n_layers, n_heads = alignments.size(1), alignments.size(2) 29 | 30 | for layer in range(n_layers): 31 | for head in range(n_heads): 32 | if att_type=='enc': 33 | align = alignments[idx, layer, head].contiguous() 34 | axes[layer,head].imshow(align[:L, :L], aspect='auto') 35 | axes[layer,head].xaxis.tick_top() 36 | 37 | elif att_type=='dec': 38 | align = alignments[idx, layer, head].contiguous() 39 | axes[layer,head].imshow(align[:T, :T], aspect='auto') 40 | axes[layer,head].xaxis.tick_top() 41 | 42 | elif att_type=='enc_dec': 43 | align = alignments[idx, layer, head].transpose(0,1).contiguous() 44 | axes[layer,head].imshow(align[:L, :T], origin='lower', aspect='auto') 45 | 46 | return fig 47 | 48 | def plot_alignment(alignment): 49 | fig, ax = plt.subplots(figsize=(6, 4)) 50 | im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none') 51 | plt.colorbar(im, ax=ax) 52 | fig.canvas.draw() 53 | plt.close() 54 | 55 | return fig 56 | 57 | def plot_gate(gate_out): 58 | fig = plt.figure(figsize=(10,5)) 59 | plt.plot(torch.sigmoid(gate_out)) 60 | return fig 61 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os, random, torch, pdb 2 | import matplotlib.pyplot as plt 3 | from torch.utils.data import DataLoader 4 | from numba import jit, prange 5 | import numpy as np 6 | 7 | import hparams 8 | from .data_utils import TextMelSet, TextMelCollate 9 | from text import * 10 | 11 | 12 | def init_weights(m, mean=0.0, std=0.01): 13 | classname = m.__class__.__name__ 14 | if classname.find("Conv") != -1: 15 | m.weight.data.normal_(mean, std) 16 | 17 | 18 | def apply_weight_norm(m): 19 | classname = m.__class__.__name__ 20 | if classname.find("Conv") != -1: 21 | weight_norm(m) 22 | 23 | 24 | def get_padding(kernel_size, dilation=1): 25 | return int((kernel_size*dilation - dilation)/2) 26 | 27 | 28 | def prepare_dataloaders(hparams): 29 | trainset = TextMelSet(hparams.training_files, hparams) 30 | valset = TextMelSet(hparams.validation_files, hparams) 31 | collate_fn = TextMelCollate() 32 | 33 | train_loader = DataLoader(trainset, 34 | num_workers=hparams.n_gpus-1, 35 | shuffle=True, 36 | batch_size=hparams.batch_size, 37 | drop_last=True, 38 | collate_fn=collate_fn) 39 | 40 | val_loader = DataLoader(valset, 41 | batch_size=hparams.batch_size//hparams.n_gpus, 42 | collate_fn=collate_fn) 43 | 44 | spk_label = [key for key in valset.sid_dict] 45 | emo_label = [key for key in valset.eid_dict] 46 | print(spk_label) 47 | print(emo_label) 48 | 49 | return train_loader, val_loader, collate_fn 50 | 51 | 52 | def load_checkpoint(checkpoint_path, model, optimizer): 53 | if checkpoint_path == None: return model, optimizer, 0, hparams.lr, 0 54 | assert os.path.isfile(checkpoint_path) 55 | print("Loading checkpoint '{}'".format(checkpoint_path)) 56 | 57 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') # pretrained 58 | model_dict = model.state_dict() 59 | pretrained_dict = {k: v for k,v in checkpoint_dict['state_dict'].items() if k in model_dict.keys()} 60 | model_dict.update(pretrained_dict) 61 | model.load_state_dict(model_dict) 62 | 63 | learning_rate = checkpoint_dict['learning_rate'] 64 | epoch = checkpoint_dict['epoch'] 65 | iteration = checkpoint_dict['iteration'] 66 | print("Loaded checkpoint '{}' from iteration {}" .format(checkpoint_path, iteration)) 67 | 68 | if optimizer is None: return model, optimizer, epoch, learning_rate, iteration 69 | 70 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 71 | return model, optimizer, epoch, learning_rate, iteration 72 | 73 | 74 | def save_checkpoint(model, optimizer, epoch, learning_rate, iteration, filepath): 75 | print(f"Saving model and optimizer state at iteration {iteration} to {filepath}") 76 | torch.save({'iteration': iteration, 77 | 'state_dict': model.state_dict(), 78 | 'optimizer': optimizer.state_dict(), 79 | 'epoch': epoch, 80 | 'learning_rate': learning_rate}, f'{filepath}/checkpoint_{iteration}') 81 | 82 | 83 | def lr_scheduling(opt, step, init_lr=hparams.lr, warmup_steps=hparams.warmup_steps): 84 | opt.param_groups[0]['lr'] = init_lr * min(step ** -0.5, (warmup_steps ** -1.5) * step) 85 | return 86 | 87 | 88 | def get_mask_from_lengths(lengths): 89 | max_len = torch.max(lengths).item() 90 | ids = lengths.new_tensor(torch.arange(0, max_len)) 91 | mask = (lengths.unsqueeze(1) <= ids.cuda()).to(torch.bool) 92 | return mask 93 | 94 | 95 | def get_mask(lengths): 96 | mask = torch.zeros(len(lengths), torch.max(lengths)).cuda() 97 | for i in range(len(mask)): 98 | mask[i] = torch.nn.functional.pad(torch.arange(1,lengths[i]+1),[0,torch.max(lengths)-lengths[i]],'constant') 99 | return mask.cuda() 100 | 101 | 102 | def reorder_batch(x, n_gpus): 103 | assert (len(x)%n_gpus)==0, 'Batch size must be a multiple of the number of GPUs.' 104 | if isinstance(x, list): 105 | return x 106 | new_x = x.new_zeros(x.size()) 107 | chunk_size = x.size(0)//n_gpus 108 | 109 | for i in range(n_gpus): 110 | new_x[i::n_gpus] = x[i*chunk_size:(i+1)*chunk_size] 111 | 112 | return new_x.cuda() 113 | 114 | 115 | @jit(nopython=True) 116 | def mas_width1(attn_map): 117 | """mas with hardcoded width=1""" 118 | # assumes mel x text 119 | opt = np.zeros_like(attn_map) 120 | attn_map = np.log(attn_map) 121 | attn_map[0, 1:] = -np.inf 122 | log_p = np.zeros_like(attn_map) 123 | log_p[0, :] = attn_map[0, :] 124 | prev_ind = np.zeros_like(attn_map, dtype=np.int64) 125 | for i in range(1, attn_map.shape[0]): 126 | for j in range(attn_map.shape[1]): # for each text dim 127 | prev_log = log_p[i - 1, j] 128 | prev_j = j 129 | 130 | if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]: 131 | prev_log = log_p[i - 1, j - 1] 132 | prev_j = j - 1 133 | 134 | log_p[i, j] = attn_map[i, j] + prev_log 135 | prev_ind[i, j] = prev_j 136 | 137 | # now backtrack 138 | curr_text_idx = attn_map.shape[1] - 1 139 | for i in range(attn_map.shape[0] - 1, -1, -1): 140 | opt[i, curr_text_idx] = 1 141 | curr_text_idx = prev_ind[i, curr_text_idx] 142 | opt[0, curr_text_idx] = 1 143 | return opt 144 | 145 | 146 | @jit(nopython=True, parallel=True) 147 | def b_mas(b_attn_map, in_lens, out_lens, width=1): 148 | assert width == 1 149 | attn_out = np.zeros_like(b_attn_map) 150 | 151 | for b in prange(b_attn_map.shape[0]): 152 | out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]]) 153 | attn_out[b, 0, : out_lens[b], : in_lens[b]] = out 154 | return attn_out 155 | 156 | def binarize_attention_parallel(attn, in_lens, out_lens): 157 | with torch.no_grad(): 158 | attn_cpu = attn.data.cpu().numpy() 159 | attn_out = b_mas(attn_cpu, in_lens.cpu().numpy(), out_lens.cpu().numpy(), width=1) 160 | return torch.from_numpy(attn_out).cuda() 161 | 162 | 163 | -------------------------------------------------------------------------------- /utils/writer.py: -------------------------------------------------------------------------------- 1 | import os, random, pdb 2 | from torch.utils.tensorboard import SummaryWriter 3 | 4 | from .plot_image import * 5 | 6 | def get_writer(output_directory, log_directory): 7 | logging_path=f'{output_directory}/{log_directory}' 8 | writer = TTSWriter(logging_path) 9 | 10 | return writer 11 | 12 | 13 | class TTSWriter(SummaryWriter): 14 | def __init__(self, log_dir): 15 | super(TTSWriter, self).__init__(log_dir) 16 | 17 | def add_losses(self, total_loss, mel_loss, bce_loss, guide_loss, ctc_loss, bin_loss, emo_loss, global_step, phase, f0_loss=None): 18 | self.add_scalar(f'{phase}/mel_loss', mel_loss, global_step) 19 | self.add_scalar(f'{phase}/bce_loss', bce_loss, global_step) 20 | self.add_scalar(f'{phase}/guide_loss', guide_loss, global_step) 21 | self.add_scalar(f'{phase}/ctc_loss', ctc_loss, global_step) 22 | self.add_scalar(f'{phase}/bin_loss', bin_loss, global_step) 23 | self.add_scalar(f'{phase}/emo_loss', emo_loss, global_step) 24 | self.add_scalar(f'{phase}/total_loss', total_loss, global_step) 25 | if f0_loss is not None: 26 | self.add_scalar(f'{phase}/F0_loss', f0_loss, global_step) 27 | 28 | def add_lr(self, current_lr, global_step, phase): 29 | self.add_scalar(f'{phase}/learning_rate', current_lr, global_step) 30 | 31 | def add_specs(self, mel_padded, mel_out, mel_lengths, global_step, phase, idx): 32 | mel_fig = plot_melspec(mel_padded, mel_out, mel_lengths, idx) 33 | self.add_figure(f'Plot/melspec', mel_fig, global_step) 34 | 35 | def add_alignments(self, enc_alignments, dec_alignments, enc_dec_alignments, 36 | text_padded, mel_lengths, text_lengths, global_step, phase, idx): 37 | enc_align_fig = plot_alignments(enc_alignments, text_padded, mel_lengths, text_lengths, 'enc', idx) 38 | self.add_figure(f'Alignment/encoder', enc_align_fig, global_step) 39 | 40 | dec_align_fig = plot_alignments(dec_alignments, text_padded, mel_lengths, text_lengths, 'dec', idx) 41 | self.add_figure(f'Alignment/decoder', dec_align_fig, global_step) 42 | 43 | enc_dec_align_fig = plot_alignments(enc_dec_alignments, text_padded, mel_lengths, text_lengths, 'enc_dec', idx) 44 | self.add_figure(f'Alignment/encoder_decoder', enc_dec_align_fig, global_step) 45 | 46 | def add_gates(self, gate_out, global_step, phase): 47 | gate_fig = plot_gate(gate_out) 48 | self.add_figure(f'Plot/gate_out', gate_fig, global_step) 49 | 50 | def plot_attn(logger, soft_A, hard_A, iteration): 51 | logger.add_figure("Alignment/soft_A", plot_alignment(soft_A.data.cpu().numpy()), iteration) 52 | logger.add_figure("Alignment/hard_A", plot_alignment(hard_A.data.cpu().numpy()), iteration) 53 | 54 | 55 | 56 | --------------------------------------------------------------------------------