├── .gitignore ├── data ├── hifigan ├── audio ├── __init__.py ├── __pycache__ │ ├── stft.cpython-38.pyc │ ├── tools.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── audio_processing.cpython-38.pyc ├── tools.py ├── audio_processing.py └── stft.py ├── output └── result │ └── Smart │ ├── 1.wav │ ├── 2.wav │ ├── 3.wav │ ├── 4.wav │ └── 5.wav ├── transformer ├── __init__.py ├── __pycache__ │ ├── Modules.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── SubLayers.cpython-38.pyc ├── Constants.py ├── Modules.py ├── SubLayers.py ├── Layers.py └── Models.py ├── text ├── transformer │ ├── __init__.py │ ├── __pycache__ │ │ ├── Layers.cpython-38.pyc │ │ ├── Models.cpython-38.pyc │ │ ├── Modules.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── Constants.cpython-38.pyc │ │ └── SubLayers.cpython-38.pyc │ ├── Constants.py │ ├── Modules.py │ ├── SubLayers.py │ ├── Layers.py │ └── Models.py ├── __pycache__ │ ├── cmudict.cpython-38.pyc │ ├── numbers.cpython-38.pyc │ ├── pinyin.cpython-38.pyc │ ├── symbols.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── cleaners.cpython-38.pyc ├── utils │ ├── __pycache__ │ │ ├── model.cpython-38.pyc │ │ └── tools.cpython-38.pyc │ ├── model.py │ └── tools.py ├── symbols.py ├── numbers.py ├── __init__.py ├── cleaners.py ├── cmudict.py └── pinyin.py ├── model ├── __pycache__ │ ├── loss.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── modules.cpython-38.pyc │ ├── optimizer.cpython-38.pyc │ └── fastspeech2_align.cpython-38.pyc ├── __init__.py ├── optimizer.py ├── fastspeech2_align.py ├── modules.py └── loss.py ├── utils ├── __pycache__ │ ├── model.cpython-38.pyc │ └── tools.cpython-38.pyc ├── model.py └── tools.py ├── preprocessor ├── __pycache__ │ ├── ljspeech.cpython-38.pyc │ └── preprocessor.cpython-38.pyc ├── ljspeech.py └── preprocessor.py ├── requirements.txt ├── preprocess.py ├── config └── LJSpeech │ ├── train.yaml │ ├── model.yaml │ └── preprocess.yaml ├── README.md ├── synthesize.py ├── train.py ├── dataset.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | ../../Pytorch/FastSpeech2/raw_data -------------------------------------------------------------------------------- /hifigan: -------------------------------------------------------------------------------- 1 | ../../Pytorch/FastSpeech2/hifigan -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- 1 | import audio.tools 2 | import audio.stft 3 | import audio.audio_processing 4 | -------------------------------------------------------------------------------- /output/result/Smart/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/output/result/Smart/1.wav -------------------------------------------------------------------------------- /output/result/Smart/2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/output/result/Smart/2.wav -------------------------------------------------------------------------------- /output/result/Smart/3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/output/result/Smart/3.wav -------------------------------------------------------------------------------- /output/result/Smart/4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/output/result/Smart/4.wav -------------------------------------------------------------------------------- /output/result/Smart/5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/output/result/Smart/5.wav -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Models import TxtEncoder, MelEncoder, MelDecoder 2 | from .Layers import PostNet, Prenet -------------------------------------------------------------------------------- /text/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .Models import TxtEncoder, MelEncoder, MelDecoder 2 | from .Layers import PostNet, Prenet -------------------------------------------------------------------------------- /audio/__pycache__/stft.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/audio/__pycache__/stft.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/model/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /audio/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/audio/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/cmudict.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/__pycache__/cmudict.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/numbers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/__pycache__/numbers.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/pinyin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/__pycache__/pinyin.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/__pycache__/symbols.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/utils/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/utils/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /audio/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/audio/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .fastspeech2_align import FastSpeech2Align 2 | from .loss import FastSpeech2Loss 3 | from .optimizer import ScheduledOptim -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/model/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/optimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/model/__pycache__/optimizer.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/__pycache__/cleaners.cpython-38.pyc -------------------------------------------------------------------------------- /text/utils/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/utils/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /text/utils/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/utils/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/Modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/transformer/__pycache__/Modules.cpython-38.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/transformer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /audio/__pycache__/audio_processing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/audio/__pycache__/audio_processing.cpython-38.pyc -------------------------------------------------------------------------------- /preprocessor/__pycache__/ljspeech.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/preprocessor/__pycache__/ljspeech.cpython-38.pyc -------------------------------------------------------------------------------- /transformer/__pycache__/SubLayers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/transformer/__pycache__/SubLayers.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/fastspeech2_align.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/model/__pycache__/fastspeech2_align.cpython-38.pyc -------------------------------------------------------------------------------- /preprocessor/__pycache__/preprocessor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/preprocessor/__pycache__/preprocessor.cpython-38.pyc -------------------------------------------------------------------------------- /text/transformer/__pycache__/Layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/transformer/__pycache__/Layers.cpython-38.pyc -------------------------------------------------------------------------------- /text/transformer/__pycache__/Models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/transformer/__pycache__/Models.cpython-38.pyc -------------------------------------------------------------------------------- /text/transformer/__pycache__/Modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/transformer/__pycache__/Modules.cpython-38.pyc -------------------------------------------------------------------------------- /text/transformer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/transformer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /text/transformer/__pycache__/Constants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/transformer/__pycache__/Constants.cpython-38.pyc -------------------------------------------------------------------------------- /text/transformer/__pycache__/SubLayers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SMART-TTS/SMART-NAR_Fast_TTS/HEAD/text/transformer/__pycache__/SubLayers.cpython-38.pyc -------------------------------------------------------------------------------- /transformer/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = "" 7 | UNK_WORD = "" 8 | BOS_WORD = "" 9 | EOS_WORD = "" 10 | -------------------------------------------------------------------------------- /text/transformer/Constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = "" 7 | UNK_WORD = "" 8 | BOS_WORD = "" 9 | EOS_WORD = "" 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | g2p-en == 2.1.0 2 | inflect == 4.1.0 3 | librosa == 0.7.2 4 | matplotlib == 3.2.2 5 | numba == 0.48 6 | numpy == 1.19.0 7 | pypinyin == 0.39.0 8 | pyworld == 0.2.10 9 | PyYAML == 5.4.1 10 | scikit-learn == 0.23.2 11 | scipy == 1.5.0 12 | soundfile == 0.10.3.post1 13 | tensorboard == 2.2.2 14 | tgt == 1.4.4 15 | torch == 1.7.0 16 | tqdm == 4.46.1 17 | unidecode == 1.1.1 18 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from preprocessor.preprocessor import Preprocessor 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 10 | args = parser.parse_args() 11 | 12 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 13 | preprocessor = Preprocessor(config) 14 | preprocessor.build_from_path() 15 | -------------------------------------------------------------------------------- /config/LJSpeech/train.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | ckpt_path: "./output/ckpt/LJSpeech" 3 | log_path: "./output/log/LJSpeech" 4 | result_path: "./output/result/LJSpeech" 5 | optimizer: 6 | batch_size: 48 7 | betas: [0.9, 0.98] 8 | eps: 0.000000001 9 | weight_decay: 0.0 10 | grad_clip_thresh: 1.0 11 | grad_acc_step: 1 12 | warm_up_step: 4000 13 | anneal_steps: [] 14 | anneal_rate: 1.0 15 | step: 16 | total_step: 160000 17 | log_step: 100 18 | synth_step: 100 19 | val_step: 100 20 | save_step: 10000 21 | -------------------------------------------------------------------------------- /transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | 19 | if mask is not None: 20 | attn = attn.masked_fill(mask, -np.inf) 21 | 22 | attn = self.softmax(attn) 23 | output = torch.bmm(attn, v) 24 | 25 | return output, attn 26 | -------------------------------------------------------------------------------- /text/transformer/Modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class ScaledDotProductAttention(nn.Module): 7 | """ Scaled Dot-Product Attention """ 8 | 9 | def __init__(self, temperature): 10 | super().__init__() 11 | self.temperature = temperature 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.bmm(q, k.transpose(1, 2)) 17 | attn = attn / self.temperature 18 | 19 | if mask is not None: 20 | attn = attn.masked_fill(mask, -np.inf) 21 | 22 | attn = self.softmax(attn) 23 | output = torch.bmm(attn, v) 24 | 25 | return output, attn 26 | -------------------------------------------------------------------------------- /config/LJSpeech/model.yaml: -------------------------------------------------------------------------------- 1 | transformer: 2 | encoder_layer: 4 3 | encoder_head: 2 4 | encoder_hidden: 256 5 | decoder_layer: 4 6 | decoder_head: 2 7 | decoder_hidden: 256 8 | conv_filter_size: 1024 9 | conv_kernel_size: [9, 1] 10 | encoder_dropout: 0.2 11 | decoder_dropout: 0.2 12 | 13 | variance_predictor: 14 | filter_size: 256 15 | kernel_size: 3 16 | dropout: 0.5 17 | 18 | variance_embedding: 19 | pitch_quantization: "log" # support 'linear' or 'log', 'log' is allowed only if the pitch values are not normalized during preprocessing 20 | energy_quantization: "linear" # support 'linear' or 'log', 'log' is allowed only if the energy values are not normalized during preprocessing 21 | n_bins: 256 22 | 23 | multi_speaker: False 24 | 25 | max_seq_len: 1000 26 | 27 | vocoder: 28 | model: "HiFi-GAN" # support 'HiFi-GAN', 'MelGAN' 29 | speaker: "LJSpeech" # support 'LJSpeech', 'universal' 30 | -------------------------------------------------------------------------------- /config/LJSpeech/preprocess.yaml: -------------------------------------------------------------------------------- 1 | dataset: "LJSpeech" 2 | 3 | path: 4 | lexicon_path: "lexicon/librispeech-lexicon.txt" 5 | data_path: "data/LJSpeech" 6 | preprocessed_path: "./preprocessed_data/LJSpeech" 7 | 8 | preprocessing: 9 | val_size: 512 10 | text: 11 | text_cleaners: ["english_cleaners"] 12 | language: "en" 13 | phone: 14 | use_hierarchical_aligner: true 15 | spm_level: [64, 4096] 16 | audio: 17 | sampling_rate: 22050 18 | max_wav_value: 32768.0 19 | stft: 20 | filter_length: 1024 21 | hop_length: 256 22 | win_length: 1024 23 | mel: 24 | n_mel_channels: 80 25 | mel_fmin: 0 26 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder 27 | pitch: 28 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 29 | normalization: True 30 | energy: 31 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 32 | normalization: True 33 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from text import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | ) 30 | -------------------------------------------------------------------------------- /audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | 5 | from audio.audio_processing import griffin_lim 6 | 7 | 8 | def get_mel_from_wav(audio, _stft): 9 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 10 | audio = torch.autograd.Variable(audio, requires_grad=False) 11 | melspec, energy = _stft.mel_spectrogram(audio) 12 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 13 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 14 | 15 | return melspec, energy 16 | 17 | 18 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 19 | mel = torch.stack([mel]) 20 | mel_decompress = _stft.spectral_de_normalize(mel) 21 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 22 | spec_from_mel_scaling = 1000 23 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 24 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 25 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 26 | 27 | audio = griffin_lim( 28 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 29 | ) 30 | 31 | audio = audio.squeeze() 32 | audio = audio.cpu().numpy() 33 | audio_path = out_filename 34 | write(audio_path, _stft.sampling_rate, audio) 35 | -------------------------------------------------------------------------------- /preprocessor/ljspeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from tqdm import tqdm 7 | 8 | from text import _clean_text 9 | 10 | 11 | def prepare_align(config): 12 | in_dir = config["path"]["corpus_path"] 13 | out_dir = config["path"]["raw_path"] 14 | sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 15 | max_wav_value = config["preprocessing"]["audio"]["max_wav_value"] 16 | cleaners = config["preprocessing"]["text"]["text_cleaners"] 17 | speaker = "LJSpeech" 18 | with open(os.path.join(in_dir, "metadata.csv"), encoding="utf-8") as f: 19 | for line in tqdm(f): 20 | parts = line.strip().split("|") 21 | base_name = parts[0] 22 | text = parts[2] 23 | text = _clean_text(text, cleaners) 24 | 25 | wav_path = os.path.join(in_dir, "wavs", "{}.wav".format(base_name)) 26 | if os.path.exists(wav_path): 27 | os.makedirs(os.path.join(out_dir, speaker), exist_ok=True) 28 | wav, _ = librosa.load(wav_path, sampling_rate) 29 | wav = wav / max(abs(wav)) * max_wav_value 30 | wavfile.write( 31 | os.path.join(out_dir, speaker, "{}.wav".format(base_name)), 32 | sampling_rate, 33 | wav.astype(np.int16), 34 | ) 35 | with open( 36 | os.path.join(out_dir, speaker, "{}.lab".format(base_name)), 37 | "w", 38 | ) as f1: 39 | f1.write(text) 40 | -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ScheduledOptim: 6 | """ A simple wrapper class for learning rate scheduling """ 7 | 8 | def __init__(self, model, train_config, model_config, current_step): 9 | 10 | self._optimizer = torch.optim.Adam( 11 | model.parameters(), 12 | betas=train_config["optimizer"]["betas"], 13 | eps=train_config["optimizer"]["eps"], 14 | weight_decay=train_config["optimizer"]["weight_decay"], 15 | ) 16 | self.n_warmup_steps = train_config["optimizer"]["warm_up_step"] 17 | self.anneal_steps = train_config["optimizer"]["anneal_steps"] 18 | self.anneal_rate = train_config["optimizer"]["anneal_rate"] 19 | self.current_step = current_step 20 | self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) 21 | 22 | def step_and_update_lr(self): 23 | self._update_learning_rate() 24 | self._optimizer.step() 25 | 26 | def zero_grad(self): 27 | # print(self.init_lr) 28 | self._optimizer.zero_grad() 29 | 30 | def load_state_dict(self, path): 31 | self._optimizer.load_state_dict(path) 32 | 33 | def _get_lr_scale(self): 34 | lr = np.min( 35 | [ 36 | np.power(self.current_step, -0.5), 37 | np.power(self.n_warmup_steps, -1.5) * self.current_step, 38 | ] 39 | ) 40 | for s in self.anneal_steps: 41 | if self.current_step > s: 42 | lr = lr * self.anneal_rate 43 | return lr 44 | 45 | def _update_learning_rate(self): 46 | """ Learning rate scheduling per step """ 47 | self.current_step += 1 48 | lr = self.init_lr * self._get_lr_scale() 49 | 50 | for param_group in self._optimizer.param_groups: 51 | param_group["lr"] = lr 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMART-NAR_Fast_TTS 2 | FastSpeech2 기반의 SMART-TTS의 Non-autoregressive TTS 모델입니다. 공개된 코드는 2021년도 과학기술통신부의 재원으로 정보통신기획평가원(IITP)의 지원을 받아 수행한 "소량 데이터만을 이용한 고품질 종단형 기반의 딥러닝 다화자 운율 및 감정 복제 기술 개발" 과제의 일환으로 공개된 코드입니다. 3 | 4 | SMART-TTS_NAR_Fast_TTS 모델 v2.0.0 은 [FastSpeech2 모델](https://github.com/ming024/FastSpeech2)을 기반으로 alignment를 external duration label 없이 모델링하는 non-autoregressive 구조의 TTS 모델입니다. 5 | 6 | FastSpeeche2 모델을 기반으로 하여 아래 부분들을 개선하였습니다. 7 | 8 | Done 9 | * Acoustic feature 를 encoding 하는 reference encoder 추가 10 | * Linguistic feature 와 acoustic feature 사이의 alignment를 학습하기 위한 attention module 추가 11 | * Alignment 로부터 duration predictor 학습을 위한 duration label 추출 12 | * Predicted duration 을 기반으로 Gaussian upsampling 적용 13 | 14 | # Environment 15 | Under Python 3.6 16 | 17 | # Requirements 18 | To install requirements: 19 |
20 | 
21 | pip install -r requirements.txt
22 | 
23 | 
24 | 25 | # Preprocessing 26 | To preprocess: 27 |
28 | 
29 | python3 preprocess.py --conf {preprocess configuration file path}
30 | 
31 | 
32 | 
33 | 34 | # Training 35 | To train the NAR TTS model, run this command: 36 |
37 | 
38 | python3 train.py -p {preprocess config file path} -m {model condig file path} -t {training config file path}
39 | 
40 | 
41 | 42 | # Evaluation 43 | To evaluate, run: 44 |
45 | 
46 | python3 synthesize.py --text  --restore_step {restore step} -p {preprocess config file path} -m {model condig file path} -t {training config file path}
47 | 
48 | 
49 | 50 | # Results 51 | Synthesized audio samples can be found in ./output/results 52 | 53 | 현재 ./output/results 저장된 샘플들은 연구실 보유중인 DB를 사용해 학습한 샘플입니다. 54 | 55 | # Reference 56 | * <1> [ming024's FastSpeech 2 implementation](https://github.com/ming024/FastSpeech2) 57 | * <2> [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558), Y. Ren, *et al*. 58 | 59 | # Technical Document 60 | 본 프로젝트 관련 개선사항들에 대한 기술문서는 [여기](https://drive.google.com/file/d/1G1Afg1PwdW5TQcXuTKeZDZ_67XTlchiY/view?usp=sharing)를 참고해 주세요. 61 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | """ 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | 34 | if "korean_cleaners" in cleaner_names: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 36 | break 37 | 38 | if not m: 39 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 40 | break 41 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 42 | sequence += _arpabet_to_sequence(m.group(2)) 43 | text = m.group(3) 44 | 45 | return sequence 46 | 47 | 48 | def sequence_to_text(sequence): 49 | """Converts a sequence of IDs back to a string""" 50 | result = "" 51 | for symbol_id in sequence: 52 | if symbol_id in _id_to_symbol: 53 | s = _id_to_symbol[symbol_id] 54 | # Enclose ARPAbet back in curly braces: 55 | if len(s) > 1 and s[0] == "@": 56 | s = "{%s}" % s[1:] 57 | result += s 58 | return result.replace("}{", " ") 59 | 60 | 61 | def _clean_text(text, cleaner_names): 62 | for name in cleaner_names: 63 | cleaner = getattr(cleaners, name) 64 | if not cleaner: 65 | raise Exception("Unknown cleaner: %s" % name) 66 | text = cleaner(text) 67 | return text 68 | 69 | 70 | def _symbols_to_sequence(symbols): 71 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 72 | 73 | 74 | def _arpabet_to_sequence(text): 75 | return _symbols_to_sequence(["@" + s for s in text.split()]) 76 | 77 | 78 | def _should_keep_symbol(s): 79 | return s in _symbol_to_id and s != "_" and s != "~" 80 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaners(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text 90 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import hifigan 8 | from model import FastSpeech2Align, ScheduledOptim 9 | 10 | 11 | def get_model(args, configs, device, train=False): 12 | (preprocess_config, model_config, train_config) = configs 13 | 14 | model = FastSpeech2Align(preprocess_config, model_config).to(device) 15 | 16 | if args.restore_step: 17 | ckpt_path = os.path.join( 18 | train_config["path"]["ckpt_path"], 19 | "{}.pth.tar".format(args.restore_step), 20 | ) 21 | ckpt = torch.load(ckpt_path) 22 | model.load_state_dict(ckpt["model"]) 23 | 24 | if train: 25 | scheduled_optim = ScheduledOptim( 26 | model, train_config, model_config, args.restore_step 27 | ) 28 | if args.restore_step: 29 | scheduled_optim.load_state_dict(ckpt["optimizer"]) 30 | model.train() 31 | return model, scheduled_optim 32 | 33 | model.eval() 34 | model.requires_grad_ = False 35 | return model 36 | 37 | 38 | def get_vocoder(config, device): 39 | name = config["vocoder"]["model"] 40 | speaker = config["vocoder"]["speaker"] 41 | 42 | if name == "MelGAN": 43 | if speaker == "LJSpeech": 44 | vocoder = torch.hub.load( 45 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 46 | ) 47 | elif speaker == "universal": 48 | vocoder = torch.hub.load( 49 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 50 | ) 51 | vocoder.mel2wav.eval() 52 | vocoder.mel2wav.to(device) 53 | elif name == "HiFi-GAN": 54 | with open("hifigan/config.json", "r") as f: 55 | config = json.load(f) 56 | config = hifigan.AttrDict(config) 57 | vocoder = hifigan.Generator(config) 58 | if speaker == "LJSpeech": 59 | ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar") 60 | elif speaker == "universal": 61 | ckpt = torch.load("hifigan/generator_universal.pth.tar") 62 | vocoder.load_state_dict(ckpt["generator"]) 63 | vocoder.eval() 64 | vocoder.remove_weight_norm() 65 | vocoder.to(device) 66 | 67 | return vocoder 68 | 69 | 70 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): 71 | name = model_config["vocoder"]["model"] 72 | with torch.no_grad(): 73 | if name == "MelGAN": 74 | wavs = vocoder.inverse(mels / np.log(10)) 75 | elif name == "HiFi-GAN": 76 | wavs = vocoder(mels).squeeze(1) 77 | 78 | wavs = ( 79 | wavs.cpu().numpy() 80 | * preprocess_config["preprocessing"]["audio"]["max_wav_value"] 81 | ).astype("int16") 82 | wavs = [wav for wav in wavs] 83 | 84 | for i in range(len(mels)): 85 | if lengths is not None: 86 | wavs[i] = wavs[i][: lengths[i]] 87 | 88 | return wavs 89 | -------------------------------------------------------------------------------- /text/utils/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import hifigan 8 | from model import FastSpeech2Align, ScheduledOptim 9 | 10 | 11 | def get_model(args, configs, device, train=False): 12 | (preprocess_config, model_config, train_config) = configs 13 | 14 | model = FastSpeech2Align(preprocess_config, model_config).to(device) 15 | 16 | if args.restore_step: 17 | ckpt_path = os.path.join( 18 | train_config["path"]["ckpt_path"], 19 | "{}.pth.tar".format(args.restore_step), 20 | ) 21 | ckpt = torch.load(ckpt_path) 22 | model.load_state_dict(ckpt["model"]) 23 | 24 | if train: 25 | scheduled_optim = ScheduledOptim( 26 | model, train_config, model_config, args.restore_step 27 | ) 28 | if args.restore_step: 29 | scheduled_optim.load_state_dict(ckpt["optimizer"]) 30 | model.train() 31 | return model, scheduled_optim 32 | 33 | model.eval() 34 | model.requires_grad_ = False 35 | return model 36 | 37 | 38 | def get_param_num(model): 39 | num_param = sum(param.numel() for param in model.parameters()) 40 | return num_param 41 | 42 | 43 | def get_vocoder(config, device): 44 | name = config["vocoder"]["model"] 45 | speaker = config["vocoder"]["speaker"] 46 | 47 | if name == "MelGAN": 48 | if speaker == "LJSpeech": 49 | vocoder = torch.hub.load( 50 | "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" 51 | ) 52 | elif speaker == "universal": 53 | vocoder = torch.hub.load( 54 | "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" 55 | ) 56 | vocoder.mel2wav.eval() 57 | vocoder.mel2wav.to(device) 58 | elif name == "HiFi-GAN": 59 | with open("hifigan/config.json", "r") as f: 60 | config = json.load(f) 61 | config = hifigan.AttrDict(config) 62 | vocoder = hifigan.Generator(config) 63 | if speaker == "LJSpeech": 64 | ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar") 65 | elif speaker == "universal": 66 | ckpt = torch.load("hifigan/generator_universal.pth.tar") 67 | vocoder.load_state_dict(ckpt["generator"]) 68 | vocoder.eval() 69 | vocoder.remove_weight_norm() 70 | vocoder.to(device) 71 | 72 | return vocoder 73 | 74 | 75 | def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): 76 | name = model_config["vocoder"]["model"] 77 | with torch.no_grad(): 78 | if name == "MelGAN": 79 | wavs = vocoder.inverse(mels / np.log(10)) 80 | elif name == "HiFi-GAN": 81 | wavs = vocoder(mels).squeeze(1) 82 | 83 | wavs = ( 84 | wavs.cpu().numpy() 85 | * preprocess_config["preprocessing"]["audio"]["max_wav_value"] 86 | ).astype("int16") 87 | wavs = [wav for wav in wavs] 88 | 89 | for i in range(len(mels)): 90 | if lengths is not None: 91 | wavs[i] = wavs[i][: lengths[i]] 92 | 93 | return wavs 94 | -------------------------------------------------------------------------------- /audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return torch.log(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from .Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ Multi-Head Attention module """ 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v) 21 | 22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 23 | self.layer_norm = nn.LayerNorm(d_model) 24 | 25 | self.fc = nn.Linear(n_head * d_v, d_model) 26 | 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | 33 | sz_b, len_q, _ = q.size() 34 | sz_b, len_k, _ = k.size() 35 | sz_b, len_v, _ = v.size() 36 | 37 | residual = q 38 | 39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 45 | 46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 47 | output, attn = self.attention(q, k, v, mask=mask) 48 | attn = attn.view(n_head, sz_b, attn.shape[1], attn.shape[2]) 49 | attn = attn.transpose(0, 1) 50 | 51 | output = output.view(n_head, sz_b, len_q, d_v) 52 | output = ( 53 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 54 | ) # b x lq x (n*dv) 55 | 56 | output = self.dropout(self.fc(output)) 57 | output = self.layer_norm(output + residual) 58 | 59 | return output, attn 60 | 61 | 62 | class PositionwiseFeedForward(nn.Module): 63 | """ A two-feed-forward-layer module """ 64 | 65 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 66 | super().__init__() 67 | 68 | # Use Conv1D 69 | # position-wise 70 | self.w_1 = nn.Conv1d( 71 | d_in, 72 | d_hid, 73 | kernel_size=kernel_size[0], 74 | padding=(kernel_size[0] - 1) // 2, 75 | ) 76 | # position-wise 77 | self.w_2 = nn.Conv1d( 78 | d_hid, 79 | d_in, 80 | kernel_size=kernel_size[1], 81 | padding=(kernel_size[1] - 1) // 2, 82 | ) 83 | 84 | self.layer_norm = nn.LayerNorm(d_in) 85 | self.dropout = nn.Dropout(dropout) 86 | 87 | def forward(self, x): 88 | residual = x 89 | output = x.transpose(1, 2) 90 | output = self.w_2(F.relu(self.w_1(output))) 91 | output = output.transpose(1, 2) 92 | output = self.dropout(output) 93 | output = self.layer_norm(output + residual) 94 | 95 | return output 96 | -------------------------------------------------------------------------------- /text/transformer/SubLayers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from .Modules import ScaledDotProductAttention 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ Multi-Head Attention module """ 10 | 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 12 | super().__init__() 13 | 14 | self.n_head = n_head 15 | self.d_k = d_k 16 | self.d_v = d_v 17 | 18 | self.w_qs = nn.Linear(d_model, n_head * d_k) 19 | self.w_ks = nn.Linear(d_model, n_head * d_k) 20 | self.w_vs = nn.Linear(d_model, n_head * d_v) 21 | 22 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 23 | self.layer_norm = nn.LayerNorm(d_model) 24 | 25 | self.fc = nn.Linear(n_head * d_v, d_model) 26 | 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, q, k, v, mask=None): 30 | 31 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 32 | 33 | sz_b, len_q, _ = q.size() 34 | sz_b, len_k, _ = k.size() 35 | sz_b, len_v, _ = v.size() 36 | 37 | residual = q 38 | 39 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 40 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 41 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 42 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 43 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 44 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 45 | 46 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 47 | output, attn = self.attention(q, k, v, mask=mask) 48 | attn = attn.view(n_head, sz_b, attn.shape[1], attn.shape[2]) 49 | attn = attn.transpose(0, 1) 50 | 51 | output = output.view(n_head, sz_b, len_q, d_v) 52 | output = ( 53 | output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) 54 | ) # b x lq x (n*dv) 55 | 56 | output = self.dropout(self.fc(output)) 57 | output = self.layer_norm(output + residual) 58 | 59 | return output, attn 60 | 61 | 62 | class PositionwiseFeedForward(nn.Module): 63 | """ A two-feed-forward-layer module """ 64 | 65 | def __init__(self, d_in, d_hid, kernel_size, dropout=0.1): 66 | super().__init__() 67 | 68 | # Use Conv1D 69 | # position-wise 70 | self.w_1 = nn.Conv1d( 71 | d_in, 72 | d_hid, 73 | kernel_size=kernel_size[0], 74 | padding=(kernel_size[0] - 1) // 2, 75 | ) 76 | # position-wise 77 | self.w_2 = nn.Conv1d( 78 | d_hid, 79 | d_in, 80 | kernel_size=kernel_size[1], 81 | padding=(kernel_size[1] - 1) // 2, 82 | ) 83 | 84 | self.layer_norm = nn.LayerNorm(d_in) 85 | self.dropout = nn.Dropout(dropout) 86 | 87 | def forward(self, x): 88 | residual = x 89 | output = x.transpose(1, 2) 90 | output = self.w_2(F.relu(self.w_1(output))) 91 | output = output.transpose(1, 2) 92 | output = self.dropout(output) 93 | output = self.layer_norm(output + residual) 94 | 95 | return output 96 | -------------------------------------------------------------------------------- /model/fastspeech2_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from transformer import TxtEncoder, MelEncoder, MelDecoder, PostNet 9 | from .modules import VarianceAdaptor 10 | from utils.tools import get_mask_from_lengths 11 | 12 | 13 | class FastSpeech2Align(nn.Module): 14 | """ FastSpeech2 """ 15 | 16 | def __init__(self, preprocess_config, model_config): 17 | super(FastSpeech2Align, self).__init__() 18 | self.model_config = model_config 19 | 20 | self.txt_encoder = TxtEncoder(model_config) 21 | self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config) 22 | self.mel_encoder = MelEncoder(model_config) 23 | self.mel_decoder = MelDecoder(model_config) 24 | self.mel_linear = nn.Linear( 25 | model_config["transformer"]["decoder_hidden"], 26 | preprocess_config["preprocessing"]["mel"]["n_mel_channels"], 27 | ) 28 | self.postnet = PostNet() 29 | 30 | def forward( 31 | self, 32 | speakers, 33 | texts, 34 | src_lens, 35 | max_src_len, 36 | mels=None, 37 | mel_lens=None, 38 | max_mel_len=None, 39 | p_targets=None, 40 | e_targets=None, 41 | p_control=1.0, 42 | e_control=1.0, 43 | ): 44 | is_training = False if mel_lens is None else True 45 | 46 | src_masks = get_mask_from_lengths(src_lens, max_src_len) 47 | mel_masks = ( 48 | get_mask_from_lengths(mel_lens, max_mel_len) 49 | if mel_lens is not None 50 | else None 51 | ) 52 | 53 | src_output = self.txt_encoder(texts, src_masks) 54 | 55 | if is_training: 56 | tgt_output, tgt_alignment = self.mel_encoder(src_output, mels, src_masks, mel_masks) 57 | d_targets = torch.stack([self._calculate_duration(attn, src_len, mel_len, max_src_len) 58 | for attn, src_len, mel_len in zip(tgt_alignment[-1].detach(), src_lens, mel_lens)]) 59 | else: 60 | d_targets, tgt_alignment = None, None 61 | 62 | ( 63 | output, 64 | p_predictions, 65 | e_predictions, 66 | log_d_predictions, 67 | d_rounded, 68 | mel_lens, 69 | mel_masks, 70 | ) = self.variance_adaptor( 71 | src_output, 72 | src_masks, 73 | mel_masks, 74 | max_mel_len, 75 | p_targets, 76 | e_targets, 77 | d_targets, 78 | p_control, 79 | e_control, 80 | ) 81 | 82 | output, mel_masks = self.mel_decoder(output, mel_masks) 83 | output = self.mel_linear(output) 84 | 85 | postnet_output = self.postnet(output) + output 86 | 87 | return ( 88 | output, 89 | postnet_output, 90 | p_predictions, 91 | e_predictions, 92 | log_d_predictions, 93 | d_rounded, 94 | src_masks, 95 | mel_masks, 96 | src_lens, 97 | mel_lens, 98 | tgt_alignment, 99 | d_targets, 100 | ) 101 | -------------------------------------------------------------------------------- /text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | from string import punctuation 4 | 5 | import torch 6 | import yaml 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | from g2p_en import G2p 10 | from pypinyin import pinyin, Style 11 | 12 | from utils.model import get_model, get_vocoder 13 | from utils.tools import to_device, synth_samples 14 | from dataset import TextDataset 15 | from text import text_to_sequence 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def read_lexicon(lex_path): 21 | lexicon = {} 22 | with open(lex_path) as f: 23 | for line in f: 24 | temp = re.split(r"\s+", line.strip("\n")) 25 | word = temp[0] 26 | phones = temp[1:] 27 | if word.lower() not in lexicon: 28 | lexicon[word.lower()] = phones 29 | return lexicon 30 | 31 | 32 | def preprocess_english(text, preprocess_config): 33 | text = text.rstrip(punctuation) 34 | lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) 35 | 36 | g2p = G2p() 37 | phones = [] 38 | words = re.split(r"([,;.\-\?\!\s+])", text) 39 | for w in words: 40 | if w.lower() in lexicon: 41 | phones += lexicon[w.lower()] 42 | else: 43 | phones += list(filter(lambda p: p != " ", g2p(w))) 44 | phones = "{" + "}{".join(phones) + "}" 45 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) 46 | phones = phones.replace("}{", " ") 47 | 48 | print("Raw Text Sequence: {}".format(text)) 49 | print("Phoneme Sequence: {}".format(phones)) 50 | sequence = np.array( 51 | text_to_sequence( 52 | phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] 53 | ) 54 | ) 55 | 56 | return np.array(sequence) 57 | 58 | 59 | def synthesize(model, step, configs, vocoder, batchs): 60 | preprocess_config, model_config, train_config = configs 61 | 62 | for batch in batchs: 63 | batch = to_device(batch, device) 64 | with torch.no_grad(): 65 | # Forward 66 | output = model( 67 | *(batch[2:]) 68 | ) 69 | synth_samples( 70 | batch, 71 | output, 72 | vocoder, 73 | model_config, 74 | preprocess_config, 75 | train_config["path"]["result_path"], 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("--restore_step", type=int, required=True) 83 | parser.add_argument( 84 | "--text", 85 | type=str, 86 | default=None, 87 | help="raw text to synthesize, for single-sentence mode only", 88 | ) 89 | parser.add_argument( 90 | "-p", 91 | "--preprocess_config", 92 | type=str, 93 | required=True, 94 | help="path to preprocess.yaml", 95 | ) 96 | parser.add_argument( 97 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 98 | ) 99 | parser.add_argument( 100 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 101 | ) 102 | args = parser.parse_args() 103 | 104 | # Check source texts 105 | assert args.text is not None 106 | 107 | # Read Config 108 | preprocess_config = yaml.load( 109 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 110 | ) 111 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 112 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 113 | configs = (preprocess_config, model_config, train_config) 114 | 115 | # Get model 116 | model = get_model(args, configs, device, train=False) 117 | 118 | # Load vocoder 119 | vocoder = get_vocoder(model_config, device) 120 | 121 | ids = raw_texts = [args.text[:100]] 122 | speakers = np.array([args.speaker_id]) 123 | if preprocess_config["preprocessing"]["text"]["language"] == "en": 124 | texts = np.array([preprocess_english(args.text, preprocess_config)]) 125 | text_lens = np.array([len(texts[0])]) 126 | batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))] 127 | 128 | synthesize(model, args.restore_step, configs , vocoder, batchs) 129 | -------------------------------------------------------------------------------- /transformer/Layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.nn import functional as F 7 | 8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward 9 | 10 | 11 | class Prenet(nn.Module): 12 | """ 13 | Prenet before passing through the network 14 | """ 15 | def __init__(self): 16 | super(Prenet, self).__init__() 17 | 18 | self.w_1 = nn.Linear(80, 256) 19 | self.w_2 = nn.Linear(256, 256) 20 | self.dropout = nn.Dropout(0.2) 21 | 22 | def forward(self, x): 23 | output = F.relu(self.w_2(F.relu(self.w_1(x)))) 24 | output = self.dropout(output) 25 | 26 | return output 27 | 28 | 29 | class FFTBlock(torch.nn.Module): 30 | """FFT Block""" 31 | 32 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 33 | super(FFTBlock, self).__init__() 34 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward( 36 | d_model, d_inner, kernel_size, dropout=dropout 37 | ) 38 | 39 | def forward(self, enc_input, mask=None, slf_attn_mask=None): 40 | enc_output, enc_slf_attn = self.slf_attn( 41 | enc_input, enc_input, enc_input, mask=slf_attn_mask 42 | ) 43 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 44 | 45 | enc_output = self.pos_ffn(enc_output) 46 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 47 | 48 | return enc_output, enc_slf_attn 49 | 50 | 51 | class FFTBlock2(torch.nn.Module): 52 | """FFT Block""" 53 | 54 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 55 | super(FFTBlock2, self).__init__() 56 | self.crs_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 57 | self.pos_ffn = PositionwiseFeedForward( 58 | d_model, d_inner, kernel_size, dropout=dropout 59 | ) 60 | 61 | def forward(self, src_input, tgt_input, mask=None, crs_attn_mask=None): 62 | enc_output, enc_crs_attn = self.crs_attn( 63 | tgt_input, src_input, src_input, mask=crs_attn_mask 64 | ) 65 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 66 | 67 | enc_output = self.pos_ffn(enc_output) 68 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 69 | 70 | return enc_output, enc_crs_attn 71 | 72 | 73 | class ConvNorm(torch.nn.Module): 74 | def __init__( 75 | self, 76 | in_channels, 77 | out_channels, 78 | kernel_size=1, 79 | stride=1, 80 | padding=None, 81 | dilation=1, 82 | bias=True, 83 | w_init_gain="linear", 84 | ): 85 | super(ConvNorm, self).__init__() 86 | 87 | if padding is None: 88 | assert kernel_size % 2 == 1 89 | padding = int(dilation * (kernel_size - 1) / 2) 90 | 91 | self.conv = torch.nn.Conv1d( 92 | in_channels, 93 | out_channels, 94 | kernel_size=kernel_size, 95 | stride=stride, 96 | padding=padding, 97 | dilation=dilation, 98 | bias=bias, 99 | ) 100 | 101 | def forward(self, signal): 102 | conv_signal = self.conv(signal) 103 | 104 | return conv_signal 105 | 106 | 107 | class PostNet(nn.Module): 108 | """ 109 | PostNet: Five 1-d convolution with 512 channels and kernel size 5 110 | """ 111 | 112 | def __init__( 113 | self, 114 | n_mel_channels=80, 115 | postnet_embedding_dim=512, 116 | postnet_kernel_size=5, 117 | postnet_n_convolutions=5, 118 | ): 119 | 120 | super(PostNet, self).__init__() 121 | self.convolutions = nn.ModuleList() 122 | 123 | self.convolutions.append( 124 | nn.Sequential( 125 | ConvNorm( 126 | n_mel_channels, 127 | postnet_embedding_dim, 128 | kernel_size=postnet_kernel_size, 129 | stride=1, 130 | padding=int((postnet_kernel_size - 1) / 2), 131 | dilation=1, 132 | w_init_gain="tanh", 133 | ), 134 | nn.BatchNorm1d(postnet_embedding_dim), 135 | ) 136 | ) 137 | 138 | for i in range(1, postnet_n_convolutions - 1): 139 | self.convolutions.append( 140 | nn.Sequential( 141 | ConvNorm( 142 | postnet_embedding_dim, 143 | postnet_embedding_dim, 144 | kernel_size=postnet_kernel_size, 145 | stride=1, 146 | padding=int((postnet_kernel_size - 1) / 2), 147 | dilation=1, 148 | w_init_gain="tanh", 149 | ), 150 | nn.BatchNorm1d(postnet_embedding_dim), 151 | ) 152 | ) 153 | 154 | self.convolutions.append( 155 | nn.Sequential( 156 | ConvNorm( 157 | postnet_embedding_dim, 158 | n_mel_channels, 159 | kernel_size=postnet_kernel_size, 160 | stride=1, 161 | padding=int((postnet_kernel_size - 1) / 2), 162 | dilation=1, 163 | w_init_gain="linear", 164 | ), 165 | nn.BatchNorm1d(n_mel_channels), 166 | ) 167 | ) 168 | 169 | def forward(self, x): 170 | x = x.contiguous().transpose(1, 2) 171 | 172 | for i in range(len(self.convolutions) - 1): 173 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 174 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 175 | 176 | x = x.contiguous().transpose(1, 2) 177 | return x 178 | -------------------------------------------------------------------------------- /text/transformer/Layers.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.nn import functional as F 7 | 8 | from .SubLayers import MultiHeadAttention, PositionwiseFeedForward 9 | 10 | 11 | class Prenet(nn.Module): 12 | """ 13 | Prenet before passing through the network 14 | """ 15 | def __init__(self): 16 | super(Prenet, self).__init__() 17 | 18 | self.w_1 = nn.Linear(80, 256) 19 | self.w_2 = nn.Linear(256, 256) 20 | self.dropout = nn.Dropout(0.2) 21 | 22 | def forward(self, x): 23 | output = F.relu(self.w_2(F.relu(self.w_1(x)))) 24 | output = self.dropout(output) 25 | 26 | return output 27 | 28 | 29 | class FFTBlock(torch.nn.Module): 30 | """FFT Block""" 31 | 32 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 33 | super(FFTBlock, self).__init__() 34 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 35 | self.pos_ffn = PositionwiseFeedForward( 36 | d_model, d_inner, kernel_size, dropout=dropout 37 | ) 38 | 39 | def forward(self, enc_input, mask=None, slf_attn_mask=None): 40 | enc_output, enc_slf_attn = self.slf_attn( 41 | enc_input, enc_input, enc_input, mask=slf_attn_mask 42 | ) 43 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 44 | 45 | enc_output = self.pos_ffn(enc_output) 46 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 47 | 48 | return enc_output, enc_slf_attn 49 | 50 | 51 | class FFTBlock2(torch.nn.Module): 52 | """FFT Block""" 53 | 54 | def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1): 55 | super(FFTBlock2, self).__init__() 56 | self.crs_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 57 | self.pos_ffn = PositionwiseFeedForward( 58 | d_model, d_inner, kernel_size, dropout=dropout 59 | ) 60 | 61 | def forward(self, src_input, tgt_input, mask=None, crs_attn_mask=None): 62 | enc_output, enc_crs_attn = self.crs_attn( 63 | tgt_input, src_input, src_input, mask=crs_attn_mask 64 | ) 65 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 66 | 67 | enc_output = self.pos_ffn(enc_output) 68 | enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0) 69 | 70 | return enc_output, enc_crs_attn 71 | 72 | 73 | class ConvNorm(torch.nn.Module): 74 | def __init__( 75 | self, 76 | in_channels, 77 | out_channels, 78 | kernel_size=1, 79 | stride=1, 80 | padding=None, 81 | dilation=1, 82 | bias=True, 83 | w_init_gain="linear", 84 | ): 85 | super(ConvNorm, self).__init__() 86 | 87 | if padding is None: 88 | assert kernel_size % 2 == 1 89 | padding = int(dilation * (kernel_size - 1) / 2) 90 | 91 | self.conv = torch.nn.Conv1d( 92 | in_channels, 93 | out_channels, 94 | kernel_size=kernel_size, 95 | stride=stride, 96 | padding=padding, 97 | dilation=dilation, 98 | bias=bias, 99 | ) 100 | 101 | def forward(self, signal): 102 | conv_signal = self.conv(signal) 103 | 104 | return conv_signal 105 | 106 | 107 | class PostNet(nn.Module): 108 | """ 109 | PostNet: Five 1-d convolution with 512 channels and kernel size 5 110 | """ 111 | 112 | def __init__( 113 | self, 114 | n_mel_channels=80, 115 | postnet_embedding_dim=512, 116 | postnet_kernel_size=5, 117 | postnet_n_convolutions=5, 118 | ): 119 | 120 | super(PostNet, self).__init__() 121 | self.convolutions = nn.ModuleList() 122 | 123 | self.convolutions.append( 124 | nn.Sequential( 125 | ConvNorm( 126 | n_mel_channels, 127 | postnet_embedding_dim, 128 | kernel_size=postnet_kernel_size, 129 | stride=1, 130 | padding=int((postnet_kernel_size - 1) / 2), 131 | dilation=1, 132 | w_init_gain="tanh", 133 | ), 134 | nn.BatchNorm1d(postnet_embedding_dim), 135 | ) 136 | ) 137 | 138 | for i in range(1, postnet_n_convolutions - 1): 139 | self.convolutions.append( 140 | nn.Sequential( 141 | ConvNorm( 142 | postnet_embedding_dim, 143 | postnet_embedding_dim, 144 | kernel_size=postnet_kernel_size, 145 | stride=1, 146 | padding=int((postnet_kernel_size - 1) / 2), 147 | dilation=1, 148 | w_init_gain="tanh", 149 | ), 150 | nn.BatchNorm1d(postnet_embedding_dim), 151 | ) 152 | ) 153 | 154 | self.convolutions.append( 155 | nn.Sequential( 156 | ConvNorm( 157 | postnet_embedding_dim, 158 | n_mel_channels, 159 | kernel_size=postnet_kernel_size, 160 | stride=1, 161 | padding=int((postnet_kernel_size - 1) / 2), 162 | dilation=1, 163 | w_init_gain="linear", 164 | ), 165 | nn.BatchNorm1d(n_mel_channels), 166 | ) 167 | ) 168 | 169 | def forward(self, x): 170 | x = x.contiguous().transpose(1, 2) 171 | 172 | for i in range(len(self.convolutions) - 1): 173 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 174 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 175 | 176 | x = x.contiguous().transpose(1, 2) 177 | return x 178 | -------------------------------------------------------------------------------- /audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data.cuda(), 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes): 152 | output = dynamic_range_compression(magnitudes) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1 170 | assert torch.max(y.data) <= 1 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, energy 179 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | 11 | from utils.model import get_model, get_vocoder 12 | from utils.tools import to_device, log, synth_one_sample 13 | from model import FastSpeech2Loss 14 | from dataset import Dataset 15 | 16 | from evaluate import evaluate 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | def main(args, configs): 22 | print("Prepare training ...") 23 | 24 | preprocess_config, model_config, train_config = configs 25 | 26 | # Get dataset 27 | dataset = Dataset( 28 | "train.txt", preprocess_config, train_config, sort=True, drop_last=True 29 | ) 30 | batch_size = train_config["optimizer"]["batch_size"] 31 | group_size = 4 # Set this larger than 1 to enable sorting in Dataset 32 | assert batch_size * group_size < len(dataset) 33 | loader = DataLoader( 34 | dataset, 35 | batch_size=batch_size * group_size, 36 | shuffle=True, 37 | collate_fn=dataset.collate_fn, 38 | ) 39 | 40 | # Prepare model 41 | model, optimizer = get_model(args, configs, device, train=True) 42 | model = nn.DataParallel(model) 43 | Loss = FastSpeech2Loss(preprocess_config, model_config).to(device) 44 | 45 | # Load vocoder 46 | vocoder = get_vocoder(model_config, device) 47 | 48 | # Init logger 49 | for p in train_config["path"].values(): 50 | os.makedirs(p, exist_ok=True) 51 | train_log_path = os.path.join(train_config["path"]["log_path"], "train") 52 | val_log_path = os.path.join(train_config["path"]["log_path"], "val") 53 | os.makedirs(train_log_path, exist_ok=True) 54 | os.makedirs(val_log_path, exist_ok=True) 55 | train_logger = SummaryWriter(train_log_path) 56 | val_logger = SummaryWriter(val_log_path) 57 | 58 | # Training 59 | step = args.restore_step + 1 60 | epoch = 1 61 | grad_acc_step = train_config["optimizer"]["grad_acc_step"] 62 | grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] 63 | total_step = train_config["step"]["total_step"] 64 | log_step = train_config["step"]["log_step"] 65 | save_step = train_config["step"]["save_step"] 66 | synth_step = train_config["step"]["synth_step"] 67 | val_step = train_config["step"]["val_step"] 68 | 69 | outer_bar = tqdm(total=total_step, desc="Training", position=0) 70 | outer_bar.n = args.restore_step 71 | outer_bar.update() 72 | 73 | while True: 74 | inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) 75 | for batchs in loader: 76 | for batch in batchs: 77 | batch = to_device(batch, device) 78 | 79 | # Forward 80 | output = model(*(batch[2:])) 81 | 82 | # Cal Loss 83 | losses = Loss(batch, output) 84 | total_loss = losses[0] 85 | 86 | # Backward 87 | total_loss = total_loss / grad_acc_step 88 | total_loss.backward() 89 | if step % grad_acc_step == 0: 90 | # Clipping gradients to avoid gradient explosion 91 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) 92 | 93 | # Update weights 94 | optimizer.step_and_update_lr() 95 | optimizer.zero_grad() 96 | 97 | if step % log_step == 0: 98 | losses = [l.item() for l in losses] 99 | message1 = "Step {}/{}, ".format(step, total_step) 100 | message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Dur Loss: {:.4f}, Attn Loss: {:.4f}".format( 101 | *losses 102 | ) 103 | 104 | with open(os.path.join(train_log_path, "log.txt"), "a") as f: 105 | f.write(message1 + message2 + "\n") 106 | 107 | outer_bar.write(message1 + message2) 108 | 109 | log(train_logger, step, losses=losses) 110 | 111 | if step % synth_step == 0: 112 | fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( 113 | batch, 114 | output, 115 | vocoder, 116 | model_config, 117 | preprocess_config, 118 | ) 119 | log( 120 | train_logger, 121 | fig=fig, 122 | tag="Training/step_{}_{}".format(step, tag), 123 | ) 124 | sampling_rate = preprocess_config["preprocessing"]["audio"][ 125 | "sampling_rate" 126 | ] 127 | log( 128 | train_logger, 129 | audio=wav_reconstruction, 130 | sampling_rate=sampling_rate, 131 | tag="Training/step_{}_{}_reconstructed".format(step, tag), 132 | ) 133 | log( 134 | train_logger, 135 | audio=wav_prediction, 136 | sampling_rate=sampling_rate, 137 | tag="Training/step_{}_{}_synthesized".format(step, tag), 138 | ) 139 | 140 | if step % val_step == 0: 141 | model.eval() 142 | message = evaluate(model, step, configs, val_logger, vocoder) 143 | with open(os.path.join(val_log_path, "log.txt"), "a") as f: 144 | f.write(message + "\n") 145 | outer_bar.write(message) 146 | 147 | model.train() 148 | 149 | if step % save_step == 0: 150 | torch.save( 151 | { 152 | "model": model.module.state_dict(), 153 | "optimizer": optimizer._optimizer.state_dict(), 154 | }, 155 | os.path.join( 156 | train_config["path"]["ckpt_path"], 157 | "{}.pth.tar".format(step), 158 | ), 159 | ) 160 | 161 | if step == total_step: 162 | quit() 163 | step += 1 164 | outer_bar.update(1) 165 | 166 | inner_bar.update(1) 167 | epoch += 1 168 | 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--restore_step", type=int, default=0) 173 | parser.add_argument( 174 | "-p", 175 | "--preprocess_config", 176 | type=str, 177 | required=True, 178 | help="path to preprocess.yaml", 179 | ) 180 | parser.add_argument( 181 | "-m", "--model_config", type=str, required=True, help="path to model.yaml" 182 | ) 183 | parser.add_argument( 184 | "-t", "--train_config", type=str, required=True, help="path to train.yaml" 185 | ) 186 | args = parser.parse_args() 187 | 188 | # Read Config 189 | preprocess_config = yaml.load( 190 | open(args.preprocess_config, "r"), Loader=yaml.FullLoader 191 | ) 192 | model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) 193 | train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) 194 | configs = (preprocess_config, model_config, train_config) 195 | 196 | main(args, configs) 197 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | from text import text_to_sequence 9 | from utils.tools import pad_1D, pad_2D 10 | 11 | 12 | class Dataset(Dataset): 13 | def __init__( 14 | self, filename, preprocess_config, train_config, sort=False, drop_last=False 15 | ): 16 | self.dataset_name = preprocess_config["dataset"] 17 | self.preprocessed_path = preprocess_config["path"]["preprocessed_path"] 18 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 19 | self.batch_size = train_config["optimizer"]["batch_size"] 20 | 21 | self.basename, self.speaker, self.text, self.raw_text = self.process_meta( 22 | filename 23 | ) 24 | with open(os.path.join(self.preprocessed_path, "speakers.json")) as f: 25 | self.speaker_map = json.load(f) 26 | self.sort = sort 27 | self.drop_last = drop_last 28 | 29 | def __len__(self): 30 | return len(self.text) 31 | 32 | def __getitem__(self, idx): 33 | basename = self.basename[idx] 34 | speaker = self.speaker[idx] 35 | speaker_id = self.speaker_map[speaker] 36 | raw_text = self.raw_text[idx] 37 | if "korean_cleaners" in self.cleaners: 38 | self.text[idx] = self.text[idx].replace(" ", "") 39 | self.text[idx] = self.text[idx].replace("sp", ",") 40 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 41 | mel_path = os.path.join( 42 | self.preprocessed_path, 43 | "mel", 44 | "{}-mel-{}.npy".format(speaker, basename), 45 | ) 46 | mel = np.load(mel_path) 47 | pitch_path = os.path.join( 48 | self.preprocessed_path, 49 | "pitch", 50 | "{}-pitch-{}.npy".format(speaker, basename), 51 | ) 52 | pitch = np.load(pitch_path) 53 | energy_path = os.path.join( 54 | self.preprocessed_path, 55 | "energy", 56 | "{}-energy-{}.npy".format(speaker, basename), 57 | ) 58 | energy = np.load(energy_path) 59 | 60 | sample = { 61 | "id": basename, 62 | "speaker": speaker_id, 63 | "text": phone, 64 | "raw_text": raw_text, 65 | "mel": mel, 66 | "pitch": pitch, 67 | "energy": energy, 68 | } 69 | 70 | return sample 71 | 72 | def process_meta(self, filename): 73 | with open( 74 | os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8" 75 | ) as f: 76 | name = [] 77 | speaker = [] 78 | text = [] 79 | raw_text = [] 80 | for line in f.readlines(): 81 | n, s, t, r = line.strip("\n").split("|") 82 | name.append(n) 83 | speaker.append(s) 84 | text.append(t) 85 | raw_text.append(r) 86 | return name, speaker, text, raw_text 87 | 88 | def reprocess(self, data, idxs): 89 | ids = [data[idx]["id"] for idx in idxs] 90 | speakers = [data[idx]["speaker"] for idx in idxs] 91 | texts = [data[idx]["text"] for idx in idxs] 92 | raw_texts = [data[idx]["raw_text"] for idx in idxs] 93 | mels = [data[idx]["mel"] for idx in idxs] 94 | pitches = [data[idx]["pitch"] for idx in idxs] 95 | energies = [data[idx]["energy"] for idx in idxs] 96 | 97 | text_lens = np.array([text.shape[0] for text in texts]) 98 | mel_lens = np.array([mel.shape[0] for mel in mels]) 99 | 100 | speakers = np.array(speakers) 101 | texts = pad_1D(texts) 102 | mels = pad_2D(mels) 103 | pitches = pad_1D(pitches) 104 | energies = pad_1D(energies) 105 | 106 | return ( 107 | ids, 108 | raw_texts, 109 | speakers, 110 | texts, 111 | text_lens, 112 | max(text_lens), 113 | mels, 114 | mel_lens, 115 | max(mel_lens), 116 | pitches, 117 | energies, 118 | ) 119 | 120 | def collate_fn(self, data): 121 | data_size = len(data) 122 | 123 | if self.sort: 124 | len_arr = np.array([d["text"].shape[0] for d in data]) 125 | idx_arr = np.argsort(-len_arr) 126 | else: 127 | idx_arr = np.arange(data_size) 128 | 129 | tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size) :] 130 | idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)] 131 | idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist() 132 | if not self.drop_last and len(tail) > 0: 133 | idx_arr += [tail.tolist()] 134 | 135 | output = list() 136 | for idx in idx_arr: 137 | output.append(self.reprocess(data, idx)) 138 | 139 | return output 140 | 141 | 142 | class TextDataset(Dataset): 143 | def __init__(self, filepath, preprocess_config): 144 | self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"] 145 | 146 | self.basename, self.speaker, self.text, self.raw_text = self.process_meta( 147 | filepath 148 | ) 149 | with open( 150 | os.path.join( 151 | preprocess_config["path"]["preprocessed_path"], "speakers.json" 152 | ) 153 | ) as f: 154 | self.speaker_map = json.load(f) 155 | 156 | def __len__(self): 157 | return len(self.text) 158 | 159 | def __getitem__(self, idx): 160 | basename = self.basename[idx] 161 | speaker = self.speaker[idx] 162 | speaker_id = self.speaker_map[speaker] 163 | raw_text = self.raw_text[idx] 164 | phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) 165 | 166 | return (basename, speaker_id, phone, raw_text) 167 | 168 | def process_meta(self, filename): 169 | with open(filename, "r", encoding="utf-8") as f: 170 | name = [] 171 | speaker = [] 172 | text = [] 173 | raw_text = [] 174 | for line in f.readlines(): 175 | n, s, t, r = line.strip("\n").split("|") 176 | name.append(n) 177 | speaker.append(s) 178 | text.append(t) 179 | raw_text.append(r) 180 | return name, speaker, text, raw_text 181 | 182 | def collate_fn(self, data): 183 | ids = [d[0] for d in data] 184 | speakers = np.array([d[1] for d in data]) 185 | texts = [d[2] for d in data] 186 | raw_texts = [d[3] for d in data] 187 | text_lens = np.array([text.shape[0] for text in texts]) 188 | 189 | texts = pad_1D(texts) 190 | 191 | return ids, raw_texts, speakers, texts, text_lens, max(text_lens) 192 | 193 | 194 | if __name__ == "__main__": 195 | # Test 196 | import torch 197 | import yaml 198 | from torch.utils.data import DataLoader 199 | from utils.utils import to_device 200 | 201 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 202 | preprocess_config = yaml.load( 203 | open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader 204 | ) 205 | train_config = yaml.load( 206 | open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader 207 | ) 208 | 209 | train_dataset = Dataset( 210 | "train.txt", preprocess_config, train_config, sort=True, drop_last=True 211 | ) 212 | val_dataset = Dataset( 213 | "val.txt", preprocess_config, train_config, sort=False, drop_last=False 214 | ) 215 | train_loader = DataLoader( 216 | train_dataset, 217 | batch_size=train_config["optimizer"]["batch_size"] * 4, 218 | shuffle=True, 219 | collate_fn=train_dataset.collate_fn, 220 | ) 221 | val_loader = DataLoader( 222 | val_dataset, 223 | batch_size=train_config["optimizer"]["batch_size"], 224 | shuffle=False, 225 | collate_fn=val_dataset.collate_fn, 226 | ) 227 | 228 | n_batch = 0 229 | for batchs in train_loader: 230 | for batch in batchs: 231 | to_device(batch, device) 232 | n_batch += 1 233 | print( 234 | "Training set with size {} is composed of {} batches.".format( 235 | len(train_dataset), n_batch 236 | ) 237 | ) 238 | 239 | n_batch = 0 240 | for batchs in val_loader: 241 | for batch in batchs: 242 | to_device(batch, device) 243 | n_batch += 1 244 | print( 245 | "Validation set with size {} is composed of {} batches.".format( 246 | len(val_dataset), n_batch 247 | ) 248 | ) -------------------------------------------------------------------------------- /transformer/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import transformer.Constants as Constants 6 | from .Layers import FFTBlock, FFTBlock2, Prenet 7 | from text.symbols import symbols 8 | 9 | 10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 11 | """ Sinusoid position encoding table """ 12 | 13 | def cal_angle(position, hid_idx): 14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 15 | 16 | def get_posi_angle_vec(position): 17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 18 | 19 | sinusoid_table = np.array( 20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 21 | ) 22 | 23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 25 | 26 | if padding_idx is not None: 27 | # zero vector for padding dimension 28 | sinusoid_table[padding_idx] = 0.0 29 | 30 | return torch.FloatTensor(sinusoid_table) 31 | 32 | 33 | class TxtEncoder(nn.Module): 34 | """ TxtEncoder """ 35 | 36 | def __init__(self, config): 37 | super(TxtEncoder, self).__init__() 38 | 39 | n_position = config["max_seq_len"] + 1 40 | n_src_vocab = len(symbols) + 1 41 | d_word_vec = config["transformer"]["encoder_hidden"] 42 | n_layers = config["transformer"]["encoder_layer"] 43 | n_head = config["transformer"]["encoder_head"] 44 | d_k = d_v = ( 45 | config["transformer"]["encoder_hidden"] 46 | // config["transformer"]["encoder_head"] 47 | ) 48 | d_model = config["transformer"]["encoder_hidden"] 49 | d_inner = config["transformer"]["conv_filter_size"] 50 | kernel_size = config["transformer"]["conv_kernel_size"] 51 | dropout = config["transformer"]["encoder_dropout"] 52 | 53 | self.max_seq_len = config["max_seq_len"] 54 | self.d_model = d_model 55 | 56 | self.src_word_emb = nn.Embedding( 57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD 58 | ) 59 | self.position_enc = nn.Parameter( 60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 61 | requires_grad=False, 62 | ) 63 | 64 | self.layer_stack = nn.ModuleList( 65 | [ 66 | FFTBlock( 67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 68 | ) 69 | for _ in range(n_layers) 70 | ] 71 | ) 72 | 73 | def forward(self, src_seq, mask, return_attns=False): 74 | 75 | enc_slf_attn_list = [] 76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 77 | 78 | # -- Prepare masks 79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 80 | 81 | # -- Forward 82 | if not self.training and src_seq.shape[1] > self.max_seq_len: 83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table( 84 | src_seq.shape[1], self.d_model 85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 86 | src_seq.device 87 | ) 88 | else: 89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[ 90 | :, :max_len, : 91 | ].expand(batch_size, -1, -1) 92 | 93 | for enc_layer in self.layer_stack: 94 | enc_output, enc_slf_attn = enc_layer( 95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask 96 | ) 97 | if return_attns: 98 | enc_slf_attn_list += [enc_slf_attn] 99 | 100 | return enc_output 101 | 102 | 103 | class MelEncoder(nn.Module): 104 | """ MelEncoder """ 105 | 106 | def __init__(self, config): 107 | super(MelEncoder, self).__init__() 108 | 109 | n_position = config["max_seq_len"] + 1 110 | d_word_vec = config["transformer"]["decoder_hidden"] 111 | n_layers = config["transformer"]["decoder_layer"] 112 | n_head = config["transformer"]["decoder_head"] 113 | d_k = d_v = ( 114 | config["transformer"]["decoder_hidden"] 115 | // config["transformer"]["decoder_head"] 116 | ) 117 | d_model = config["transformer"]["decoder_hidden"] 118 | d_inner = config["transformer"]["conv_filter_size"] 119 | kernel_size = config["transformer"]["conv_kernel_size"] 120 | dropout = config["transformer"]["decoder_dropout"] 121 | 122 | self.max_seq_len = config["max_seq_len"] 123 | self.d_model = d_model 124 | 125 | self.prenet = Prenet() 126 | self.position_enc = nn.Parameter( 127 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 128 | requires_grad=False, 129 | ) 130 | 131 | self.layer_stack = nn.ModuleList( 132 | [ 133 | FFTBlock2( 134 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 135 | ) 136 | for _ in range(n_layers) 137 | ] 138 | ) 139 | 140 | def forward(self, src_seq, tgt_seq, src_mask, tgt_mask, return_attns=True): 141 | 142 | dec_crs_attn_list = [] 143 | batch_size, max_mel_len = tgt_seq.shape[0], tgt_seq.shape[1] 144 | 145 | zero_seq = torch.zeros(tgt_seq.shape[0], 1, tgt_seq.shape[2]).to(tgt_seq.device) 146 | tgt_seq = torch.cat([zero_seq, tgt_seq[:, 1:, :]], dim=1) 147 | 148 | # -- Forward 149 | if not self.training and tgt_seq.shape[1] > self.max_seq_len: 150 | # -- Prepare masks 151 | dec_output = self.prenet(tgt_seq) + get_sinusoid_encoding_table( 152 | tgt_seq.shape[1], self.d_model 153 | )[: tgt_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 154 | tgt_seq.device 155 | ) 156 | crs_attn_mask = src_mask.unsqueeze(1).expand(-1, max_mel_len, -1) 157 | 158 | else: 159 | max_mel_len = min(max_mel_len, self.max_seq_len) 160 | 161 | # -- Prepare masks 162 | dec_output = self.prenet(tgt_seq[:, :max_mel_len, :]) + self.position_enc[ 163 | :, :max_mel_len, : 164 | ].expand(batch_size, -1, -1) 165 | tgt_mask = tgt_mask[:, :max_mel_len] 166 | crs_attn_mask = src_mask.unsqueeze(1).expand(-1, max_mel_len, -1) 167 | 168 | for enc_layer in self.layer_stack: 169 | dec_output, dec_crs_attn = enc_layer(src_seq, dec_output, mask=tgt_mask, crs_attn_mask=crs_attn_mask) 170 | if return_attns: 171 | dec_crs_attn_list += [dec_crs_attn] 172 | 173 | return dec_output, dec_crs_attn_list 174 | 175 | 176 | class MelDecoder(nn.Module): 177 | """ MelDecoder """ 178 | 179 | def __init__(self, config): 180 | super(MelDecoder, self).__init__() 181 | 182 | n_position = config["max_seq_len"] + 1 183 | d_word_vec = config["transformer"]["decoder_hidden"] 184 | n_layers = config["transformer"]["decoder_layer"] 185 | n_head = config["transformer"]["decoder_head"] 186 | d_k = d_v = ( 187 | config["transformer"]["decoder_hidden"] 188 | // config["transformer"]["decoder_head"] 189 | ) 190 | d_model = config["transformer"]["decoder_hidden"] 191 | d_inner = config["transformer"]["conv_filter_size"] 192 | kernel_size = config["transformer"]["conv_kernel_size"] 193 | dropout = config["transformer"]["decoder_dropout"] 194 | 195 | self.max_seq_len = config["max_seq_len"] 196 | self.d_model = d_model 197 | 198 | self.position_enc = nn.Parameter( 199 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 200 | requires_grad=False, 201 | ) 202 | 203 | self.layer_stack = nn.ModuleList( 204 | [ 205 | FFTBlock( 206 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 207 | ) 208 | for _ in range(n_layers) 209 | ] 210 | ) 211 | 212 | def forward(self, enc_seq, mask, return_attns=False): 213 | 214 | dec_slf_attn_list = [] 215 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 216 | 217 | # -- Forward 218 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 219 | # -- Prepare masks 220 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 221 | dec_output = enc_seq + get_sinusoid_encoding_table( 222 | enc_seq.shape[1], self.d_model 223 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 224 | enc_seq.device 225 | ) 226 | else: 227 | max_len = min(max_len, self.max_seq_len) 228 | 229 | # -- Prepare masks 230 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 231 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 232 | :, :max_len, : 233 | ].expand(batch_size, -1, -1) 234 | mask = mask[:, :max_len] 235 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 236 | 237 | for dec_layer in self.layer_stack: 238 | dec_output, dec_slf_attn = dec_layer( 239 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask 240 | ) 241 | if return_attns: 242 | dec_slf_attn_list += [dec_slf_attn] 243 | 244 | return dec_output, mask 245 | -------------------------------------------------------------------------------- /text/transformer/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import transformer.Constants as Constants 6 | from .Layers import FFTBlock, FFTBlock2, Prenet 7 | from text.symbols import symbols 8 | 9 | 10 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 11 | """ Sinusoid position encoding table """ 12 | 13 | def cal_angle(position, hid_idx): 14 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 15 | 16 | def get_posi_angle_vec(position): 17 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 18 | 19 | sinusoid_table = np.array( 20 | [get_posi_angle_vec(pos_i) for pos_i in range(n_position)] 21 | ) 22 | 23 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 24 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 25 | 26 | if padding_idx is not None: 27 | # zero vector for padding dimension 28 | sinusoid_table[padding_idx] = 0.0 29 | 30 | return torch.FloatTensor(sinusoid_table) 31 | 32 | 33 | class TxtEncoder(nn.Module): 34 | """ TxtEncoder """ 35 | 36 | def __init__(self, config): 37 | super(TxtEncoder, self).__init__() 38 | 39 | n_position = config["max_seq_len"] + 1 40 | n_src_vocab = len(symbols) + 1 41 | d_word_vec = config["transformer"]["encoder_hidden"] 42 | n_layers = config["transformer"]["encoder_layer"] 43 | n_head = config["transformer"]["encoder_head"] 44 | d_k = d_v = ( 45 | config["transformer"]["encoder_hidden"] 46 | // config["transformer"]["encoder_head"] 47 | ) 48 | d_model = config["transformer"]["encoder_hidden"] 49 | d_inner = config["transformer"]["conv_filter_size"] 50 | kernel_size = config["transformer"]["conv_kernel_size"] 51 | dropout = config["transformer"]["encoder_dropout"] 52 | 53 | self.max_seq_len = config["max_seq_len"] 54 | self.d_model = d_model 55 | 56 | self.src_word_emb = nn.Embedding( 57 | n_src_vocab, d_word_vec, padding_idx=Constants.PAD 58 | ) 59 | self.position_enc = nn.Parameter( 60 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 61 | requires_grad=False, 62 | ) 63 | 64 | self.layer_stack = nn.ModuleList( 65 | [ 66 | FFTBlock( 67 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 68 | ) 69 | for _ in range(n_layers) 70 | ] 71 | ) 72 | 73 | def forward(self, src_seq, mask, return_attns=False): 74 | 75 | enc_slf_attn_list = [] 76 | batch_size, max_len = src_seq.shape[0], src_seq.shape[1] 77 | 78 | # -- Prepare masks 79 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 80 | 81 | # -- Forward 82 | if not self.training and src_seq.shape[1] > self.max_seq_len: 83 | enc_output = self.src_word_emb(src_seq) + get_sinusoid_encoding_table( 84 | src_seq.shape[1], self.d_model 85 | )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 86 | src_seq.device 87 | ) 88 | else: 89 | enc_output = self.src_word_emb(src_seq) + self.position_enc[ 90 | :, :max_len, : 91 | ].expand(batch_size, -1, -1) 92 | 93 | for enc_layer in self.layer_stack: 94 | enc_output, enc_slf_attn = enc_layer( 95 | enc_output, mask=mask, slf_attn_mask=slf_attn_mask 96 | ) 97 | if return_attns: 98 | enc_slf_attn_list += [enc_slf_attn] 99 | 100 | return enc_output 101 | 102 | 103 | class MelEncoder(nn.Module): 104 | """ MelEncoder """ 105 | 106 | def __init__(self, config): 107 | super(MelEncoder, self).__init__() 108 | 109 | n_position = config["max_seq_len"] + 1 110 | d_word_vec = config["transformer"]["decoder_hidden"] 111 | n_layers = config["transformer"]["decoder_layer"] 112 | n_head = config["transformer"]["decoder_head"] 113 | d_k = d_v = ( 114 | config["transformer"]["decoder_hidden"] 115 | // config["transformer"]["decoder_head"] 116 | ) 117 | d_model = config["transformer"]["decoder_hidden"] 118 | d_inner = config["transformer"]["conv_filter_size"] 119 | kernel_size = config["transformer"]["conv_kernel_size"] 120 | dropout = config["transformer"]["decoder_dropout"] 121 | 122 | self.max_seq_len = config["max_seq_len"] 123 | self.d_model = d_model 124 | 125 | self.prenet = Prenet() 126 | self.position_enc = nn.Parameter( 127 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 128 | requires_grad=False, 129 | ) 130 | 131 | self.layer_stack = nn.ModuleList( 132 | [ 133 | FFTBlock2( 134 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 135 | ) 136 | for _ in range(n_layers) 137 | ] 138 | ) 139 | 140 | def forward(self, src_seq, tgt_seq, src_mask, tgt_mask, return_attns=True): 141 | 142 | dec_crs_attn_list = [] 143 | batch_size, max_mel_len = tgt_seq.shape[0], tgt_seq.shape[1] 144 | 145 | zero_seq = torch.zeros(tgt_seq.shape[0], 1, tgt_seq.shape[2]).to(tgt_seq.device) 146 | tgt_seq = torch.cat([zero_seq, tgt_seq[:, 1:, :]], dim=1) 147 | 148 | # -- Forward 149 | if not self.training and tgt_seq.shape[1] > self.max_seq_len: 150 | # -- Prepare masks 151 | dec_output = self.prenet(tgt_seq) + get_sinusoid_encoding_table( 152 | tgt_seq.shape[1], self.d_model 153 | )[: tgt_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 154 | tgt_seq.device 155 | ) 156 | crs_attn_mask = src_mask.unsqueeze(1).expand(-1, max_mel_len, -1) 157 | 158 | else: 159 | max_mel_len = min(max_mel_len, self.max_seq_len) 160 | 161 | # -- Prepare masks 162 | dec_output = self.prenet(tgt_seq[:, :max_mel_len, :]) + self.position_enc[ 163 | :, :max_mel_len, : 164 | ].expand(batch_size, -1, -1) 165 | tgt_mask = tgt_mask[:, :max_mel_len] 166 | crs_attn_mask = src_mask.unsqueeze(1).expand(-1, max_mel_len, -1) 167 | 168 | for enc_layer in self.layer_stack: 169 | dec_output, dec_crs_attn = enc_layer(src_seq, dec_output, mask=tgt_mask, crs_attn_mask=crs_attn_mask) 170 | if return_attns: 171 | dec_crs_attn_list += [dec_crs_attn] 172 | 173 | return dec_output, dec_crs_attn_list 174 | 175 | 176 | class MelDecoder(nn.Module): 177 | """ MelDecoder """ 178 | 179 | def __init__(self, config): 180 | super(MelDecoder, self).__init__() 181 | 182 | n_position = config["max_seq_len"] + 1 183 | d_word_vec = config["transformer"]["decoder_hidden"] 184 | n_layers = config["transformer"]["decoder_layer"] 185 | n_head = config["transformer"]["decoder_head"] 186 | d_k = d_v = ( 187 | config["transformer"]["decoder_hidden"] 188 | // config["transformer"]["decoder_head"] 189 | ) 190 | d_model = config["transformer"]["decoder_hidden"] 191 | d_inner = config["transformer"]["conv_filter_size"] 192 | kernel_size = config["transformer"]["conv_kernel_size"] 193 | dropout = config["transformer"]["decoder_dropout"] 194 | 195 | self.max_seq_len = config["max_seq_len"] 196 | self.d_model = d_model 197 | 198 | self.position_enc = nn.Parameter( 199 | get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0), 200 | requires_grad=False, 201 | ) 202 | 203 | self.layer_stack = nn.ModuleList( 204 | [ 205 | FFTBlock( 206 | d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout 207 | ) 208 | for _ in range(n_layers) 209 | ] 210 | ) 211 | 212 | def forward(self, enc_seq, mask, return_attns=False): 213 | 214 | dec_slf_attn_list = [] 215 | batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1] 216 | 217 | # -- Forward 218 | if not self.training and enc_seq.shape[1] > self.max_seq_len: 219 | # -- Prepare masks 220 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 221 | dec_output = enc_seq + get_sinusoid_encoding_table( 222 | enc_seq.shape[1], self.d_model 223 | )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to( 224 | enc_seq.device 225 | ) 226 | else: 227 | max_len = min(max_len, self.max_seq_len) 228 | 229 | # -- Prepare masks 230 | slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) 231 | dec_output = enc_seq[:, :max_len, :] + self.position_enc[ 232 | :, :max_len, : 233 | ].expand(batch_size, -1, -1) 234 | mask = mask[:, :max_len] 235 | slf_attn_mask = slf_attn_mask[:, :, :max_len] 236 | 237 | for dec_layer in self.layer_stack: 238 | dec_output, dec_slf_attn = dec_layer( 239 | dec_output, mask=mask, slf_attn_mask=slf_attn_mask 240 | ) 241 | if return_attns: 242 | dec_slf_attn_list += [dec_slf_attn] 243 | 244 | return dec_output, mask 245 | -------------------------------------------------------------------------------- /text/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import matplotlib 8 | from scipy.io import wavfile 9 | from matplotlib import pyplot as plt 10 | 11 | 12 | matplotlib.use("Agg") 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def to_device(data, device): 19 | if len(data) == 11: 20 | ( 21 | ids, 22 | raw_texts, 23 | speakers, 24 | texts, 25 | src_lens, 26 | max_src_len, 27 | mels, 28 | mel_lens, 29 | max_mel_len, 30 | pitches, 31 | energies, 32 | ) = data 33 | 34 | speakers = torch.from_numpy(speakers).long().to(device) 35 | texts = torch.from_numpy(texts).long().to(device) 36 | src_lens = torch.from_numpy(src_lens).to(device) 37 | mels = torch.from_numpy(mels).float().to(device) 38 | mel_lens = torch.from_numpy(mel_lens).to(device) 39 | pitches = torch.from_numpy(pitches).float().to(device) 40 | energies = torch.from_numpy(energies).to(device) 41 | 42 | return ( 43 | ids, 44 | raw_texts, 45 | speakers, 46 | texts, 47 | src_lens, 48 | max_src_len, 49 | mels, 50 | mel_lens, 51 | max_mel_len, 52 | pitches, 53 | energies, 54 | ) 55 | 56 | if len(data) == 6: 57 | (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data 58 | 59 | speakers = torch.from_numpy(speakers).long().to(device) 60 | texts = torch.from_numpy(texts).long().to(device) 61 | src_lens = torch.from_numpy(src_lens).to(device) 62 | 63 | return (ids, raw_texts, speakers, texts, src_lens, max_src_len) 64 | 65 | 66 | def log( 67 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 68 | ): 69 | if losses is not None: 70 | logger.add_scalar("Loss/total_loss", losses[0], step) 71 | logger.add_scalar("Loss/mel_loss", losses[1], step) 72 | logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) 73 | logger.add_scalar("Loss/pitch_loss", losses[3], step) 74 | logger.add_scalar("Loss/energy_loss", losses[4], step) 75 | logger.add_scalar("Loss/duration_loss", losses[5], step) 76 | 77 | if fig is not None: 78 | logger.add_figure(tag, fig) 79 | 80 | if audio is not None: 81 | logger.add_audio( 82 | tag, 83 | audio / max(abs(audio)), 84 | sample_rate=sampling_rate, 85 | ) 86 | 87 | 88 | def get_mask_from_lengths(lengths, max_len=None): 89 | batch_size = lengths.shape[0] 90 | if max_len is None: 91 | max_len = torch.max(lengths).item() 92 | 93 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 94 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 95 | 96 | return mask 97 | 98 | 99 | def expand(values, durations): 100 | out = list() 101 | for value, d in zip(values, durations): 102 | out += [value] * max(0, int(d)) 103 | return np.array(out) 104 | 105 | 106 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): 107 | 108 | basename = targets[0][0] 109 | src_len = predictions[8][0].item() 110 | mel_len = predictions[9][0].item() 111 | mel_target = targets[6][0, :mel_len].detach().transpose(0, 1) 112 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) 113 | pitch = targets[9][0, :mel_len].detach().cpu().numpy() 114 | energy = targets[10][0, :mel_len].detach().cpu().numpy() 115 | 116 | with open( 117 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 118 | ) as f: 119 | stats = json.load(f) 120 | stats = stats["pitch"] + stats["energy"][:2] 121 | 122 | fig = plot_mel( 123 | [ 124 | (mel_prediction.cpu().numpy(), pitch, energy), 125 | (mel_target.cpu().numpy(), pitch, energy), 126 | ], 127 | stats, 128 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], 129 | ) 130 | 131 | if vocoder is not None: 132 | from .model import vocoder_infer 133 | 134 | wav_reconstruction = vocoder_infer( 135 | mel_target.unsqueeze(0), 136 | vocoder, 137 | model_config, 138 | preprocess_config, 139 | )[0] 140 | wav_prediction = vocoder_infer( 141 | mel_prediction.unsqueeze(0), 142 | vocoder, 143 | model_config, 144 | preprocess_config, 145 | )[0] 146 | else: 147 | wav_reconstruction = wav_prediction = None 148 | 149 | return fig, wav_reconstruction, wav_prediction, basename 150 | 151 | 152 | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): 153 | 154 | basenames = targets[0] 155 | for i in range(len(predictions[0])): 156 | basename = basenames[i] 157 | src_len = predictions[8][i].item() 158 | mel_len = predictions[9][i].item() 159 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) 160 | duration = predictions[5][i, :src_len].detach().cpu().numpy() 161 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 162 | pitch = predictions[2][i, :src_len].detach().cpu().numpy() 163 | pitch = expand(pitch, duration) 164 | else: 165 | pitch = predictions[2][i, :mel_len].detach().cpu().numpy() 166 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 167 | energy = predictions[3][i, :src_len].detach().cpu().numpy() 168 | energy = expand(energy, duration) 169 | else: 170 | energy = predictions[3][i, :mel_len].detach().cpu().numpy() 171 | 172 | with open( 173 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 174 | ) as f: 175 | stats = json.load(f) 176 | stats = stats["pitch"] + stats["energy"][:2] 177 | 178 | fig = plot_mel( 179 | [ 180 | (mel_prediction.cpu().numpy(), pitch, energy), 181 | ], 182 | stats, 183 | ["Synthetized Spectrogram"], 184 | ) 185 | plt.savefig(os.path.join(path, "{}.png".format(basename))) 186 | plt.close() 187 | 188 | from .model import vocoder_infer 189 | 190 | mel_predictions = predictions[1].transpose(1, 2) 191 | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] 192 | wav_predictions = vocoder_infer( 193 | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths 194 | ) 195 | 196 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 197 | for wav, basename in zip(wav_predictions, basenames): 198 | wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) 199 | 200 | 201 | def plot_mel(data, stats, titles): 202 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 203 | if titles is None: 204 | titles = [None for i in range(len(data))] 205 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 206 | pitch_min = pitch_min * pitch_std + pitch_mean 207 | pitch_max = pitch_max * pitch_std + pitch_mean 208 | 209 | def add_axis(fig, old_ax): 210 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 211 | ax.set_facecolor("None") 212 | return ax 213 | 214 | for i in range(len(data)): 215 | mel, pitch, energy = data[i] 216 | pitch = pitch * pitch_std + pitch_mean 217 | axes[i][0].imshow(mel, origin="lower") 218 | axes[i][0].set_aspect(2.5, adjustable="box") 219 | axes[i][0].set_ylim(0, mel.shape[0]) 220 | axes[i][0].set_title(titles[i], fontsize="medium") 221 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 222 | axes[i][0].set_anchor("W") 223 | 224 | ax1 = add_axis(fig, axes[i][0]) 225 | ax1.plot(pitch, color="tomato") 226 | ax1.set_xlim(0, mel.shape[1]) 227 | ax1.set_ylim(0, pitch_max) 228 | ax1.set_ylabel("F0", color="tomato") 229 | ax1.tick_params( 230 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 231 | ) 232 | 233 | ax2 = add_axis(fig, axes[i][0]) 234 | ax2.plot(energy, color="darkviolet") 235 | ax2.set_xlim(0, mel.shape[1]) 236 | ax2.set_ylim(energy_min, energy_max) 237 | ax2.set_ylabel("Energy", color="darkviolet") 238 | ax2.yaxis.set_label_position("right") 239 | ax2.tick_params( 240 | labelsize="x-small", 241 | colors="darkviolet", 242 | bottom=False, 243 | labelbottom=False, 244 | left=False, 245 | labelleft=False, 246 | right=True, 247 | labelright=True, 248 | ) 249 | 250 | return fig 251 | 252 | 253 | def pad_1D(inputs, PAD=0): 254 | def pad_data(x, length, PAD): 255 | x_padded = np.pad( 256 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 257 | ) 258 | return x_padded 259 | 260 | max_len = max((len(x) for x in inputs)) 261 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 262 | 263 | return padded 264 | 265 | 266 | def pad_2D(inputs, maxlen=None): 267 | def pad(x, max_len): 268 | PAD = 0 269 | if np.shape(x)[0] > max_len: 270 | raise ValueError("not max_len") 271 | 272 | s = np.shape(x)[1] 273 | x_padded = np.pad( 274 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 275 | ) 276 | return x_padded[:, :s] 277 | 278 | if maxlen: 279 | output = np.stack([pad(x, maxlen) for x in inputs]) 280 | else: 281 | max_len = max(np.shape(x)[0] for x in inputs) 282 | output = np.stack([pad(x, max_len) for x in inputs]) 283 | 284 | return output 285 | 286 | 287 | def pad(input_ele, mel_max_length=None): 288 | if mel_max_length: 289 | max_len = mel_max_length 290 | else: 291 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 292 | 293 | out_list = list() 294 | for i, batch in enumerate(input_ele): 295 | if len(batch.shape) == 1: 296 | one_batch_padded = F.pad( 297 | batch, (0, max_len - batch.size(0)), "constant", 0.0 298 | ) 299 | elif len(batch.shape) == 2: 300 | one_batch_padded = F.pad( 301 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 302 | ) 303 | out_list.append(one_batch_padded) 304 | out_padded = torch.stack(out_list) 305 | return out_padded 306 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import matplotlib 8 | from scipy.io import wavfile 9 | from matplotlib import pyplot as plt 10 | 11 | 12 | matplotlib.use("Agg") 13 | 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | def to_device(data, device): 19 | if len(data) == 11: 20 | ( 21 | ids, 22 | raw_texts, 23 | speakers, 24 | texts, 25 | src_lens, 26 | max_src_len, 27 | mels, 28 | mel_lens, 29 | max_mel_len, 30 | pitches, 31 | energies, 32 | ) = data 33 | 34 | speakers = torch.from_numpy(speakers).long().to(device) 35 | texts = torch.from_numpy(texts).long().to(device) 36 | src_lens = torch.from_numpy(src_lens).to(device) 37 | mels = torch.from_numpy(mels).float().to(device) 38 | mel_lens = torch.from_numpy(mel_lens).to(device) 39 | pitches = torch.from_numpy(pitches).float().to(device) 40 | energies = torch.from_numpy(energies).to(device) 41 | 42 | return ( 43 | ids, 44 | raw_texts, 45 | speakers, 46 | texts, 47 | src_lens, 48 | max_src_len, 49 | mels, 50 | mel_lens, 51 | max_mel_len, 52 | pitches, 53 | energies, 54 | ) 55 | 56 | if len(data) == 6: 57 | (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data 58 | 59 | speakers = torch.from_numpy(speakers).long().to(device) 60 | texts = torch.from_numpy(texts).long().to(device) 61 | src_lens = torch.from_numpy(src_lens).to(device) 62 | 63 | return (ids, raw_texts, speakers, texts, src_lens, max_src_len) 64 | 65 | 66 | def log( 67 | logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" 68 | ): 69 | if losses is not None: 70 | logger.add_scalar("Loss/total_loss", losses[0], step) 71 | logger.add_scalar("Loss/mel_loss", losses[1], step) 72 | logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) 73 | logger.add_scalar("Loss/pitch_loss", losses[3], step) 74 | logger.add_scalar("Loss/energy_loss", losses[4], step) 75 | logger.add_scalar("Loss/duration_loss", losses[5], step) 76 | logger.add_scalar("Loss/attention_loss", losses[6], step) 77 | 78 | if fig is not None: 79 | logger.add_figure(tag, fig) 80 | 81 | if audio is not None: 82 | logger.add_audio( 83 | tag, 84 | audio / max(abs(audio)), 85 | sample_rate=sampling_rate, 86 | ) 87 | 88 | 89 | def get_mask_from_lengths(lengths, max_len=None): 90 | batch_size = lengths.shape[0] 91 | if max_len is None: 92 | max_len = torch.max(lengths).item() 93 | 94 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 95 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 96 | 97 | return mask 98 | 99 | 100 | def expand(values, durations): 101 | out = list() 102 | for value, d in zip(values, durations): 103 | out += [value] * max(0, int(d)) 104 | return np.array(out) 105 | 106 | 107 | def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): 108 | 109 | basename = targets[0][0] 110 | src_len = predictions[8][0].item() 111 | mel_len = predictions[9][0].item() 112 | mel_target = targets[6][0, :mel_len].detach().transpose(0, 1) 113 | mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) 114 | pitch = targets[9][0, :mel_len].detach().cpu().numpy() 115 | energy = targets[10][0, :mel_len].detach().cpu().numpy() 116 | 117 | with open( 118 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 119 | ) as f: 120 | stats = json.load(f) 121 | stats = stats["pitch"] + stats["energy"][:2] 122 | 123 | fig = plot_mel( 124 | [ 125 | (mel_prediction.cpu().numpy(), pitch, energy), 126 | (mel_target.cpu().numpy(), pitch, energy), 127 | ], 128 | stats, 129 | ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], 130 | ) 131 | 132 | if vocoder is not None: 133 | from .model import vocoder_infer 134 | 135 | wav_reconstruction = vocoder_infer( 136 | mel_target.unsqueeze(0), 137 | vocoder, 138 | model_config, 139 | preprocess_config, 140 | )[0] 141 | wav_prediction = vocoder_infer( 142 | mel_prediction.unsqueeze(0), 143 | vocoder, 144 | model_config, 145 | preprocess_config, 146 | )[0] 147 | else: 148 | wav_reconstruction = wav_prediction = None 149 | 150 | return fig, wav_reconstruction, wav_prediction, basename 151 | 152 | 153 | def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): 154 | 155 | basenames = targets[0] 156 | for i in range(len(predictions[0])): 157 | basename = basenames[i] 158 | src_len = predictions[8][i].item() 159 | mel_len = predictions[9][i].item() 160 | mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) 161 | duration = predictions[5][i, :src_len].detach().cpu().numpy() 162 | if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": 163 | pitch = predictions[2][i, :src_len].detach().cpu().numpy() 164 | pitch = expand(pitch, duration) 165 | else: 166 | pitch = predictions[2][i, :mel_len].detach().cpu().numpy() 167 | if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": 168 | energy = predictions[3][i, :src_len].detach().cpu().numpy() 169 | energy = expand(energy, duration) 170 | else: 171 | energy = predictions[3][i, :mel_len].detach().cpu().numpy() 172 | 173 | with open( 174 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 175 | ) as f: 176 | stats = json.load(f) 177 | stats = stats["pitch"] + stats["energy"][:2] 178 | 179 | fig = plot_mel( 180 | [ 181 | (mel_prediction.cpu().numpy(), pitch, energy), 182 | ], 183 | stats, 184 | ["Synthetized Spectrogram"], 185 | ) 186 | plt.savefig(os.path.join(path, "{}.png".format(basename))) 187 | plt.close() 188 | 189 | from .model import vocoder_infer 190 | 191 | mel_predictions = predictions[1].transpose(1, 2) 192 | lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"] 193 | wav_predictions = vocoder_infer( 194 | mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths 195 | ) 196 | 197 | sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] 198 | for wav, basename in zip(wav_predictions, basenames): 199 | wavfile.write(os.path.join(path, "{}.wav".format(basename)), sampling_rate, wav) 200 | 201 | 202 | def plot_mel(data, stats, titles): 203 | fig, axes = plt.subplots(len(data), 1, squeeze=False) 204 | if titles is None: 205 | titles = [None for i in range(len(data))] 206 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 207 | pitch_min = pitch_min * pitch_std + pitch_mean 208 | pitch_max = pitch_max * pitch_std + pitch_mean 209 | 210 | def add_axis(fig, old_ax): 211 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 212 | ax.set_facecolor("None") 213 | return ax 214 | 215 | for i in range(len(data)): 216 | mel, pitch, energy = data[i] 217 | pitch = pitch * pitch_std + pitch_mean 218 | axes[i][0].imshow(mel, origin="lower") 219 | axes[i][0].set_aspect(2.5, adjustable="box") 220 | axes[i][0].set_ylim(0, mel.shape[0]) 221 | axes[i][0].set_title(titles[i], fontsize="medium") 222 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 223 | axes[i][0].set_anchor("W") 224 | 225 | ax1 = add_axis(fig, axes[i][0]) 226 | ax1.plot(pitch, color="tomato") 227 | ax1.set_xlim(0, mel.shape[1]) 228 | ax1.set_ylim(0, pitch_max) 229 | ax1.set_ylabel("F0", color="tomato") 230 | ax1.tick_params( 231 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 232 | ) 233 | 234 | ax2 = add_axis(fig, axes[i][0]) 235 | ax2.plot(energy, color="darkviolet") 236 | ax2.set_xlim(0, mel.shape[1]) 237 | ax2.set_ylim(energy_min, energy_max) 238 | ax2.set_ylabel("Energy", color="darkviolet") 239 | ax2.yaxis.set_label_position("right") 240 | ax2.tick_params( 241 | labelsize="x-small", 242 | colors="darkviolet", 243 | bottom=False, 244 | labelbottom=False, 245 | left=False, 246 | labelleft=False, 247 | right=True, 248 | labelright=True, 249 | ) 250 | 251 | return fig 252 | 253 | 254 | def pad_1D(inputs, PAD=0): 255 | def pad_data(x, length, PAD): 256 | x_padded = np.pad( 257 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 258 | ) 259 | return x_padded 260 | 261 | max_len = max((len(x) for x in inputs)) 262 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 263 | 264 | return padded 265 | 266 | 267 | def pad_2D(inputs, maxlen=None): 268 | def pad(x, max_len): 269 | PAD = 0 270 | if np.shape(x)[0] > max_len: 271 | raise ValueError("not max_len") 272 | 273 | s = np.shape(x)[1] 274 | x_padded = np.pad( 275 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 276 | ) 277 | return x_padded[:, :s] 278 | 279 | if maxlen: 280 | output = np.stack([pad(x, maxlen) for x in inputs]) 281 | else: 282 | max_len = max(np.shape(x)[0] for x in inputs) 283 | output = np.stack([pad(x, max_len) for x in inputs]) 284 | 285 | return output 286 | 287 | 288 | def pad(input_ele, mel_max_length=None): 289 | if mel_max_length: 290 | max_len = mel_max_length 291 | else: 292 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 293 | 294 | out_list = list() 295 | for i, batch in enumerate(input_ele): 296 | if len(batch.shape) == 1: 297 | one_batch_padded = F.pad( 298 | batch, (0, max_len - batch.size(0)), "constant", 0.0 299 | ) 300 | elif len(batch.shape) == 2: 301 | one_batch_padded = F.pad( 302 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 303 | ) 304 | out_list.append(one_batch_padded) 305 | out_padded = torch.stack(out_list) 306 | return out_padded 307 | -------------------------------------------------------------------------------- /preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | 5 | import tgt 6 | import librosa 7 | import numpy as np 8 | import pyworld as pw 9 | from scipy.interpolate import interp1d 10 | from sklearn.preprocessing import StandardScaler 11 | from tqdm import tqdm 12 | 13 | import audio as Audio 14 | 15 | 16 | class Preprocessor: 17 | def __init__(self, config): 18 | self.config = config 19 | self.in_dir = config["path"]["data_path"] 20 | self.out_dir = config["path"]["preprocessed_path"] 21 | self.val_size = config["preprocessing"]["val_size"] 22 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 23 | self.hop_length = config["preprocessing"]["stft"]["hop_length"] 24 | 25 | assert config["preprocessing"]["pitch"]["feature"] in [ 26 | "phoneme_level", 27 | "frame_level", 28 | ] 29 | assert config["preprocessing"]["energy"]["feature"] in [ 30 | "phoneme_level", 31 | "frame_level", 32 | ] 33 | self.pitch_phoneme_averaging = ( 34 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level" 35 | ) 36 | self.energy_phoneme_averaging = ( 37 | config["preprocessing"]["energy"]["feature"] == "phoneme_level" 38 | ) 39 | 40 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"] 41 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"] 42 | 43 | self.STFT = Audio.stft.TacotronSTFT( 44 | config["preprocessing"]["stft"]["filter_length"], 45 | config["preprocessing"]["stft"]["hop_length"], 46 | config["preprocessing"]["stft"]["win_length"], 47 | config["preprocessing"]["mel"]["n_mel_channels"], 48 | config["preprocessing"]["audio"]["sampling_rate"], 49 | config["preprocessing"]["mel"]["mel_fmin"], 50 | config["preprocessing"]["mel"]["mel_fmax"], 51 | ) 52 | 53 | def build_from_path(self): 54 | os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True) 55 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True) 56 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True) 57 | 58 | print("Processing Data ...") 59 | out = list() 60 | n_frames = 0 61 | pitch_scaler = StandardScaler() 62 | energy_scaler = StandardScaler() 63 | 64 | # Compute pitch, energy, duration, and mel-spectrogram 65 | speakers = {} 66 | for i, speaker in enumerate(tqdm(os.listdir(self.in_dir))): 67 | speakers[speaker] = i 68 | for wav_name in os.listdir(os.path.join(self.in_dir, speaker)): 69 | if ".wav" not in wav_name: 70 | continue 71 | 72 | basename = wav_name.split(".")[0] 73 | tg_path = os.path.join( 74 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 75 | ) 76 | if os.path.exists(tg_path): 77 | ret = self.process_utterance(speaker, basename) 78 | if ret is None: 79 | continue 80 | else: 81 | info, pitch, energy, n = ret 82 | out.append(info) 83 | 84 | if len(pitch) > 0: 85 | pitch_scaler.partial_fit(pitch.reshape((-1, 1))) 86 | if len(energy) > 0: 87 | energy_scaler.partial_fit(energy.reshape((-1, 1))) 88 | 89 | n_frames += n 90 | 91 | print("Computing statistic quantities ...") 92 | # Perform normalization if necessary 93 | if self.pitch_normalization: 94 | pitch_mean = pitch_scaler.mean_[0] 95 | pitch_std = pitch_scaler.scale_[0] 96 | else: 97 | # A numerical trick to avoid normalization... 98 | pitch_mean = 0 99 | pitch_std = 1 100 | if self.energy_normalization: 101 | energy_mean = energy_scaler.mean_[0] 102 | energy_std = energy_scaler.scale_[0] 103 | else: 104 | energy_mean = 0 105 | energy_std = 1 106 | 107 | pitch_min, pitch_max = self.normalize( 108 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std 109 | ) 110 | energy_min, energy_max = self.normalize( 111 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std 112 | ) 113 | 114 | # Save files 115 | with open(os.path.join(self.out_dir, "speakers.json"), "w") as f: 116 | f.write(json.dumps(speakers)) 117 | 118 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f: 119 | stats = { 120 | "pitch": [ 121 | float(pitch_min), 122 | float(pitch_max), 123 | float(pitch_mean), 124 | float(pitch_std), 125 | ], 126 | "energy": [ 127 | float(energy_min), 128 | float(energy_max), 129 | float(energy_mean), 130 | float(energy_std), 131 | ], 132 | } 133 | f.write(json.dumps(stats)) 134 | 135 | print( 136 | "Total time: {} hours".format( 137 | n_frames * self.hop_length / self.sampling_rate / 3600 138 | ) 139 | ) 140 | 141 | random.shuffle(out) 142 | out = [r for r in out if r is not None] 143 | 144 | # Write metadata 145 | with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f: 146 | for m in out[self.val_size :]: 147 | f.write(m + "\n") 148 | with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f: 149 | for m in out[: self.val_size]: 150 | f.write(m + "\n") 151 | 152 | return out 153 | 154 | def process_utterance(self, speaker, basename): 155 | wav_path = os.path.join(self.in_dir, speaker, "{}.wav".format(basename)) 156 | text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename)) 157 | tg_path = os.path.join( 158 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 159 | ) 160 | 161 | # Get alignments 162 | textgrid = tgt.io.read_textgrid(tg_path) 163 | phone, duration, start, end = self.get_alignment( 164 | textgrid.get_tier_by_name("phones") 165 | ) 166 | text = "{" + " ".join(phone) + "}" 167 | if start >= end: 168 | return None 169 | 170 | # Read and trim wav files 171 | wav, _ = librosa.load(wav_path) 172 | wav = wav[ 173 | int(self.sampling_rate * start) : int(self.sampling_rate * end) 174 | ].astype(np.float32) 175 | 176 | # Read raw text 177 | with open(text_path, "r") as f: 178 | raw_text = f.readline().strip("\n") 179 | 180 | # Compute fundamental frequency 181 | pitch, t = pw.dio( 182 | wav.astype(np.float64), 183 | self.sampling_rate, 184 | frame_period=self.hop_length / self.sampling_rate * 1000, 185 | ) 186 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate) 187 | 188 | pitch = pitch[: sum(duration)] 189 | if np.sum(pitch != 0) <= 1: 190 | return None 191 | 192 | # Compute mel-scale spectrogram and energy 193 | mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT) 194 | mel_spectrogram = mel_spectrogram[:, : sum(duration)] 195 | energy = energy[: sum(duration)] 196 | 197 | if self.pitch_phoneme_averaging: 198 | # perform linear interpolation 199 | nonzero_ids = np.where(pitch != 0)[0] 200 | interp_fn = interp1d( 201 | nonzero_ids, 202 | pitch[nonzero_ids], 203 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), 204 | bounds_error=False, 205 | ) 206 | pitch = interp_fn(np.arange(0, len(pitch))) 207 | 208 | # Phoneme-level average 209 | pos = 0 210 | for i, d in enumerate(duration): 211 | if d > 0: 212 | pitch[i] = np.mean(pitch[pos : pos + d]) 213 | else: 214 | pitch[i] = 0 215 | pos += d 216 | pitch = pitch[: len(duration)] 217 | 218 | if self.energy_phoneme_averaging: 219 | # Phoneme-level average 220 | pos = 0 221 | for i, d in enumerate(duration): 222 | if d > 0: 223 | energy[i] = np.mean(energy[pos : pos + d]) 224 | else: 225 | energy[i] = 0 226 | pos += d 227 | energy = energy[: len(duration)] 228 | 229 | # Save files 230 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename) 231 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch) 232 | 233 | energy_filename = "{}-energy-{}.npy".format(speaker, basename) 234 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy) 235 | 236 | mel_filename = "{}-mel-{}.npy".format(speaker, basename) 237 | np.save( 238 | os.path.join(self.out_dir, "mel", mel_filename), 239 | mel_spectrogram.T, 240 | ) 241 | 242 | return ( 243 | "|".join([basename, speaker, text, raw_text]), 244 | self.remove_outlier(pitch), 245 | self.remove_outlier(energy), 246 | mel_spectrogram.shape[1], 247 | ) 248 | 249 | def get_alignment(self, tier): 250 | sil_phones = ["sil", "sp", "spn"] 251 | 252 | phones = [] 253 | durations = [] 254 | start_time = 0 255 | end_time = 0 256 | end_idx = 0 257 | for t in tier._objects: 258 | s, e, p = t.start_time, t.end_time, t.text 259 | 260 | # Trim leading silences 261 | if phones == []: 262 | if p in sil_phones: 263 | continue 264 | else: 265 | start_time = s 266 | 267 | if p not in sil_phones: 268 | # For ordinary phones 269 | phones.append(p) 270 | end_time = e 271 | end_idx = len(phones) 272 | else: 273 | # For silent phones 274 | phones.append(p) 275 | 276 | durations.append( 277 | int( 278 | np.round(e * self.sampling_rate / self.hop_length) 279 | - np.round(s * self.sampling_rate / self.hop_length) 280 | ) 281 | ) 282 | 283 | # Trim tailing silences 284 | phones = phones[:end_idx] 285 | durations = durations[:end_idx] 286 | 287 | return phones, durations, start_time, end_time 288 | 289 | def remove_outlier(self, values): 290 | values = np.array(values) 291 | p25 = np.percentile(values, 25) 292 | p75 = np.percentile(values, 75) 293 | lower = p25 - 1.5 * (p75 - p25) 294 | upper = p75 + 1.5 * (p75 - p25) 295 | normal_indices = np.logical_and(values > lower, values < upper) 296 | 297 | return values[normal_indices] 298 | 299 | def normalize(self, in_dir, mean, std): 300 | max_value = np.finfo(np.float64).min 301 | min_value = np.finfo(np.float64).max 302 | for filename in os.listdir(in_dir): 303 | filename = os.path.join(in_dir, filename) 304 | values = (np.load(filename) - mean) / std 305 | np.save(filename, values) 306 | 307 | max_value = max(max_value, max(values)) 308 | min_value = min(min_value, min(values)) 309 | 310 | return min_value, max_value -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | from utils.tools import get_mask_from_lengths, pad 12 | 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class VarianceAdaptor(nn.Module): 18 | """Variance Adaptor""" 19 | 20 | def __init__(self, preprocess_config, model_config): 21 | super(VarianceAdaptor, self).__init__() 22 | self.length_regulator = LengthRegulator() 23 | self.duration_predictor = VariancePredictor(model_config) 24 | self.pitch_predictor = VariancePredictor(model_config) 25 | self.energy_predictor = VariancePredictor(model_config) 26 | 27 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 28 | "feature" 29 | ] 30 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 31 | "feature" 32 | ] 33 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"] 34 | assert self.energy_feature_level in ["phoneme_level", "frame_level"] 35 | 36 | pitch_quantization = model_config["variance_embedding"]["pitch_quantization"] 37 | energy_quantization = model_config["variance_embedding"]["energy_quantization"] 38 | n_bins = model_config["variance_embedding"]["n_bins"] 39 | assert pitch_quantization in ["linear", "log"] 40 | assert energy_quantization in ["linear", "log"] 41 | with open( 42 | os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json") 43 | ) as f: 44 | stats = json.load(f) 45 | pitch_min, pitch_max = stats["pitch"][:2] 46 | energy_min, energy_max = stats["energy"][:2] 47 | 48 | if pitch_quantization == "log": 49 | self.pitch_bins = nn.Parameter( 50 | torch.exp( 51 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) 52 | ), 53 | requires_grad=False, 54 | ) 55 | else: 56 | self.pitch_bins = nn.Parameter( 57 | torch.linspace(pitch_min, pitch_max, n_bins - 1), 58 | requires_grad=False, 59 | ) 60 | if energy_quantization == "log": 61 | self.energy_bins = nn.Parameter( 62 | torch.exp( 63 | torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) 64 | ), 65 | requires_grad=False, 66 | ) 67 | else: 68 | self.energy_bins = nn.Parameter( 69 | torch.linspace(energy_min, energy_max, n_bins - 1), 70 | requires_grad=False, 71 | ) 72 | 73 | self.pitch_embedding = nn.Embedding( 74 | n_bins, model_config["transformer"]["encoder_hidden"] 75 | ) 76 | self.energy_embedding = nn.Embedding( 77 | n_bins, model_config["transformer"]["encoder_hidden"] 78 | ) 79 | 80 | def get_pitch_embedding(self, x, target, mask, control): 81 | prediction = self.pitch_predictor(x, mask) 82 | if target is not None: 83 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) 84 | else: 85 | prediction = prediction * control 86 | embedding = self.pitch_embedding( 87 | torch.bucketize(prediction, self.pitch_bins) 88 | ) 89 | return prediction, embedding 90 | 91 | def get_energy_embedding(self, x, target, mask, control): 92 | prediction = self.energy_predictor(x, mask) 93 | if target is not None: 94 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) 95 | else: 96 | prediction = prediction * control 97 | embedding = self.energy_embedding( 98 | torch.bucketize(prediction, self.energy_bins) 99 | ) 100 | return prediction, embedding 101 | 102 | def forward( 103 | self, 104 | x, 105 | src_mask, 106 | mel_mask=None, 107 | max_len=None, 108 | pitch_target=None, 109 | energy_target=None, 110 | duration_target=None, 111 | p_control=1.0, 112 | e_control=1.0, 113 | d_control=1.0, 114 | ): 115 | 116 | log_duration_prediction = self.duration_predictor(x, src_mask) 117 | if self.pitch_feature_level == "phoneme_level": 118 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 119 | x, pitch_target, src_mask, p_control 120 | ) 121 | x = x + pitch_embedding 122 | if self.energy_feature_level == "phoneme_level": 123 | energy_prediction, energy_embedding = self.get_energy_embedding( 124 | x, energy_target, src_mask, e_control 125 | ) 126 | x = x + energy_embedding 127 | 128 | if duration_target is not None: 129 | x, mel_len = self.length_regulator(x, duration_target, max_len) 130 | duration_rounded = duration_target 131 | else: 132 | duration_rounded = torch.clamp( 133 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), 134 | min=0, 135 | ) 136 | x, mel_len = self.length_regulator(x, duration_rounded, max_len) 137 | mel_mask = get_mask_from_lengths(mel_len) 138 | 139 | if self.pitch_feature_level == "frame_level": 140 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 141 | x, pitch_target, mel_mask, p_control 142 | ) 143 | x = x + pitch_embedding 144 | if self.energy_feature_level == "frame_level": 145 | 146 | energy_prediction, energy_embedding = self.get_energy_embedding( 147 | x, energy_target, mel_mask, e_control 148 | ) 149 | x = x + energy_embedding 150 | 151 | return ( 152 | x, 153 | pitch_prediction, 154 | energy_prediction, 155 | log_duration_prediction, 156 | duration_rounded, 157 | mel_len, 158 | mel_mask, 159 | ) 160 | 161 | 162 | class GaussianUpsampling(nn.Module): 163 | def __init__(self): 164 | super(GaussianUpsampling, self).__init__() 165 | 166 | def forward(self, x, durations, range_outputs, max_len): 167 | range_outputs = range_outputs.unsqueeze(-1) 168 | s = torch.sum(durations, dim=-1, keepdim=True) # [B, 1] 169 | e = torch.cumsum(durations, dim=-1).float() # [B, L] 170 | c = (e - 0.5 * durations).unsqueeze(-1) # [B, L, 1] 171 | t = torch.arange(0, torch.max( 172 | s)).unsqueeze(0).unsqueeze(1).to(device) # [1, 1, T] 173 | 174 | # [B, L, T] 175 | range_outputs = 10.0 176 | w_1 = torch.exp(-(range_outputs ** -2) * ((t - c) ** 2)) 177 | 178 | w_2 = torch.sum(torch.exp(-((range_outputs) ** -2) * 179 | ((t - c) ** 2)), dim=1, keepdim=True) + 1e-20 180 | w = w_1 / w_2 # [B, L, T] 181 | # w[w != w] = 0 # n?an_to_num 182 | w_np = w.detach().cpu().numpy() 183 | 184 | output = torch.matmul(w.transpose( 185 | 1, 2), x) # [B, T, ENC_DIM] 186 | 187 | if max_len is not None: 188 | output = pad(output, max_len) 189 | else: 190 | output = pad(output) 191 | 192 | return output, s, w 193 | 194 | 195 | class LengthRegulator(nn.Module): 196 | """Length Regulator""" 197 | 198 | def __init__(self): 199 | super(LengthRegulator, self).__init__() 200 | 201 | def LR(self, x, duration, max_len): 202 | output = list() 203 | mel_len = list() 204 | 205 | 206 | for batch, expand_target in zip(x, duration): 207 | expanded = self.expand(batch, expand_target) 208 | output.append(expanded) 209 | mel_len.append(expanded.shape[0]) 210 | 211 | if max_len is not None: 212 | output = pad(output, max_len) 213 | else: 214 | output = pad(output) 215 | 216 | return output, torch.LongTensor(mel_len).to(device) 217 | 218 | def expand(self, batch, predicted): 219 | out = list() 220 | 221 | for i, vec in enumerate(batch): 222 | expand_size = predicted[i].item() 223 | out.append(vec.expand(max(int(expand_size), 0), -1)) 224 | out = torch.cat(out, 0) 225 | 226 | return out 227 | 228 | def forward(self, x, duration, max_len): 229 | output, mel_len = self.LR(x, duration, max_len) 230 | return output, mel_len 231 | 232 | 233 | class VariancePredictor(nn.Module): 234 | """Duration, Pitch and Energy Predictor""" 235 | 236 | def __init__(self, model_config): 237 | super(VariancePredictor, self).__init__() 238 | 239 | self.input_size = model_config["transformer"]["encoder_hidden"] 240 | self.filter_size = model_config["variance_predictor"]["filter_size"] 241 | self.kernel = model_config["variance_predictor"]["kernel_size"] 242 | self.conv_output_size = model_config["variance_predictor"]["filter_size"] 243 | self.dropout = model_config["variance_predictor"]["dropout"] 244 | 245 | self.conv_layer = nn.Sequential( 246 | OrderedDict( 247 | [ 248 | ( 249 | "conv1d_1", 250 | Conv( 251 | self.input_size, 252 | self.filter_size, 253 | kernel_size=self.kernel, 254 | padding=(self.kernel - 1) // 2, 255 | ), 256 | ), 257 | ("relu_1", nn.ReLU()), 258 | ("layer_norm_1", nn.LayerNorm(self.filter_size)), 259 | ("dropout_1", nn.Dropout(self.dropout)), 260 | ( 261 | "conv1d_2", 262 | Conv( 263 | self.filter_size, 264 | self.filter_size, 265 | kernel_size=self.kernel, 266 | padding=1, 267 | ), 268 | ), 269 | ("relu_2", nn.ReLU()), 270 | ("layer_norm_2", nn.LayerNorm(self.filter_size)), 271 | ("dropout_2", nn.Dropout(self.dropout)), 272 | ] 273 | ) 274 | ) 275 | 276 | self.linear_layer = nn.Linear(self.conv_output_size, 1) 277 | 278 | def forward(self, encoder_output, mask): 279 | out = self.conv_layer(encoder_output) 280 | out = self.linear_layer(out) 281 | out = out.squeeze(-1) 282 | 283 | if mask is not None: 284 | out = out.masked_fill(mask, 0.0) 285 | 286 | return out 287 | 288 | 289 | class Conv(nn.Module): 290 | """ 291 | Convolution Module 292 | """ 293 | 294 | def __init__( 295 | self, 296 | in_channels, 297 | out_channels, 298 | kernel_size=1, 299 | stride=1, 300 | padding=0, 301 | dilation=1, 302 | bias=True, 303 | w_init="linear", 304 | ): 305 | """ 306 | :param in_channels: dimension of input 307 | :param out_channels: dimension of output 308 | :param kernel_size: size of kernel 309 | :param stride: size of stride 310 | :param padding: size of padding 311 | :param dilation: dilation rate 312 | :param bias: boolean. if True, bias is included. 313 | :param w_init: str. weight inits with xavier initialization. 314 | """ 315 | super(Conv, self).__init__() 316 | 317 | self.conv = nn.Conv1d( 318 | in_channels, 319 | out_channels, 320 | kernel_size=kernel_size, 321 | stride=stride, 322 | padding=padding, 323 | dilation=dilation, 324 | bias=bias, 325 | ) 326 | 327 | def forward(self, x): 328 | x = x.contiguous().transpose(1, 2) 329 | x = self.conv(x) 330 | x = x.contiguous().transpose(1, 2) 331 | 332 | return x 333 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GuidedAttentionLoss(torch.nn.Module): 6 | """Guided attention loss function module. 7 | 8 | This module calculates the guided attention loss described 9 | in `Efficiently Trainable Text-to-Speech System Based 10 | on Deep Convolutional Networks with Guided Attention`_, 11 | which forces the attention to be diagonal. 12 | 13 | .. _`Efficiently Trainable Text-to-Speech System 14 | Based on Deep Convolutional Networks with Guided Attention`: 15 | https://arxiv.org/abs/1710.08969 16 | 17 | """ 18 | 19 | def __init__(self, sigma=0.2, alpha=10.0, reset_always=True): 20 | """Initialize guided attention loss module. 21 | 22 | Args: 23 | sigma (float, optional): Standard deviation to control 24 | how close attention to a diagonal. 25 | alpha (float, optional): Scaling coefficient (lambda). 26 | reset_always (bool, optional): Whether to always reset masks. 27 | 28 | """ 29 | super(GuidedAttentionLoss, self).__init__() 30 | self.sigma = sigma 31 | self.alpha = alpha 32 | self.reset_always = reset_always 33 | self.guided_attn_masks = None 34 | self.masks = None 35 | 36 | def _reset_masks(self): 37 | self.guided_attn_masks = None 38 | self.masks = None 39 | 40 | def forward(self, att_ws, ilens, olens): 41 | """Calculate forward propagation. 42 | 43 | Args: 44 | att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in). 45 | ilens (LongTensor): Batch of input lenghts (B,). 46 | olens (LongTensor): Batch of output lenghts (B,). 47 | 48 | Returns: 49 | Tensor: Guided attention loss value. 50 | 51 | """ 52 | if self.guided_attn_masks is None: 53 | self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to( 54 | att_ws.device 55 | ) 56 | 57 | if self.masks is None: 58 | self.masks = self._make_masks(ilens, olens).to(att_ws.device) 59 | 60 | losses = self.guided_attn_masks * att_ws 61 | loss = torch.mean(losses.masked_select(self.masks)) 62 | 63 | if self.reset_always: 64 | self._reset_masks() 65 | return self.alpha * loss 66 | 67 | def _make_guided_attention_masks(self, ilens, olens): 68 | n_batches = len(ilens) 69 | max_ilen = max(ilens) 70 | max_olen = max(olens) 71 | guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen)) 72 | for idx, (ilen, olen) in enumerate(zip(ilens, olens)): 73 | guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask( 74 | ilen, olen, self.sigma 75 | ) 76 | return guided_attn_masks 77 | 78 | @staticmethod 79 | def _make_guided_attention_mask(ilen, olen, sigma): 80 | """Make guided attention mask. 81 | 82 | Examples: 83 | >>> guided_attn_mask =_make_guided_attention(5, 5, 0.4) 84 | >>> guided_attn_mask.shape 85 | torch.Size([5, 5]) 86 | >>> guided_attn_mask 87 | tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647], 88 | [0.1175, 0.0000, 0.1175, 0.3935, 0.6753], 89 | [0.3935, 0.1175, 0.0000, 0.1175, 0.3935], 90 | [0.6753, 0.3935, 0.1175, 0.0000, 0.1175], 91 | [0.8647, 0.6753, 0.3935, 0.1175, 0.0000]]) 92 | >>> guided_attn_mask =_make_guided_attention(3, 6, 0.4) 93 | >>> guided_attn_mask.shape 94 | torch.Size([6, 3]) 95 | >>> guided_attn_mask 96 | tensor([[0.0000, 0.2934, 0.7506], 97 | [0.0831, 0.0831, 0.5422], 98 | [0.2934, 0.0000, 0.2934], 99 | [0.5422, 0.0831, 0.0831], 100 | [0.7506, 0.2934, 0.0000], 101 | [0.8858, 0.5422, 0.0831]]) 102 | 103 | """ 104 | grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen)) 105 | grid_x, grid_y = grid_x.float().to(olen.device), grid_y.float().to(ilen.device) 106 | return 1.0 - torch.exp( 107 | -((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)) 108 | ) 109 | 110 | @staticmethod 111 | def _make_masks(ilens, olens): 112 | """Make masks indicating non-padded part. 113 | 114 | Args: 115 | ilens (LongTensor or List): Batch of lengths (B,). 116 | olens (LongTensor or List): Batch of lengths (B,). 117 | 118 | Returns: 119 | Tensor: Mask tensor indicating non-padded part. 120 | dtype=torch.uint8 in PyTorch 1.2- 121 | dtype=torch.bool in PyTorch 1.2+ (including 1.2) 122 | 123 | Examples: 124 | >>> ilens, olens = [5, 2], [8, 5] 125 | >>> _make_mask(ilens, olens) 126 | tensor([[[1, 1, 1, 1, 1], 127 | [1, 1, 1, 1, 1], 128 | [1, 1, 1, 1, 1], 129 | [1, 1, 1, 1, 1], 130 | [1, 1, 1, 1, 1], 131 | [1, 1, 1, 1, 1], 132 | [1, 1, 1, 1, 1], 133 | [1, 1, 1, 1, 1]], 134 | [[1, 1, 0, 0, 0], 135 | [1, 1, 0, 0, 0], 136 | [1, 1, 0, 0, 0], 137 | [1, 1, 0, 0, 0], 138 | [1, 1, 0, 0, 0], 139 | [0, 0, 0, 0, 0], 140 | [0, 0, 0, 0, 0], 141 | [0, 0, 0, 0, 0]]], dtype=torch.uint8) 142 | 143 | """ 144 | in_masks = make_non_pad_mask(ilens) # (B, T_in) 145 | out_masks = make_non_pad_mask(olens) # (B, T_out) 146 | return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in) 147 | 148 | 149 | class FastSpeech2Loss(nn.Module): 150 | """ FastSpeech2 Loss """ 151 | 152 | def __init__(self, preprocess_config, model_config): 153 | super(FastSpeech2Loss, self).__init__() 154 | self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][ 155 | "feature" 156 | ] 157 | self.energy_feature_level = preprocess_config["preprocessing"]["energy"][ 158 | "feature" 159 | ] 160 | self.mse_loss = nn.MSELoss() 161 | self.mae_loss = nn.L1Loss() 162 | self.attn_loss = GuidedAttentionLoss() 163 | 164 | def forward(self, inputs, predictions): 165 | ( 166 | src_lens, 167 | _, 168 | mel_targets, 169 | mel_lens, 170 | _, 171 | pitch_targets, 172 | energy_targets, 173 | ) = inputs[4:] 174 | ( 175 | mel_predictions, 176 | postnet_mel_predictions, 177 | pitch_predictions, 178 | energy_predictions, 179 | log_duration_predictions, 180 | _, 181 | src_masks, 182 | mel_masks, 183 | _, 184 | _, 185 | attn, 186 | duration_targets 187 | ) = predictions 188 | src_masks = ~src_masks 189 | mel_masks = ~mel_masks 190 | log_duration_targets = torch.log(duration_targets.float() + 1) 191 | mel_targets = mel_targets[:, : mel_masks.shape[1], :] 192 | mel_masks = mel_masks[:, :mel_masks.shape[1]] 193 | 194 | log_duration_targets.requires_grad = False 195 | pitch_targets.requires_grad = False 196 | energy_targets.requires_grad = False 197 | mel_targets.requires_grad = False 198 | 199 | if self.pitch_feature_level == "phoneme_level": 200 | pitch_predictions = pitch_predictions.masked_select(src_masks) 201 | pitch_targets = pitch_targets.masked_select(src_masks) 202 | elif self.pitch_feature_level == "frame_level": 203 | pitch_predictions = pitch_predictions.masked_select(mel_masks) 204 | pitch_targets = pitch_targets.masked_select(mel_masks) 205 | 206 | if self.energy_feature_level == "phoneme_level": 207 | energy_predictions = energy_predictions.masked_select(src_masks) 208 | energy_targets = energy_targets.masked_select(src_masks) 209 | if self.energy_feature_level == "frame_level": 210 | energy_predictions = energy_predictions.masked_select(mel_masks) 211 | energy_targets = energy_targets.masked_select(mel_masks) 212 | 213 | log_duration_predictions = log_duration_predictions.masked_select(src_masks) 214 | if log_duration_targets.size() != src_masks.size(): 215 | print(log_duration_targets.size(), src_masks.size()) 216 | log_duration_targets = log_duration_targets[:, :src_masks.size(1)] 217 | log_duration_targets = log_duration_targets.masked_select(src_masks) 218 | 219 | mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1)) 220 | postnet_mel_predictions = postnet_mel_predictions.masked_select( 221 | mel_masks.unsqueeze(-1) 222 | 223 | ) 224 | mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1)) 225 | 226 | mel_loss = self.mae_loss(mel_predictions, mel_targets) 227 | postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) 228 | 229 | pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) 230 | energy_loss = self.mse_loss(energy_predictions, energy_targets) 231 | duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets.detach()) 232 | 233 | attn_loss = self.attn_loss(attn[0][:, 0], src_lens, mel_lens) 234 | attn_loss += self.attn_loss(attn[1][:, 0], src_lens, mel_lens) 235 | attn_loss += self.attn_loss(attn[2][:, 0], src_lens, mel_lens) 236 | attn_loss += self.attn_loss(attn[3][:, 0], src_lens, mel_lens) 237 | 238 | total_loss = ( 239 | mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss + attn_loss 240 | ) 241 | 242 | return ( 243 | total_loss, 244 | mel_loss, 245 | postnet_mel_loss, 246 | pitch_loss, 247 | energy_loss, 248 | duration_loss, 249 | attn_loss, 250 | ) 251 | 252 | 253 | def make_pad_mask(lengths, xs=None, length_dim=-1): 254 | """Make mask tensor containing indices of padded part. 255 | Args: 256 | lengths (LongTensor or List): Batch of lengths (B,). 257 | xs (Tensor, optional): The reference tensor. 258 | If set, masks will be the same shape as this tensor. 259 | length_dim (int, optional): Dimension indicator of the above tensor. 260 | See the example. 261 | Returns: 262 | Tensor: Mask tensor containing indices of padded part. 263 | dtype=torch.uint8 in PyTorch 1.2- 264 | dtype=torch.bool in PyTorch 1.2+ (including 1.2) 265 | Examples: 266 | With only lengths. 267 | >>> lengths = [5, 3, 2] 268 | >>> make_non_pad_mask(lengths) 269 | masks = [[0, 0, 0, 0 ,0], 270 | [0, 0, 0, 1, 1], 271 | [0, 0, 1, 1, 1]] 272 | With the reference tensor. 273 | >>> xs = torch.zeros((3, 2, 4)) 274 | >>> make_pad_mask(lengths, xs) 275 | tensor([[[0, 0, 0, 0], 276 | [0, 0, 0, 0]], 277 | [[0, 0, 0, 1], 278 | [0, 0, 0, 1]], 279 | [[0, 0, 1, 1], 280 | [0, 0, 1, 1]]], dtype=torch.uint8) 281 | >>> xs = torch.zeros((3, 2, 6)) 282 | >>> make_pad_mask(lengths, xs) 283 | tensor([[[0, 0, 0, 0, 0, 1], 284 | [0, 0, 0, 0, 0, 1]], 285 | [[0, 0, 0, 1, 1, 1], 286 | [0, 0, 0, 1, 1, 1]], 287 | [[0, 0, 1, 1, 1, 1], 288 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 289 | With the reference tensor and dimension indicator. 290 | >>> xs = torch.zeros((3, 6, 6)) 291 | >>> make_pad_mask(lengths, xs, 1) 292 | tensor([[[0, 0, 0, 0, 0, 0], 293 | [0, 0, 0, 0, 0, 0], 294 | [0, 0, 0, 0, 0, 0], 295 | [0, 0, 0, 0, 0, 0], 296 | [0, 0, 0, 0, 0, 0], 297 | [1, 1, 1, 1, 1, 1]], 298 | [[0, 0, 0, 0, 0, 0], 299 | [0, 0, 0, 0, 0, 0], 300 | [0, 0, 0, 0, 0, 0], 301 | [1, 1, 1, 1, 1, 1], 302 | [1, 1, 1, 1, 1, 1], 303 | [1, 1, 1, 1, 1, 1]], 304 | [[0, 0, 0, 0, 0, 0], 305 | [0, 0, 0, 0, 0, 0], 306 | [1, 1, 1, 1, 1, 1], 307 | [1, 1, 1, 1, 1, 1], 308 | [1, 1, 1, 1, 1, 1], 309 | [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) 310 | >>> make_pad_mask(lengths, xs, 2) 311 | tensor([[[0, 0, 0, 0, 0, 1], 312 | [0, 0, 0, 0, 0, 1], 313 | [0, 0, 0, 0, 0, 1], 314 | [0, 0, 0, 0, 0, 1], 315 | [0, 0, 0, 0, 0, 1], 316 | [0, 0, 0, 0, 0, 1]], 317 | [[0, 0, 0, 1, 1, 1], 318 | [0, 0, 0, 1, 1, 1], 319 | [0, 0, 0, 1, 1, 1], 320 | [0, 0, 0, 1, 1, 1], 321 | [0, 0, 0, 1, 1, 1], 322 | [0, 0, 0, 1, 1, 1]], 323 | [[0, 0, 1, 1, 1, 1], 324 | [0, 0, 1, 1, 1, 1], 325 | [0, 0, 1, 1, 1, 1], 326 | [0, 0, 1, 1, 1, 1], 327 | [0, 0, 1, 1, 1, 1], 328 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 329 | """ 330 | if length_dim == 0: 331 | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) 332 | 333 | if not isinstance(lengths, list): 334 | lengths = lengths.tolist() 335 | bs = int(len(lengths)) 336 | if xs is None: 337 | maxlen = int(max(lengths)) 338 | else: 339 | maxlen = xs.size(length_dim) 340 | 341 | seq_range = torch.arange(0, maxlen, dtype=torch.int64) 342 | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) 343 | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) 344 | mask = seq_range_expand >= seq_length_expand 345 | 346 | if xs is not None: 347 | assert xs.size(0) == bs, (xs.size(0), bs) 348 | 349 | if length_dim < 0: 350 | length_dim = xs.dim() + length_dim 351 | # ind = (:, None, ..., None, :, , None, ..., None) 352 | ind = tuple( 353 | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) 354 | ) 355 | mask = mask[ind].expand_as(xs).to(xs.device) 356 | return mask 357 | 358 | 359 | def make_non_pad_mask(lengths, xs=None, length_dim=-1): 360 | return ~make_pad_mask(lengths, xs, length_dim) 361 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | --------------------------------------------------------------------------------