├── data └── .gitkeep ├── config ├── .gitkeep ├── hifi │ ├── config_v3.json │ ├── config_v1.json │ └── config_v2.json └── glow │ ├── base.json │ └── base_blank.json ├── logs ├── glow │ └── .gitkeep └── hifi │ └── .gitkeep ├── results └── api │ └── .gitkeep ├── tts_infer ├── __init__.py ├── requirements.txt ├── example_inference.py └── tts.py ├── checkpoints ├── glow │ └── .gitkeep └── hifi │ └── .gitkeep ├── src ├── glow_tts │ ├── hifi │ │ ├── __init__.py │ │ ├── env.py │ │ ├── utils.py │ │ └── models.py │ ├── monotonic_align │ │ ├── pyproject.toml │ │ ├── monotonic_align │ │ │ ├── __init__.py │ │ │ ├── core.pyx │ │ │ └── mas.py │ │ └── setup.py │ ├── t2s_gradio.py │ ├── t2s_fastapi.py │ ├── text │ │ ├── cleaners.py │ │ ├── numbers.py │ │ └── __init__.py │ ├── init.py │ ├── generate_mels.py │ ├── audio_processing.py │ ├── texttospeech.py │ ├── stft.py │ ├── commons.py │ ├── utils.py │ ├── modules.py │ ├── train.py │ ├── data_utils.py │ └── attentions.py └── hifi_gan │ ├── env.py │ ├── utils.py │ ├── inference_e2e.py │ ├── inference.py │ ├── meldataset.py │ └── models.py ├── install.sh ├── scripts ├── inference │ ├── api.sh │ ├── gradio.sh │ ├── infer.sh │ └── advanced_infer.sh ├── data │ ├── duration.sh │ └── resample.sh ├── hifi │ ├── prepare_data.sh │ └── train_hifi.sh └── glow │ ├── prepare_data.sh │ └── train_glow.sh ├── requirements.txt ├── utils ├── data │ ├── duration.py │ └── resample.py ├── inference │ ├── api.py │ ├── run_gradio.py │ ├── advanced_tts.py │ └── tts.py ├── hifi │ └── prepare_iitm_data_hifi.py └── glow │ ├── prepare_iitm_data_glow.py │ └── prepare_iitm_data_glow_en.py ├── LICENSE.md ├── setup.py ├── .gitignore └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/glow/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logs/hifi/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /results/api/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tts_infer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/glow/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/hifi/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/glow_tts/hifi/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import AttrDict 2 | from .models import Generator 3 | 4 | if __name__ == "__main__": 5 | pass 6 | -------------------------------------------------------------------------------- /tts_infer/requirements.txt: -------------------------------------------------------------------------------- 1 | # will be installed with main setup.py, no need to reinstall 2 | 3 | ai4bharat-transliteration==0.5.0.3 4 | numpy==1.19.5 5 | pandas 6 | pydload -------------------------------------------------------------------------------- /src/glow_tts/monotonic_align/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "wheel", 4 | "setuptools", 5 | "cython>=0.24.0", 6 | "numpy python duration.py /src/folder/path 2 | 3 | import soundfile as sf 4 | import sys 5 | import os 6 | from glob import glob 7 | from joblib import Parallel, delayed 8 | from tqdm import tqdm 9 | 10 | 11 | def get_duration(fpath): 12 | w = sf.SoundFile(fpath) 13 | sr = w.samplerate 14 | assert 22050 == sr, "Sample rate is not 22050" 15 | return len(w) / sr 16 | 17 | 18 | def main(folder, ext="wav"): 19 | file_list = glob(folder + "/**/*." + ext, recursive=True) 20 | print(f"\n\tTotal number of wav files {len(file_list)}") 21 | duration_list = Parallel(n_jobs=1)( 22 | delayed(get_duration)(i) for i in tqdm(file_list) 23 | ) 24 | print( 25 | f"\n\tMin Duration {min(duration_list):.2f} Max Duration {max(duration_list):.2f} in secs" 26 | ) 27 | print(f"\n\tTotal Duration {sum(duration_list)/3600:.2f} in hours") 28 | 29 | 30 | if __name__ == "__main__": 31 | folder = sys.argv[1] 32 | folder = os.path.abspath(folder) 33 | main(folder) 34 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Open-Speech-EkStep 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/glow_tts/monotonic_align/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | cimport cython 4 | from cython.parallel import prange 5 | 6 | 7 | @cython.boundscheck(False) 8 | @cython.wraparound(False) 9 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 10 | cdef int x 11 | cdef int y 12 | cdef float v_prev 13 | cdef float v_cur 14 | cdef float tmp 15 | cdef int index = t_x - 1 16 | 17 | for y in range(t_y): 18 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 19 | if x == y: 20 | v_cur = max_neg_val 21 | else: 22 | v_cur = value[x, y-1] 23 | if x == 0: 24 | if y == 0: 25 | v_prev = 0. 26 | else: 27 | v_prev = max_neg_val 28 | else: 29 | v_prev = value[x-1, y-1] 30 | value[x, y] = max(v_cur, v_prev) + value[x, y] 31 | 32 | for y in range(t_y - 1, -1, -1): 33 | path[index, y] = 1 34 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 35 | index = index - 1 36 | 37 | 38 | @cython.boundscheck(False) 39 | @cython.wraparound(False) 40 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 41 | cdef int b = values.shape[0] 42 | 43 | cdef int i 44 | for i in prange(b, nogil=True): 45 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 46 | -------------------------------------------------------------------------------- /config/glow/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "use_cuda": true, 4 | "log_interval": 20, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 1e0, 8 | "betas": [0.9, 0.98], 9 | "eps": 1e-9, 10 | "warmup_steps": 4000, 11 | "scheduler": "noam", 12 | "batch_size": 16, 13 | "ddi": true, 14 | "fp16_run": true, 15 | "save_epoch": 1 16 | }, 17 | "data": { 18 | "load_mel_from_disk": false, 19 | "training_files":"../data/training/train.txt", 20 | "validation_files":"../data/training/valid.txt", 21 | "chars":"", 22 | "punc":"", 23 | "text_cleaners":["basic_indic_cleaners"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 80.0, 31 | "mel_fmax": 7600.0, 32 | "add_noise": true 33 | }, 34 | "model": { 35 | "hidden_channels": 192, 36 | "filter_channels": 768, 37 | "filter_channels_dp": 256, 38 | "kernel_size": 3, 39 | "p_dropout": 0.1, 40 | "n_blocks_dec": 12, 41 | "n_layers_enc": 6, 42 | "n_heads": 2, 43 | "p_dropout_dec": 0.05, 44 | "dilation_rate": 1, 45 | "kernel_size_dec": 5, 46 | "n_block_layers": 4, 47 | "n_sqz": 2, 48 | "prenet": true, 49 | "mean_only": true, 50 | "hidden_channels_enc": 192, 51 | "hidden_channels_dec": 192, 52 | "window_size": 4 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /config/glow/base_blank.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "use_cuda": true, 4 | "log_interval": 20, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 1e0, 8 | "betas": [0.9, 0.98], 9 | "eps": 1e-9, 10 | "warmup_steps": 4000, 11 | "scheduler": "noam", 12 | "batch_size": 16, 13 | "ddi": true, 14 | "fp16_run": true, 15 | "save_epoch": 1 16 | }, 17 | "data": { 18 | "load_mel_from_disk": false, 19 | "training_files":"../data/training/train.txt", 20 | "validation_files":"../data/training/valid.txt", 21 | "chars":"", 22 | "punc":"", 23 | "text_cleaners":["basic_indic_cleaners"], 24 | "max_wav_value": 32768.0, 25 | "sampling_rate": 22050, 26 | "filter_length": 1024, 27 | "hop_length": 256, 28 | "win_length": 1024, 29 | "n_mel_channels": 80, 30 | "mel_fmin": 80.0, 31 | "mel_fmax": 7600.0, 32 | "add_noise": true, 33 | "add_blank": true 34 | }, 35 | "model": { 36 | "hidden_channels": 192, 37 | "filter_channels": 768, 38 | "filter_channels_dp": 256, 39 | "kernel_size": 3, 40 | "p_dropout": 0.1, 41 | "n_blocks_dec": 12, 42 | "n_layers_enc": 6, 43 | "n_heads": 2, 44 | "p_dropout_dec": 0.05, 45 | "dilation_rate": 1, 46 | "kernel_size_dec": 5, 47 | "n_block_layers": 4, 48 | "n_sqz": 2, 49 | "prenet": true, 50 | "mean_only": true, 51 | "hidden_channels_enc": 192, 52 | "hidden_channels_dec": 192, 53 | "window_size": 4 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/hifi_gan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | 7 | matplotlib.use("Agg") 8 | import matplotlib.pylab as plt 9 | 10 | 11 | def plot_spectrogram(spectrogram): 12 | fig, ax = plt.subplots(figsize=(10, 2)) 13 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size * dilation - dilation) / 2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + "????????") 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | -------------------------------------------------------------------------------- /src/glow_tts/hifi/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | 7 | matplotlib.use("Agg") 8 | import matplotlib.pylab as plt 9 | 10 | 11 | def plot_spectrogram(spectrogram): 12 | fig, ax = plt.subplots(figsize=(10, 2)) 13 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size * dilation - dilation) / 2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def scan_checkpoint(cp_dir, prefix): 53 | pattern = os.path.join(cp_dir, prefix + "????????") 54 | cp_list = glob.glob(pattern) 55 | if len(cp_list) == 0: 56 | return None 57 | return sorted(cp_list)[-1] 58 | -------------------------------------------------------------------------------- /src/glow_tts/monotonic_align/monotonic_align/mas.py: -------------------------------------------------------------------------------- 1 | from typing import overload 2 | import numpy as np 3 | import torch 4 | from monotonic_align.core import maximum_path_c 5 | 6 | 7 | def mask_from_len(lens: torch.Tensor, max_len=None): 8 | """ 9 | Make a `mask` from lens. 10 | 11 | :param inputs: (B, T, D) 12 | :param lens: (B) 13 | 14 | :return: 15 | `mask`: (B, T) 16 | """ 17 | if max_len is None: 18 | max_len = lens.max() 19 | index = torch.arange(max_len).to(lens).view(1, -1) 20 | return index < lens.unsqueeze(1) # (B, T) 21 | 22 | 23 | def mask_from_lens( 24 | similarity: torch.Tensor, 25 | symbol_lens: torch.Tensor, 26 | mel_lens: torch.Tensor, 27 | ): 28 | """ 29 | :param similarity: (B, S, T) 30 | :param symbol_lens: (B,) 31 | :param mel_lens: (B,) 32 | """ 33 | _, S, T = similarity.size() 34 | mask_S = mask_from_len(symbol_lens, S) 35 | mask_T = mask_from_len(mel_lens, T) 36 | mask_ST = mask_S.unsqueeze(2) * mask_T.unsqueeze(1) 37 | return mask_ST.to(similarity) 38 | 39 | 40 | def maximum_path(value, mask=None): 41 | """Cython optimised version. 42 | value: [b, t_x, t_y] 43 | mask: [b, t_x, t_y] 44 | """ 45 | if mask is None: 46 | mask = torch.zeros_like(value) 47 | 48 | value = value * mask 49 | device = value.device 50 | dtype = value.dtype 51 | value = value.data.cpu().numpy().astype(np.float32) 52 | path = np.zeros_like(value).astype(np.int32) 53 | mask = mask.data.cpu().numpy() 54 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 55 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 56 | maximum_path_c(path, value, t_x_max, t_y_max) 57 | return torch.from_numpy(path).to(device=device, dtype=dtype) 58 | -------------------------------------------------------------------------------- /src/glow_tts/t2s_fastapi.py: -------------------------------------------------------------------------------- 1 | from starlette.responses import StreamingResponse 2 | from texttospeech import MelToWav, TextToMel 3 | from typing import Optional 4 | from pydantic import BaseModel 5 | from fastapi import FastAPI, HTTPException 6 | import uvicorn 7 | import base64 8 | 9 | app = FastAPI() 10 | 11 | 12 | class TextJson(BaseModel): 13 | text: str 14 | lang: Optional[str] = "hi" 15 | gender: Optional[str] = "male" 16 | 17 | 18 | glow_hi_male = TextToMel(glow_model_dir="", device="") 19 | glow_hi_female = TextToMel(glow_model_dir="", device="") 20 | hifi_hi = MelToWav(hifi_model_dir="", device="") 21 | 22 | 23 | available_choice = { 24 | "hi_male": [glow_hi_male, hifi_hi], 25 | "hi_female": [glow_hi_female, hifi_hi], 26 | } 27 | 28 | 29 | @app.post("/TTS/") 30 | async def tts(input: TextJson): 31 | text = input.text 32 | lang = input.lang 33 | gender = input.gender 34 | 35 | choice = lang + "_" + gender 36 | if choice in available_choice.keys(): 37 | t2s = available_choice[choice] 38 | else: 39 | raise HTTPException( 40 | status_code=400, detail={"error": "Requested model not found"} 41 | ) 42 | 43 | if text: 44 | mel = t2s[0].generate_mel(text) 45 | data, sr = t2s[1].generate_wav(mel) 46 | t2s.save_audio("out.wav", data, sr) 47 | else: 48 | raise HTTPException(status_code=400, detail={"error": "No text"}) 49 | 50 | ## to return outpur as a file 51 | # audio = open('out.wav', mode='rb') 52 | # return StreamingResponse(audio, media_type="audio/wav") 53 | 54 | with open("out.wav", "rb") as audio_file: 55 | encoded_bytes = base64.b64encode(audio_file.read()) 56 | encoded_string = encoded_bytes.decode() 57 | return {"encoding": "base64", "data": encoded_string, "sr": sr} 58 | 59 | 60 | if __name__ == "__main__": 61 | uvicorn.run( 62 | "t2s_fastapi:app", host="127.0.0.1", port=5000, log_level="info", reload=True 63 | ) 64 | -------------------------------------------------------------------------------- /src/glow_tts/text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from unidecode import unidecode 4 | from .numbers import normalize_numbers 5 | 6 | 7 | 8 | 9 | # Regular expression matching whitespace: 10 | _whitespace_re = re.compile(r"\s+") 11 | 12 | def lowercase(text): 13 | return text.lower() 14 | 15 | def collapse_whitespace(text): 16 | return re.sub(_whitespace_re, " ", text) 17 | 18 | def basic_indic_cleaners(text): 19 | """Basic pipeline that collapses whitespace without transliteration.""" 20 | text = collapse_whitespace(text) 21 | return text 22 | 23 | 24 | def english_cleaner(text): 25 | text = text.lower().replace('‘','\'').replace('’','\'') 26 | return text 27 | 28 | 29 | def lowercase(text): 30 | return text.lower() 31 | 32 | def convert_to_ascii(text): 33 | return unidecode(text) 34 | 35 | def expand_numbers(text): 36 | return normalize_numbers(text) 37 | 38 | def expand_abbreviations(text): 39 | for regex, replacement in _abbreviations: 40 | text = re.sub(regex, replacement, text) 41 | return text 42 | 43 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 44 | ('mrs', 'missus'), 45 | ('mr', 'mister'), 46 | ('dr', 'doctor'), 47 | ('st', 'saint'), 48 | ('co', 'company'), 49 | ('jr', 'junior'), 50 | ('maj', 'major'), 51 | ('gen', 'general'), 52 | ('drs', 'doctors'), 53 | ('rev', 'reverend'), 54 | ('lt', 'lieutenant'), 55 | ('hon', 'honorable'), 56 | ('sgt', 'sergeant'), 57 | ('capt', 'captain'), 58 | ('esq', 'esquire'), 59 | ('ltd', 'limited'), 60 | ('col', 'colonel'), 61 | ('ft', 'fort'), 62 | ('pvt', 'private'), 63 | ('rs', 'Rupees') 64 | ]] 65 | 66 | 67 | 68 | 69 | 70 | 71 | def english_cleaners(text): 72 | '''Pipeline for English text, including number and abbreviation expansion.''' 73 | text = convert_to_ascii(text) 74 | text = lowercase(text) 75 | text = expand_numbers(text) 76 | text = expand_abbreviations(text) 77 | text = collapse_whitespace(text) 78 | return text 79 | -------------------------------------------------------------------------------- /utils/inference/api.py: -------------------------------------------------------------------------------- 1 | from starlette.responses import StreamingResponse 2 | from tts import MelToWav, TextToMel 3 | from advanced_tts import load_all_models, run_tts_paragraph 4 | from typing import Optional 5 | from pydantic import BaseModel 6 | from fastapi import FastAPI, HTTPException 7 | import uvicorn 8 | import base64 9 | import argparse 10 | import json 11 | import time 12 | from argparse import Namespace 13 | 14 | app = FastAPI() 15 | 16 | 17 | class TextJson(BaseModel): 18 | text: str 19 | lang: Optional[str] = "hi" 20 | noise_scale: Optional[float]=0.667 21 | length_scale: Optional[float]=1.0 22 | transliteration: Optional[int]=1 23 | number_conversion: Optional[int]=1 24 | split_sentences: Optional[int]=1 25 | 26 | 27 | 28 | 29 | @app.post("/TTS/") 30 | async def tts(input: TextJson): 31 | text = input.text 32 | lang = input.lang 33 | 34 | args = Namespace(**input.dict()) 35 | 36 | args.wav = '../../results/api/'+str(int(time.time())) + '.wav' 37 | 38 | if text: 39 | sr, audio = run_tts_paragraph(args) 40 | else: 41 | raise HTTPException(status_code=400, detail={"error": "No text"}) 42 | 43 | ## to return outpur as a file 44 | audio = open(args.wav, mode='rb') 45 | return StreamingResponse(audio, media_type="audio/wav") 46 | 47 | # with open(args.wav, "rb") as audio_file: 48 | # encoded_bytes = base64.b64encode(audio_file.read()) 49 | # encoded_string = encoded_bytes.decode() 50 | # return {"encoding": "base64", "data": encoded_string, "sr": sr} 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("-a", "--acoustic", required=True, type=str) 56 | parser.add_argument("-v", "--vocoder", required=True, type=str) 57 | parser.add_argument("-d", "--device", type=str, default="cpu") 58 | parser.add_argument("-L", "--lang", type=str, required=True) 59 | 60 | args = parser.parse_args() 61 | 62 | load_all_models(args) 63 | 64 | uvicorn.run( 65 | "api:app", host="0.0.0.0", port=6006, log_level="debug" 66 | ) 67 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name="vakyansh-tts", 8 | version="0.0.5", 9 | description="Text to speech for Indic languages", 10 | long_description=long_description, 11 | long_description_content_type="text/markdown", 12 | url="https://github.com/Open-Speech-EkStep/vakyansh-tts.git", 13 | keywords="nlp, tts, Indic languages, deep learning, text to speech", 14 | # package_dir={'': 'src'}, 15 | # packages=find_packages(where='src'), 16 | packages=["tts_infer"], 17 | python_requires=">=3.7, <4", 18 | install_requires=[ 19 | "Cython==0.29.24", 20 | "layers==0.1.5", 21 | "librosa==0.8.1", 22 | "matplotlib==3.3.4", 23 | "numpy==1.20.2", 24 | "scipy==1.5.4", 25 | "tensorboardX==2.4", 26 | "tensorboard==2.7.0", 27 | "tqdm==4.62.3", 28 | "fastapi==0.70.0", 29 | "uvicorn==0.15.0", 30 | "gradio==2.5.2", 31 | "wavio==0.0.4", 32 | "pydload==1.0.9", 33 | "mosestokenizer==1.2.1", 34 | "indic-nlp-library==0.81" 35 | ], 36 | classifiers=[ 37 | # How mature is this project? Common values are 38 | # 3 - Alpha 39 | # 4 - Beta 40 | # 5 - Production/Stable 41 | "Development Status :: 3 - Alpha", 42 | # Indicate who your project is intended for 43 | "Intended Audience :: Developers", 44 | "Intended Audience :: Education", 45 | "Intended Audience :: Science/Research", 46 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 47 | "Topic :: Text Processing :: Linguistic", 48 | # Pick your license as you wish (should match "license" above) 49 | "License :: OSI Approved :: MIT License", 50 | # Specify the Python versions you support here. In particular, ensure 51 | # that you indicate whether you support Python 2, Python 3 or both. 52 | "Programming Language :: Python :: 3.7", 53 | ], 54 | include_package_data=True, 55 | ) 56 | -------------------------------------------------------------------------------- /utils/hifi/prepare_iitm_data_hifi.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import random 4 | import sys 5 | import os 6 | import argparse 7 | 8 | 9 | 10 | 11 | def process_data(args): 12 | 13 | path = args.input_path 14 | valid_files = args.valid_files 15 | test_files = args.test_files 16 | dest_path = args.dest_path 17 | 18 | list_paths = path.split(',') 19 | 20 | valid_set = [] 21 | training_set = [] 22 | test_set = [] 23 | 24 | for local_path in list_paths: 25 | files = glob.glob(local_path+'/*.wav') 26 | print(f"Total files: {len(files)}") 27 | 28 | valid_set_local = random.sample(files, valid_files) 29 | 30 | test_set_local = random.sample(valid_set_local, test_files) 31 | valid_set.extend(list(set(valid_set_local) - set(test_set_local))) 32 | test_set.extend(test_set_local) 33 | 34 | print(len(valid_set_local)) 35 | 36 | training_set_local = set(files) - set(valid_set_local) 37 | print(len(training_set_local)) 38 | training_set.extend(training_set_local) 39 | 40 | 41 | valid_set = random.sample(valid_set, len(valid_set)) 42 | test_set = random.sample(test_set, len(test_set)) 43 | training_set = random.sample(training_set, len(training_set)) 44 | 45 | with open(os.path.join(dest_path , 'valid.txt'), mode = 'w+') as file: 46 | file.write("\n".join(list(valid_set))) 47 | 48 | with open(os.path.join(dest_path , 'train.txt'), mode = 'w+') as file: 49 | file.write("\n".join(list(training_set))) 50 | 51 | with open(os.path.join(dest_path , 'test.txt'), mode = 'w+') as file: 52 | file.write("\n".join(list(test_set))) 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('-i','--input-path',type=str,help='path to input wav files') 58 | parser.add_argument('-v','--valid-files',type=int,help='number of valid files') 59 | parser.add_argument('-t','--test-files',type=int,help='number of test files') 60 | parser.add_argument('-d','--dest-path',type=str,help='destination path to output filelists') 61 | 62 | args = parser.parse_args() 63 | 64 | process_data(args) -------------------------------------------------------------------------------- /utils/data/resample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import librosa 3 | import numpy as np 4 | import os 5 | import scipy 6 | import scipy.io.wavfile 7 | import sys 8 | 9 | from glob import glob 10 | from tqdm import tqdm 11 | from joblib import Parallel, delayed 12 | 13 | 14 | def check_directories(dir_input, dir_output): 15 | if not os.path.exists(dir_input): 16 | sys.exit("Error: Input directory does not exist: {}".format(dir_input)) 17 | if not os.path.exists(dir_output): 18 | sys.exit("Error: Output directory does not exist: {}".format(dir_output)) 19 | abs_a = os.path.abspath(dir_input) 20 | abs_b = os.path.abspath(dir_output) 21 | if abs_a == abs_b: 22 | sys.exit("Error: Paths are the same: {}".format(abs_a)) 23 | 24 | 25 | def resample_file(input_filename, output_filename, sample_rate): 26 | mono = ( 27 | True # librosa converts signal to mono by default, so I'm just surfacing this 28 | ) 29 | audio, existing_rate = librosa.load(input_filename, sr=sample_rate, mono=mono) 30 | audio /= 1.414 # Scale to [-1.0, 1.0] 31 | audio *= 32767 # Scale to int16 32 | audio = audio.astype(np.int16) 33 | scipy.io.wavfile.write(output_filename, sample_rate, audio) 34 | 35 | 36 | def downsample_wav_files(input_dir, output_dir, output_sample_rate): 37 | check_directories(input_dir, output_dir) 38 | inp_wav_paths = glob(input_dir + "/*.wav") 39 | out_wav_paths = [ 40 | os.path.join(output_dir, os.path.basename(p)) for p in inp_wav_paths 41 | ] 42 | _ = Parallel(n_jobs=-1)( 43 | delayed(resample_file)(i, o, output_sample_rate) 44 | for i, o in tqdm(zip(inp_wav_paths, out_wav_paths)) 45 | ) 46 | 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--input_dir", "-i", type=str, required=True) 51 | parser.add_argument("--output_dir", "-o", type=str, required=True) 52 | parser.add_argument("--output_sample_rate", "-s", type=int, required=True) 53 | return parser.parse_args() 54 | 55 | 56 | if __name__ == "__main__": 57 | args = parse_args() 58 | downsample_wav_files(args.input_dir, args.output_dir, args.output_sample_rate) 59 | print(f"\n\tCompleted") 60 | -------------------------------------------------------------------------------- /src/glow_tts/text/numbers.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | import re 3 | 4 | 5 | _inflect = inflect.engine() 6 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 7 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 8 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 9 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 10 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 11 | _number_re = re.compile(r'[0-9]+') 12 | 13 | 14 | def _remove_commas(m): 15 | return m.group(1).replace(',', '') 16 | 17 | 18 | def _expand_decimal_point(m): 19 | return m.group(1).replace('.', ' point ') 20 | 21 | 22 | def _expand_dollars(m): 23 | match = m.group(1) 24 | parts = match.split('.') 25 | if len(parts) > 2: 26 | return match + ' dollars' # Unexpected format 27 | dollars = int(parts[0]) if parts[0] else 0 28 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 29 | if dollars and cents: 30 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 31 | cent_unit = 'cent' if cents == 1 else 'cents' 32 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 33 | elif dollars: 34 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 35 | return '%s %s' % (dollars, dollar_unit) 36 | elif cents: 37 | cent_unit = 'cent' if cents == 1 else 'cents' 38 | return '%s %s' % (cents, cent_unit) 39 | else: 40 | return 'zero dollars' 41 | 42 | 43 | def _expand_ordinal(m): 44 | return _inflect.number_to_words(m.group(0)) 45 | 46 | 47 | def _expand_number(m): 48 | num = int(m.group(0)) 49 | if num > 1000 and num < 3000: 50 | if num == 2000: 51 | return 'two thousand' 52 | elif num > 2000 and num < 2010: 53 | return 'two thousand ' + _inflect.number_to_words(num % 100) 54 | elif num % 100 == 0: 55 | return _inflect.number_to_words(num // 100) + ' hundred' 56 | else: 57 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 58 | else: 59 | return _inflect.number_to_words(num, andword='') 60 | 61 | 62 | def normalize_numbers(text): 63 | text = re.sub(_comma_number_re, _remove_commas, text) 64 | text = re.sub(_pounds_re, r'\1 pounds', text) 65 | text = re.sub(_dollars_re, _expand_dollars, text) 66 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 67 | text = re.sub(_ordinal_re, _expand_ordinal, text) 68 | text = re.sub(_number_re, _expand_number, text) 69 | return text -------------------------------------------------------------------------------- /src/glow_tts/init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import math 5 | import torch 6 | from torch import nn, optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | 10 | from data_utils import TextMelLoader, TextMelCollate 11 | import models 12 | import commons 13 | import utils 14 | 15 | 16 | class FlowGenerator_DDI(models.FlowGenerator): 17 | """A helper for Data-dependent Initialization""" 18 | 19 | def __init__(self, *args, **kwargs): 20 | super().__init__(*args, **kwargs) 21 | for f in self.decoder.flows: 22 | if getattr(f, "set_ddi", False): 23 | f.set_ddi(True) 24 | 25 | 26 | def main(): 27 | hps = utils.get_hparams() 28 | logger = utils.get_logger(hps.log_dir) 29 | logger.info(hps) 30 | utils.check_git_hash(hps.log_dir) 31 | 32 | torch.manual_seed(hps.train.seed) 33 | 34 | train_dataset = TextMelLoader(hps.data.training_files, hps.data) 35 | collate_fn = TextMelCollate(1) 36 | train_loader = DataLoader( 37 | train_dataset, 38 | num_workers=8, 39 | shuffle=True, 40 | batch_size=hps.train.batch_size, 41 | pin_memory=True, 42 | drop_last=True, 43 | collate_fn=collate_fn, 44 | ) 45 | symbols = hps.data.punc + hps.data.chars 46 | generator = FlowGenerator_DDI( 47 | len(symbols) + getattr(hps.data, "add_blank", False), 48 | out_channels=hps.data.n_mel_channels, 49 | **hps.model 50 | ).cuda() 51 | optimizer_g = commons.Adam( 52 | generator.parameters(), 53 | scheduler=hps.train.scheduler, 54 | dim_model=hps.model.hidden_channels, 55 | warmup_steps=hps.train.warmup_steps, 56 | lr=hps.train.learning_rate, 57 | betas=hps.train.betas, 58 | eps=hps.train.eps, 59 | ) 60 | 61 | generator.train() 62 | for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader): 63 | x, x_lengths = x.cuda(), x_lengths.cuda() 64 | y, y_lengths = y.cuda(), y_lengths.cuda() 65 | 66 | _ = generator(x, x_lengths, y, y_lengths, gen=False) 67 | break 68 | 69 | utils.save_checkpoint( 70 | generator, 71 | optimizer_g, 72 | hps.train.learning_rate, 73 | 0, 74 | os.path.join(hps.model_dir, "ddi_G.pth"), 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /utils/inference/run_gradio.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import argparse 3 | import numpy as np 4 | from argparse import Namespace 5 | from advanced_tts import load_all_models, run_tts_paragraph 6 | 7 | 8 | def hit_tts(textbox, slider_noise_scale, slider_length_sclae, choice_transliteration, choice_number_conversion, choice_split_sentences): 9 | inputs_to_gradio = {'text' : textbox, 10 | 'noise_scale': slider_noise_scale, 11 | 'length_scale': slider_length_sclae, 12 | 'transliteration' : 1 if choice_transliteration else 0, 13 | 'number_conversion' : 1 if choice_number_conversion else 0, 14 | 'split_sentences' : 1 if choice_split_sentences else 0 15 | } 16 | 17 | args = Namespace(**inputs_to_gradio) 18 | args.wav = None 19 | args.lang = lang 20 | 21 | if args.text: 22 | sr, audio = run_tts_paragraph(args) 23 | return (sr, audio) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("-a", "--acoustic", required=True, type=str) 29 | parser.add_argument("-v", "--vocoder", required=True, type=str) 30 | parser.add_argument("-d", "--device", type=str, default="cpu") 31 | parser.add_argument("-L", "--lang", type=str, required=True) 32 | 33 | global lang 34 | 35 | args = parser.parse_args() 36 | lang = args.lang 37 | load_all_models(args) 38 | 39 | textbox = gr.inputs.Textbox(placeholder="Enter Text to run", default="", label="TTS") 40 | slider_noise_scale = gr.inputs.Slider(minimum=0, maximum=1.0, step=0.001, default=0.667, label='Enter Noise Scale') 41 | slider_length_sclae = gr.inputs.Slider(minimum=0, maximum=2.0, step=0.1, default=1.0, label='Enter Slider Scale') 42 | 43 | choice_transliteration = gr.inputs.Checkbox(default=True, label="Transliteration") 44 | choice_number_conversion = gr.inputs.Checkbox(default=True, label="Number Conversion") 45 | choice_split_sentences = gr.inputs.Checkbox(default=True, label="Split Sentences") 46 | 47 | 48 | 49 | op = gr.outputs.Audio(type="numpy", label=None) 50 | 51 | inputs_to_gradio = [textbox, slider_noise_scale, slider_length_sclae, choice_transliteration, choice_number_conversion, choice_split_sentences] 52 | iface = gr.Interface(fn=hit_tts, inputs=inputs_to_gradio, outputs=op, theme='huggingface', title='Run TTS example') 53 | iface.launch(share=True, enable_queue=True) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | .DS_Store 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | .idea/ 133 | -------------------------------------------------------------------------------- /src/glow_tts/generate_mels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import commons 5 | 6 | import models 7 | import utils 8 | from argparse import ArgumentParser 9 | from tqdm import tqdm 10 | from text import text_to_sequence 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser() 14 | parser.add_argument("-m", "--model_dir", required=True, type=str) 15 | parser.add_argument("-s", "--mels_dir", required=True, type=str) 16 | args = parser.parse_args() 17 | MODEL_DIR = args.model_dir # path to model dir 18 | SAVE_MELS_DIR = args.mels_dir # path to save generated mels 19 | 20 | if not os.path.exists(SAVE_MELS_DIR): 21 | os.makedirs(SAVE_MELS_DIR) 22 | 23 | hps = utils.get_hparams_from_dir(MODEL_DIR) 24 | symbols = list(hps.data.punc) + list(hps.data.chars) 25 | checkpoint_path = utils.latest_checkpoint_path(MODEL_DIR) 26 | cleaner = hps.data.text_cleaners 27 | 28 | model = models.FlowGenerator( 29 | len(symbols) + getattr(hps.data, "add_blank", False), 30 | out_channels=hps.data.n_mel_channels, 31 | **hps.model 32 | ).to("cuda") 33 | 34 | utils.load_checkpoint(checkpoint_path, model) 35 | model.decoder.store_inverse() # do not calcuate jacobians for fast decoding 36 | _ = model.eval() 37 | 38 | def get_mel(text, fpath): 39 | if getattr(hps.data, "add_blank", False): 40 | text_norm = text_to_sequence(text, symbols, cleaner) 41 | text_norm = commons.intersperse(text_norm, len(symbols)) 42 | else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality 43 | text = " " + text.strip() + " " 44 | text_norm = text_to_sequence(text, symbols, cleaner) 45 | 46 | sequence = np.array(text_norm)[None, :] 47 | 48 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() 49 | x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() 50 | 51 | with torch.no_grad(): 52 | noise_scale = 0.667 53 | length_scale = 1.0 54 | (y_gen_tst, *_), *_, (attn_gen, *_) = model( 55 | x_tst, 56 | x_tst_lengths, 57 | gen=True, 58 | noise_scale=noise_scale, 59 | length_scale=length_scale, 60 | ) 61 | 62 | np.save(os.path.join(SAVE_MELS_DIR, fpath), y_gen_tst.cpu().detach().numpy()) 63 | 64 | for f in [hps.data.training_files, hps.data.validation_files]: 65 | file_lines = open(f).read().splitlines() 66 | 67 | for line in tqdm(file_lines): 68 | fname, text = line.split("|") 69 | fname = os.path.basename(fname).replace(".wav", ".npy") 70 | get_mel(text, fname) 71 | -------------------------------------------------------------------------------- /src/hifi_gan/inference_e2e.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import numpy as np 6 | import argparse 7 | import json 8 | import torch 9 | from scipy.io.wavfile import write 10 | from env import AttrDict 11 | from meldataset import MAX_WAV_VALUE 12 | from models import Generator 13 | 14 | h = None 15 | device = None 16 | 17 | 18 | def load_checkpoint(filepath, device): 19 | assert os.path.isfile(filepath) 20 | print("Loading '{}'".format(filepath)) 21 | checkpoint_dict = torch.load(filepath, map_location=device) 22 | print("Complete.") 23 | return checkpoint_dict 24 | 25 | 26 | def scan_checkpoint(cp_dir, prefix): 27 | pattern = os.path.join(cp_dir, prefix + "*") 28 | cp_list = glob.glob(pattern) 29 | if len(cp_list) == 0: 30 | return "" 31 | return sorted(cp_list)[-1] 32 | 33 | 34 | def inference(a): 35 | generator = Generator(h).to(device) 36 | 37 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 38 | generator.load_state_dict(state_dict_g["generator"]) 39 | 40 | filelist = os.listdir(a.input_mels_dir) 41 | 42 | os.makedirs(a.output_dir, exist_ok=True) 43 | 44 | generator.eval() 45 | generator.remove_weight_norm() 46 | with torch.no_grad(): 47 | for i, filname in enumerate(filelist): 48 | x = np.load(os.path.join(a.input_mels_dir, filname)) 49 | x = torch.FloatTensor(x).to(device) 50 | y_g_hat = generator(x) 51 | audio = y_g_hat.squeeze() 52 | audio = audio * MAX_WAV_VALUE 53 | audio = audio.cpu().numpy().astype("int16") 54 | 55 | output_file = os.path.join( 56 | a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav" 57 | ) 58 | write(output_file, h.sampling_rate, audio) 59 | print(output_file) 60 | 61 | 62 | def main(): 63 | print("Initializing Inference Process..") 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--input_mels_dir", default="test_mel_files") 67 | parser.add_argument("--output_dir", default="generated_files_from_mel") 68 | parser.add_argument("--checkpoint_file", required=True) 69 | a = parser.parse_args() 70 | 71 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") 72 | with open(config_file) as f: 73 | data = f.read() 74 | 75 | global h 76 | json_config = json.loads(data) 77 | h = AttrDict(json_config) 78 | 79 | torch.manual_seed(h.seed) 80 | global device 81 | if torch.cuda.is_available(): 82 | torch.cuda.manual_seed(h.seed) 83 | device = torch.device("cuda") 84 | else: 85 | device = torch.device("cpu") 86 | 87 | inference(a) 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /src/glow_tts/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | 5 | # Regular expression matching text enclosed in curly braces: 6 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 7 | 8 | 9 | def get_arpabet(word, dictionary): 10 | word_arpabet = dictionary.lookup(word) 11 | if word_arpabet is not None: 12 | return "{" + word_arpabet[0] + "}" 13 | else: 14 | return word 15 | 16 | 17 | def text_to_sequence(text, symbols, cleaner_names, dictionary=None): 18 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 19 | 20 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 21 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 22 | 23 | Args: 24 | text: string to convert to a sequence 25 | cleaner_names: names of the cleaner functions to run the text through 26 | dictionary: arpabet class with arpabet dictionary 27 | 28 | Returns: 29 | List of integers corresponding to the symbols in the text 30 | ''' 31 | # Mappings from symbol to numeric ID and vice versa: 32 | global _id_to_symbol, _symbol_to_id 33 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 34 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 35 | 36 | sequence = [] 37 | 38 | space = _symbols_to_sequence(' ') 39 | # Check for curly braces and treat their contents as ARPAbet: 40 | while len(text): 41 | m = _curly_re.match(text) 42 | if not m: 43 | clean_text = _clean_text(text, cleaner_names) 44 | if dictionary is not None: 45 | clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")] 46 | for i in range(len(clean_text)): 47 | t = clean_text[i] 48 | if t.startswith("{"): 49 | sequence += _arpabet_to_sequence(t[1:-1]) 50 | else: 51 | sequence += _symbols_to_sequence(t) 52 | sequence += space 53 | else: 54 | sequence += _symbols_to_sequence(clean_text) 55 | break 56 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 57 | sequence += _arpabet_to_sequence(m.group(2)) 58 | text = m.group(3) 59 | 60 | # remove trailing space 61 | if dictionary is not None: 62 | sequence = sequence[:-1] if sequence[-1] == space[0] else sequence 63 | return sequence 64 | 65 | 66 | def _clean_text(text, cleaner_names): 67 | for name in cleaner_names: 68 | cleaner = getattr(cleaners, name) 69 | if not cleaner: 70 | raise Exception('Unknown cleaner: %s' % name) 71 | text = cleaner(text) 72 | return text 73 | 74 | 75 | def _symbols_to_sequence(symbols): 76 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 77 | 78 | 79 | def _arpabet_to_sequence(text): 80 | return _symbols_to_sequence(['@' + s for s in text.split()]) 81 | 82 | 83 | def _should_keep_symbol(s): 84 | return s in _symbol_to_id and s is not '_' and s is not '~' -------------------------------------------------------------------------------- /src/hifi_gan/inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import glob 4 | import os 5 | import argparse 6 | import json 7 | import torch 8 | from scipy.io.wavfile import write 9 | from env import AttrDict 10 | from meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav 11 | from models import Generator 12 | 13 | h = None 14 | device = None 15 | 16 | 17 | def load_checkpoint(filepath, device): 18 | assert os.path.isfile(filepath) 19 | print("Loading '{}'".format(filepath)) 20 | checkpoint_dict = torch.load(filepath, map_location=device) 21 | print("Complete.") 22 | return checkpoint_dict 23 | 24 | 25 | def get_mel(x): 26 | return mel_spectrogram( 27 | x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax 28 | ) 29 | 30 | 31 | def scan_checkpoint(cp_dir, prefix): 32 | pattern = os.path.join(cp_dir, prefix + "*") 33 | cp_list = glob.glob(pattern) 34 | if len(cp_list) == 0: 35 | return "" 36 | return sorted(cp_list)[-1] 37 | 38 | 39 | def inference(a): 40 | generator = Generator(h).to(device) 41 | 42 | state_dict_g = load_checkpoint(a.checkpoint_file, device) 43 | generator.load_state_dict(state_dict_g["generator"]) 44 | 45 | filelist = os.listdir(a.input_wavs_dir) 46 | 47 | os.makedirs(a.output_dir, exist_ok=True) 48 | 49 | generator.eval() 50 | generator.remove_weight_norm() 51 | with torch.no_grad(): 52 | for i, filname in enumerate(filelist): 53 | wav, sr = load_wav(os.path.join(a.input_wavs_dir, filname)) 54 | wav = wav / MAX_WAV_VALUE 55 | wav = torch.FloatTensor(wav).to(device) 56 | x = get_mel(wav.unsqueeze(0)) 57 | y_g_hat = generator(x) 58 | audio = y_g_hat.squeeze() 59 | audio = audio * MAX_WAV_VALUE 60 | audio = audio.cpu().numpy().astype("int16") 61 | 62 | output_file = os.path.join( 63 | a.output_dir, os.path.splitext(filname)[0] + "_generated.wav" 64 | ) 65 | write(output_file, h.sampling_rate, audio) 66 | print(output_file) 67 | 68 | 69 | def main(): 70 | print("Initializing Inference Process..") 71 | 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--input_wavs_dir", default="test_files") 74 | parser.add_argument("--output_dir", default="generated_files") 75 | parser.add_argument("--checkpoint_file", required=True) 76 | a = parser.parse_args() 77 | 78 | config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") 79 | with open(config_file) as f: 80 | data = f.read() 81 | 82 | global h 83 | json_config = json.loads(data) 84 | h = AttrDict(json_config) 85 | 86 | torch.manual_seed(h.seed) 87 | global device 88 | if torch.cuda.is_available(): 89 | torch.cuda.manual_seed(h.seed) 90 | device = torch.device("cuda") 91 | else: 92 | device = torch.device("cpu") 93 | 94 | inference(a) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /src/glow_tts/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length=200, 11 | win_length=800, 12 | n_fft=800, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return torch.log(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /tts_infer/example_inference.py: -------------------------------------------------------------------------------- 1 | ''' Example file to test tts_infer after installing it. Refer to section 1.1 in README.md for steps of installation. ''' 2 | 3 | from tts_infer.tts import TextToMel, MelToWav 4 | from tts_infer.transliterate import XlitEngine 5 | from tts_infer.num_to_word_on_sent import normalize_nums 6 | 7 | import re 8 | import numpy as np 9 | from scipy.io.wavfile import write 10 | 11 | from mosestokenizer import * 12 | from indicnlp.tokenize import sentence_tokenize 13 | 14 | INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"] 15 | 16 | def split_sentences(paragraph, language): 17 | if language == "en": 18 | with MosesSentenceSplitter(language) as splitter: 19 | return splitter([paragraph]) 20 | elif language in INDIC: 21 | return sentence_tokenize.sentence_split(paragraph, lang=language) 22 | 23 | 24 | device='cpu' 25 | text_to_mel = TextToMel(glow_model_dir='/path/to/glow_ckp', device=device) 26 | mel_to_wav = MelToWav(hifi_model_dir='/path/to/hifi_ckp', device=device) 27 | 28 | lang='hi' # transliteration from En to Hi 29 | engine = XlitEngine(lang) # loading translit model globally 30 | 31 | def translit(text, lang): 32 | reg = re.compile(r'[a-zA-Z]') 33 | words = [engine.translit_word(word, topk=1)[lang][0] if reg.match(word) else word for word in text.split()] 34 | updated_sent = ' '.join(words) 35 | return updated_sent 36 | 37 | def run_tts(text, lang): 38 | text = text.replace('।', '.') # only for hindi models 39 | text_num_to_word = normalize_nums(text, lang) # converting numbers to words in lang 40 | text_num_to_word_and_transliterated = translit(text_num_to_word, lang) # transliterating english words to lang 41 | final_text = ' ' + text_num_to_word_and_transliterated 42 | 43 | mel = text_to_mel.generate_mel(final_text) 44 | audio, sr = mel_to_wav.generate_wav(mel) 45 | write(filename='temp.wav', rate=sr, data=audio) # for saving wav file, if needed 46 | return (sr, audio) 47 | 48 | def run_tts_paragraph(text, lang): 49 | audio_list = [] 50 | split_sentences_list = split_sentences(text, language='hi') 51 | 52 | for sent in split_sentences_list: 53 | sr, audio = run_tts(sent, lang) 54 | audio_list.append(audio) 55 | 56 | concatenated_audio = np.concatenate([i for i in audio_list]) 57 | write(filename='temp_long.wav', rate=sr, data=concatenated_audio) 58 | return (sr, concatenated_audio) 59 | 60 | if __name__ == "__main__": 61 | _, audio = run_tts('mera naam neeraj hai', 'hi') 62 | 63 | para = ''' 64 | भारत मेरा देश है और मुझे भारतीय होने पर गर्व है। ये विश्व का सातवाँ सबसे बड़ा और विश्व में दूसरा सबसे अधिक जनसंख्या वाला देश है। 65 | इसे भारत, हिन्दुस्तान और आर्यव्रत के नाम से भी जाना जाता है। ये एक प्रायद्वीप है जो पूरब में बंगाल की खाड़ी, 66 | पश्चिम में अरेबियन सागर और दक्षिण में भारतीय महासागर जैसे तीन महासगरों से घिरा हुआ है। 67 | भारत का राष्ट्रीय पशु चीता, राष्ट्रीय पक्षी मोर, राष्ट्रीय फूल कमल, और राष्ट्रीय फल आम है। 68 | भारत मेरा देश है और मुझे भारतीय होने पर गर्व है। ये विश्व का सातवाँ सबसे बड़ा और विश्व में दूसरा सबसे अधिक जनसंख्या वाला देश है। 69 | इसे भारत, हिन्दुस्तान और आर्यव्रत के नाम से भी जाना जाता है। ये एक प्रायद्वीप है जो पूरब में बंगाल की खाड़ी, 70 | पश्चिम में अरेबियन सागर और दक्षिण में भारतीय महासागर जैसे तीन महासगरों से घिरा हुआ है। 71 | भारत का राष्ट्रीय पशु चीता, राष्ट्रीय पक्षी मोर, राष्ट्रीय फूल कमल, और राष्ट्रीय फल आम है। 72 | भारत मेरा देश है और मुझे भारतीय होने पर गर्व है। ये विश्व का सातवाँ सबसे बड़ा और विश्व में दूसरा सबसे अधिक जनसंख्या वाला देश है। 73 | इसे भारत, हिन्दुस्तान और आर्यव्रत के नाम से भी जाना जाता है। ये एक प्रायद्वीप है जो पूरब में बंगाल की खाड़ी, 74 | पश्चिम में अरेबियन सागर और दक्षिण में भारतीय महासागर जैसे तीन महासगरों से घिरा हुआ है। 75 | भारत का राष्ट्रीय पशु चीता, राष्ट्रीय पक्षी मोर, राष्ट्रीय फूल कमल, और राष्ट्रीय फल आम है। 76 | ''' 77 | 78 | print('Num chars in paragraph: ', len(para)) 79 | _, audio_long = run_tts_paragraph(para, 'hi') 80 | -------------------------------------------------------------------------------- /utils/inference/advanced_tts.py: -------------------------------------------------------------------------------- 1 | 2 | from tts import TextToMel, MelToWav 3 | from transliterate import XlitEngine 4 | from num_to_word_on_sent import normalize_nums 5 | 6 | import re 7 | import numpy as np 8 | from scipy.io.wavfile import write 9 | 10 | from mosestokenizer import * 11 | from indicnlp.tokenize import sentence_tokenize 12 | import argparse 13 | 14 | _INDIC = ["as", "bn", "gu", "hi", "kn", "ml", "mr", "or", "pa", "ta", "te"] 15 | _PURAM_VIRAM_LANGUAGES = ["hi", "or", "bn", "as"] 16 | _TRANSLITERATION_NOT_AVAILABLE_IN = ["en","or"] 17 | #_NUM2WORDS_NOT_AVAILABLE_IN = [] 18 | 19 | def normalize_text(text, lang): 20 | if lang in _PURAM_VIRAM_LANGUAGES: 21 | text = text.replace('|', '।') 22 | text = text.replace('.', '।') 23 | return text 24 | 25 | def split_sentences(paragraph, language): 26 | if language == "en": 27 | with MosesSentenceSplitter(language) as splitter: 28 | return splitter([paragraph]) 29 | elif language in _INDIC: 30 | return sentence_tokenize.sentence_split(paragraph, lang=language) 31 | 32 | 33 | 34 | def load_models(acoustic, vocoder, device): 35 | text_to_mel = TextToMel(glow_model_dir=acoustic, device=device) 36 | mel_to_wav = MelToWav(hifi_model_dir=vocoder, device=device) 37 | return text_to_mel, mel_to_wav 38 | 39 | 40 | def translit(text, lang): 41 | reg = re.compile(r'[a-zA-Z]') 42 | words = [engine.translit_word(word, topk=1)[lang][0] if reg.match(word) else word for word in text.split()] 43 | updated_sent = ' '.join(words) 44 | return updated_sent 45 | 46 | 47 | 48 | def run_tts(text, lang, args): 49 | if lang == 'hi': 50 | text = text.replace('।', '.') # only for hindi models 51 | 52 | if lang == 'en' and text[-1] != '.': 53 | text = text + '. ' 54 | 55 | if args.number_conversion == 1 and lang!='en': 56 | print("Doing number conversion") 57 | text_num_to_word = normalize_nums(text, lang) # converting numbers to words in lang 58 | else: 59 | text_num_to_word = text 60 | 61 | 62 | if args.transliteration == 1 and lang not in _TRANSLITERATION_NOT_AVAILABLE_IN: 63 | print("Doing transliteration") 64 | text_num_to_word_and_transliterated = translit(text_num_to_word, lang) # transliterating english words to lang 65 | else: 66 | text_num_to_word_and_transliterated = text_num_to_word 67 | 68 | final_text = ' ' + text_num_to_word_and_transliterated 69 | 70 | mel = text_to_mel.generate_mel(final_text, args.noise_scale, args.length_scale) 71 | audio, sr = mel_to_wav.generate_wav(mel) 72 | return sr, audio 73 | 74 | def run_tts_paragraph(args): 75 | audio_list = [] 76 | if args.split_sentences == 1: 77 | text = normalize_text(args.text, args.lang) 78 | split_sentences_list = split_sentences(text, args.lang) 79 | 80 | for sent in split_sentences_list: 81 | sr, audio = run_tts(sent, args.lang, args) 82 | audio_list.append(audio) 83 | 84 | concatenated_audio = np.concatenate([i for i in audio_list]) 85 | if args.wav: 86 | write(filename=args.wav, rate=sr, data=concatenated_audio) 87 | return (sr, concatenated_audio) 88 | else: 89 | sr, audio = run_tts(args.text, args.lang, args) 90 | if args.wav: 91 | write(filename=args.wav, rate=sr, data=audio) 92 | return (sr, audio) 93 | 94 | 95 | def load_all_models(args): 96 | global engine 97 | if args.lang not in _TRANSLITERATION_NOT_AVAILABLE_IN: 98 | engine = XlitEngine(args.lang) # loading translit model globally 99 | 100 | global text_to_mel 101 | global mel_to_wav 102 | 103 | text_to_mel, mel_to_wav = load_models(args.acoustic, args.vocoder, args.device) 104 | 105 | try: 106 | args.noise_scale = float(args.noise_scale) 107 | args.length_scale = float(args.length_scale) 108 | except: 109 | pass 110 | 111 | print(args) 112 | 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("-a", "--acoustic", required=True, type=str) 118 | parser.add_argument("-v", "--vocoder", required=True, type=str) 119 | parser.add_argument("-d", "--device", type=str, default="cpu") 120 | parser.add_argument("-t", "--text", type=str, required=True) 121 | parser.add_argument("-w", "--wav", type=str, required=True) 122 | parser.add_argument("-n", "--noise-scale", default='0.667', type=str ) 123 | parser.add_argument("-l", "--length-scale", default='1.0', type=str) 124 | 125 | parser.add_argument("-T", "--transliteration", default=1, type=int) 126 | parser.add_argument("-N", "--number-conversion", default=1, type=int) 127 | parser.add_argument("-S", "--split-sentences", default=1, type=int) 128 | parser.add_argument("-L", "--lang", type=str, required=True) 129 | 130 | args = parser.parse_args() 131 | 132 | load_all_models(args) 133 | run_tts_paragraph(args) 134 | 135 | 136 | -------------------------------------------------------------------------------- /utils/glow/prepare_iitm_data_glow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import re 4 | import string 5 | import argparse 6 | 7 | import random 8 | random.seed(42) 9 | 10 | def replace_extra_chars(line): 11 | line = line.replace("(", "").replace( 12 | ")", "" 13 | ) # .replace('\u200d', ' ').replace('\ufeff', ' ').replace('\u200c', ' ').replace('\u200e', ' ') 14 | # line = line.replace('“', ' ').replace('”', ' ').replace(':', ' ') 15 | 16 | return line.strip() 17 | 18 | 19 | def write_txt(content, filename): 20 | with open(filename, "w+", encoding="utf-8") as f: 21 | f.write(content) 22 | 23 | 24 | def save_train_test_valid_split(annotations_txt, num_samples_valid, num_samples_test): 25 | with open(annotations_txt, encoding="utf-8") as f: 26 | all_lines = [line.strip() for line in f.readlines()] 27 | test_val_indices = random.sample( 28 | range(len(all_lines)), num_samples_valid + num_samples_test 29 | ) 30 | valid_ix = test_val_indices[:num_samples_valid] 31 | test_ix = test_val_indices[num_samples_valid:] 32 | train = [line for i, line in enumerate(all_lines) if i not in test_val_indices] 33 | valid = [line for i, line in enumerate(all_lines) if i in valid_ix] 34 | test = [line for i, line in enumerate(all_lines) if i in test_ix] 35 | 36 | print(f"Num samples in train: {len(train)}") 37 | print(f"Num samples in valid: {len(valid)}") 38 | print(f"Num samples in test: {len(test)}") 39 | 40 | out_dir_path = "/".join(annotations_txt.split("/")[:-1]) 41 | with open(os.path.join(out_dir_path, "train.txt"), "w+", encoding="utf-8") as f: 42 | for line in train: 43 | print(line, file=f) 44 | with open(os.path.join(out_dir_path, "valid.txt"), "w+", encoding="utf-8") as f: 45 | for line in valid: 46 | print(line, file=f) 47 | with open(os.path.join(out_dir_path, "test.txt"), "w+", encoding="utf-8") as f: 48 | for line in test: 49 | print(line, file=f) 50 | print(f"train, test and valid txts saved in {out_dir_path}") 51 | 52 | 53 | def save_txts_from_txt_done_data( 54 | text_path, 55 | wav_path_for_annotations_txt, 56 | out_path_for_txts, 57 | num_samples_valid, 58 | num_samples_test, 59 | ): 60 | outfile = os.path.join(out_path_for_txts, "annotations.txt") 61 | with open(text_path) as file: 62 | file_lines = file.readlines() 63 | 64 | # print(file_lines[0]) 65 | 66 | file_lines = [replace_extra_chars(line) for line in file_lines] 67 | # print(file_lines[0]) 68 | 69 | fnames, ftexts = [], [] 70 | for line in file_lines: 71 | elems = line.split('"') 72 | fnames.append(elems[0].strip()) 73 | ftexts.append(elems[1].strip()) 74 | 75 | all_chars = list(set("".join(ftexts))) 76 | punct_with_space = [i for i in all_chars if i in list(string.punctuation)] + [" "] 77 | chars = [i for i in all_chars if i not in punct_with_space if i.strip()] 78 | chars = "".join(chars) 79 | punct_with_space = "".join(punct_with_space) 80 | 81 | with open('../../config/glow/base_blank.json', 'r') as jfile: 82 | json_config = json.load(jfile) 83 | 84 | json_config["data"]["chars"] = chars 85 | json_config["data"]["punc"] = punct_with_space 86 | json_config["data"]["training_files"]=out_path_for_txts + '/train.txt' 87 | json_config["data"]["validation_files"] = out_path_for_txts + '/valid.txt' 88 | new_config_name = out_path_for_txts.split('/')[-1] 89 | with open(f'../../config/glow/{new_config_name}.json','w+') as jfile: 90 | json.dump(json_config, jfile) 91 | 92 | print(f"Characters: {chars}") 93 | print(f"Punctuation: {punct_with_space}") 94 | print(f"Config file is stored at ../../config/glow/{new_config_name}.json") 95 | 96 | outfile_f = open(outfile, "w+", encoding="utf-8") 97 | for f, t in zip(fnames, ftexts): 98 | print( 99 | os.path.join(wav_path_for_annotations_txt, f) + ".wav", 100 | t, 101 | sep="|", 102 | file=outfile_f, 103 | ) 104 | outfile_f.close() 105 | write_txt(punct_with_space, os.path.join(out_path_for_txts, "punc.txt")) 106 | write_txt(chars, os.path.join(out_path_for_txts, "chars.txt")) 107 | 108 | save_train_test_valid_split( 109 | annotations_txt=outfile, 110 | num_samples_valid=num_samples_valid, 111 | num_samples_test=num_samples_test, 112 | ) 113 | 114 | 115 | 116 | 117 | if __name__ == "__main__": 118 | 119 | 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("-i", "--text-path", type=str, required=True) 122 | parser.add_argument("-o", "--output-path", type=str, required=True) 123 | parser.add_argument("-w", "--wav-path", type=str, required=True) 124 | parser.add_argument("-v", "--valid-samples", type=int, default = 100) 125 | parser.add_argument("-t", "--test-samples", type=int, default = 10) 126 | args = parser.parse_args() 127 | 128 | save_txts_from_txt_done_data( 129 | args.text_path, 130 | args.wav_path, 131 | args.output_path, 132 | args.valid_samples, 133 | args.test_samples, 134 | ) 135 | -------------------------------------------------------------------------------- /utils/glow/prepare_iitm_data_glow_en.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import re 4 | import string 5 | import argparse 6 | import json 7 | import random 8 | random.seed(42) 9 | 10 | def replace_extra_chars(line): 11 | line = line.replace("(", "").replace( 12 | ")", "" 13 | ) # .replace('\u200d', ' ').replace('\ufeff', ' ').replace('\u200c', ' ').replace('\u200e', ' ') 14 | # line = line.replace('“', ' ').replace('”', ' ').replace(':', ' ') 15 | 16 | return line.strip() 17 | 18 | 19 | def write_txt(content, filename): 20 | with open(filename, "w+", encoding="utf-8") as f: 21 | f.write(content) 22 | 23 | 24 | def save_train_test_valid_split(annotations_txt, num_samples_valid, num_samples_test): 25 | with open(annotations_txt, encoding="utf-8") as f: 26 | all_lines = [line.strip() for line in f.readlines()] 27 | test_val_indices = random.sample( 28 | range(len(all_lines)), num_samples_valid + num_samples_test 29 | ) 30 | valid_ix = test_val_indices[:num_samples_valid] 31 | test_ix = test_val_indices[num_samples_valid:] 32 | train = [line for i, line in enumerate(all_lines) if i not in test_val_indices] 33 | valid = [line for i, line in enumerate(all_lines) if i in valid_ix] 34 | test = [line for i, line in enumerate(all_lines) if i in test_ix] 35 | 36 | print(f"Num samples in train: {len(train)}") 37 | print(f"Num samples in valid: {len(valid)}") 38 | print(f"Num samples in test: {len(test)}") 39 | 40 | out_dir_path = "/".join(annotations_txt.split("/")[:-1]) 41 | with open(os.path.join(out_dir_path, "train.txt"), "w+", encoding="utf-8") as f: 42 | for line in train: 43 | print(line, file=f) 44 | with open(os.path.join(out_dir_path, "valid.txt"), "w+", encoding="utf-8") as f: 45 | for line in valid: 46 | print(line, file=f) 47 | with open(os.path.join(out_dir_path, "test.txt"), "w+", encoding="utf-8") as f: 48 | for line in test: 49 | print(line, file=f) 50 | print(f"train, test and valid txts saved in {out_dir_path}") 51 | 52 | 53 | def save_txts_from_txt_done_data( 54 | text_path, 55 | wav_path_for_annotations_txt, 56 | out_path_for_txts, 57 | num_samples_valid, 58 | num_samples_test, 59 | ): 60 | outfile = os.path.join(out_path_for_txts, "annotations.txt") 61 | with open(text_path) as file: 62 | file_lines = file.readlines() 63 | 64 | # print(file_lines[0]) 65 | 66 | file_lines = [replace_extra_chars(line) for line in file_lines] 67 | # print(file_lines[0]) 68 | 69 | fnames, ftexts = [], [] 70 | for line in file_lines: 71 | elems = line.split('"') 72 | fnames.append(elems[0].strip()) 73 | ftexts.append(elems[1].strip().lower().replace('‘','\'').replace('’','\'')) 74 | 75 | all_chars = list(set("".join(ftexts))) 76 | punct_with_space = [i for i in all_chars if i in list(string.punctuation)] + [" "] 77 | chars = [i for i in all_chars if i not in punct_with_space if i.strip()] 78 | chars = "".join(chars) 79 | punct_with_space = "".join(punct_with_space)#.replace("'",r"\'") 80 | 81 | with open('../../config/glow/base_blank.json', 'r') as jfile: 82 | json_config = json.load(jfile) 83 | 84 | json_config["data"]["chars"] = chars 85 | json_config["data"]["punc"] = punct_with_space 86 | json_config["data"]["training_files"]=out_path_for_txts + '/train.txt' 87 | json_config["data"]["validation_files"] = out_path_for_txts + '/valid.txt' 88 | new_config_name = out_path_for_txts.split('/')[-1] 89 | with open(f'../../config/glow/{new_config_name}.json','w+') as jfile: 90 | json.dump(json_config, jfile) 91 | 92 | print(f"Characters: {chars}") 93 | print(f"Len of vocab: {len(chars)}") 94 | print(f"Punctuation: {punct_with_space}") 95 | print(f"Config file is stored at ../../config/glow/{new_config_name}.json") 96 | 97 | outfile_f = open(outfile, "w+", encoding="utf-8") 98 | for f, t in zip(fnames, ftexts): 99 | print( 100 | os.path.join(wav_path_for_annotations_txt, f) + ".wav", 101 | t, 102 | sep="|", 103 | file=outfile_f, 104 | ) 105 | outfile_f.close() 106 | write_txt(punct_with_space, os.path.join(out_path_for_txts, "punc.txt")) 107 | write_txt(chars, os.path.join(out_path_for_txts, "chars.txt")) 108 | 109 | save_train_test_valid_split( 110 | annotations_txt=outfile, 111 | num_samples_valid=num_samples_valid, 112 | num_samples_test=num_samples_test, 113 | ) 114 | 115 | 116 | 117 | 118 | if __name__ == "__main__": 119 | 120 | 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument("-i", "--text-path", type=str, required=True) 123 | parser.add_argument("-o", "--output-path", type=str, required=True) 124 | parser.add_argument("-w", "--wav-path", type=str, required=True) 125 | parser.add_argument("-v", "--valid-samples", type=int, default = 100) 126 | parser.add_argument("-t", "--test-samples", type=int, default = 10) 127 | args = parser.parse_args() 128 | 129 | save_txts_from_txt_done_data( 130 | args.text_path, 131 | args.wav_path, 132 | args.output_path, 133 | args.valid_samples, 134 | args.test_samples, 135 | ) 136 | -------------------------------------------------------------------------------- /src/glow_tts/texttospeech.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | from typing import Tuple 3 | 4 | from scipy.io.wavfile import write 5 | from hifi.env import AttrDict 6 | from hifi.models import Generator 7 | 8 | import numpy as np 9 | import os 10 | import json 11 | 12 | import torch 13 | from text import text_to_sequence 14 | import commons 15 | import models 16 | import utils 17 | import sys 18 | from argparse import ArgumentParser 19 | 20 | 21 | def check_directory(dir): 22 | if not os.path.exists(dir): 23 | sys.exit("Error: {} directory does not exist".format(dir)) 24 | 25 | 26 | class TextToMel: 27 | def __init__(self, glow_model_dir, device="cuda"): 28 | self.glow_model_dir = glow_model_dir 29 | check_directory(self.glow_model_dir) 30 | self.device = device 31 | self.hps, self.glow_tts_model = self.load_glow_tts() 32 | pass 33 | 34 | def load_glow_tts(self): 35 | hps = utils.get_hparams_from_dir(self.glow_model_dir) 36 | checkpoint_path = utils.latest_checkpoint_path(self.glow_model_dir) 37 | symbols = list(hps.data.punc) + list(hps.data.chars) 38 | glow_tts_model = models.FlowGenerator( 39 | len(symbols) + getattr(hps.data, "add_blank", False), 40 | out_channels=hps.data.n_mel_channels, 41 | **hps.model 42 | ) # .to(self.device) 43 | 44 | if self.device == "cuda": 45 | glow_tts_model.to("cuda") 46 | 47 | utils.load_checkpoint(checkpoint_path, glow_tts_model) 48 | glow_tts_model.decoder.store_inverse() 49 | _ = glow_tts_model.eval() 50 | 51 | return hps, glow_tts_model 52 | 53 | def generate_mel(self, text, noise_scale=0.667, length_scale=1.0): 54 | symbols = list(self.hps.data.punc) + list(self.hps.data.chars) 55 | cleaner = self.hps.data.text_cleaners 56 | if getattr(self.hps.data, "add_blank", False): 57 | text_norm = text_to_sequence(text, symbols, cleaner) 58 | text_norm = commons.intersperse(text_norm, len(symbols)) 59 | else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality 60 | text = " " + text.strip() + " " 61 | text_norm = text_to_sequence(text, symbols, cleaner) 62 | 63 | sequence = np.array(text_norm)[None, :] 64 | 65 | if self.device == "cuda": 66 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() 67 | x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() 68 | else: 69 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).long() 70 | x_tst_lengths = torch.tensor([x_tst.shape[1]]) 71 | 72 | with torch.no_grad(): 73 | (y_gen_tst, *_), *_, (attn_gen, *_) = self.glow_tts_model( 74 | x_tst, 75 | x_tst_lengths, 76 | gen=True, 77 | noise_scale=noise_scale, 78 | length_scale=length_scale, 79 | ) 80 | 81 | return y_gen_tst 82 | #return y_gen_tst.cpu().detach().numpy() 83 | 84 | 85 | class MelToWav: 86 | def __init__(self, hifi_model_dir, device="cuda"): 87 | self.hifi_model_dir = hifi_model_dir 88 | check_directory(self.hifi_model_dir) 89 | self.device = device 90 | self.h, self.hifi_gan_generator = self.load_hifi_gan() 91 | pass 92 | 93 | def load_hifi_gan(self): 94 | checkpoint_path = utils.latest_checkpoint_path(self.hifi_model_dir, regex="g_*") 95 | config_file = os.path.join(self.hifi_model_dir, "config.json") 96 | data = open(config_file).read() 97 | json_config = json.loads(data) 98 | h = AttrDict(json_config) 99 | torch.manual_seed(h.seed) 100 | 101 | generator = Generator(h).to(self.device) 102 | 103 | assert os.path.isfile(checkpoint_path) 104 | print("Loading '{}'".format(checkpoint_path)) 105 | state_dict_g = torch.load(checkpoint_path, map_location=self.device) 106 | print("Complete.") 107 | 108 | generator.load_state_dict(state_dict_g["generator"]) 109 | 110 | generator.eval() 111 | generator.remove_weight_norm() 112 | 113 | return h, generator 114 | 115 | def generate_wav(self, mel): 116 | #mel = torch.FloatTensor(mel).to(self.device) 117 | 118 | y_g_hat = self.hifi_gan_generator(mel.to(self.device)) # passing through vocoder 119 | audio = y_g_hat.squeeze() 120 | audio = audio * 32768.0 121 | audio = audio.cpu().detach().numpy().astype("int16") 122 | 123 | return audio, self.h.sampling_rate 124 | 125 | 126 | 127 | 128 | 129 | if __name__ == "__main__": 130 | 131 | parser = ArgumentParser() 132 | parser.add_argument("-m", "--model", required=True, type=str) 133 | parser.add_argument("-g", "--gan", required=True, type=str) 134 | parser.add_argument("-d", "--device", type=str, default="cpu") 135 | parser.add_argument("-t", "--text", type=str, required=True) 136 | parser.add_argument("-w", "--wav", type=str, required=True) 137 | 138 | args = parser.parse_args() 139 | 140 | text_to_mel = TextToMel(glow_model_dir=args.model, device=args.device) 141 | mel_to_wav = MelToWav(hifi_model_dir=args.gan, device=args.device) 142 | 143 | mel = text_to_mel.generate_mel(args.text) 144 | audio, sr = mel_to_wav.generate_wav(mel) 145 | 146 | write(filename=args.wav, rate=sr, data=audio) -------------------------------------------------------------------------------- /tts_infer/tts.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | from typing import Tuple 3 | import sys 4 | from argparse import ArgumentParser 5 | 6 | import torch 7 | import numpy as np 8 | import os 9 | import json 10 | import torch 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), "../src/glow_tts")) 13 | 14 | from scipy.io.wavfile import write 15 | from hifi.env import AttrDict 16 | from hifi.models import Generator 17 | 18 | 19 | from text import text_to_sequence 20 | import commons 21 | import models 22 | import utils 23 | 24 | 25 | def check_directory(dir): 26 | if not os.path.exists(dir): 27 | sys.exit("Error: {} directory does not exist".format(dir)) 28 | 29 | 30 | class TextToMel: 31 | def __init__(self, glow_model_dir, device="cuda"): 32 | self.glow_model_dir = glow_model_dir 33 | check_directory(self.glow_model_dir) 34 | self.device = device 35 | self.hps, self.glow_tts_model = self.load_glow_tts() 36 | pass 37 | 38 | def load_glow_tts(self): 39 | hps = utils.get_hparams_from_dir(self.glow_model_dir) 40 | checkpoint_path = utils.latest_checkpoint_path(self.glow_model_dir) 41 | symbols = list(hps.data.punc) + list(hps.data.chars) 42 | glow_tts_model = models.FlowGenerator( 43 | len(symbols) + getattr(hps.data, "add_blank", False), 44 | out_channels=hps.data.n_mel_channels, 45 | **hps.model 46 | ) # .to(self.device) 47 | 48 | if self.device == "cuda": 49 | glow_tts_model.to("cuda") 50 | 51 | utils.load_checkpoint(checkpoint_path, glow_tts_model) 52 | glow_tts_model.decoder.store_inverse() 53 | _ = glow_tts_model.eval() 54 | 55 | return hps, glow_tts_model 56 | 57 | def generate_mel(self, text, noise_scale=0.667, length_scale=1.0): 58 | symbols = list(self.hps.data.punc) + list(self.hps.data.chars) 59 | cleaner = self.hps.data.text_cleaners 60 | if getattr(self.hps.data, "add_blank", False): 61 | text_norm = text_to_sequence(text, symbols, cleaner) 62 | text_norm = commons.intersperse(text_norm, len(symbols)) 63 | else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality 64 | text = " " + text.strip() + " " 65 | text_norm = text_to_sequence(text, symbols, cleaner) 66 | 67 | sequence = np.array(text_norm)[None, :] 68 | 69 | del symbols 70 | del cleaner 71 | del text 72 | del text_norm 73 | 74 | if self.device == "cuda": 75 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() 76 | x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() 77 | else: 78 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).long() 79 | x_tst_lengths = torch.tensor([x_tst.shape[1]]) 80 | 81 | with torch.no_grad(): 82 | (y_gen_tst, *_), *_, (attn_gen, *_) = self.glow_tts_model( 83 | x_tst, 84 | x_tst_lengths, 85 | gen=True, 86 | noise_scale=noise_scale, 87 | length_scale=length_scale, 88 | ) 89 | del x_tst 90 | del x_tst_lengths 91 | torch.cuda.empty_cache() 92 | return y_gen_tst 93 | #return y_gen_tst.cpu().detach().numpy() 94 | 95 | 96 | class MelToWav: 97 | def __init__(self, hifi_model_dir, device="cuda"): 98 | self.hifi_model_dir = hifi_model_dir 99 | check_directory(self.hifi_model_dir) 100 | self.device = device 101 | self.h, self.hifi_gan_generator = self.load_hifi_gan() 102 | pass 103 | 104 | def load_hifi_gan(self): 105 | checkpoint_path = utils.latest_checkpoint_path(self.hifi_model_dir, regex="g_*") 106 | config_file = os.path.join(self.hifi_model_dir, "config.json") 107 | data = open(config_file).read() 108 | json_config = json.loads(data) 109 | h = AttrDict(json_config) 110 | torch.manual_seed(h.seed) 111 | 112 | generator = Generator(h).to(self.device) 113 | 114 | assert os.path.isfile(checkpoint_path) 115 | print("Loading '{}'".format(checkpoint_path)) 116 | state_dict_g = torch.load(checkpoint_path, map_location=self.device) 117 | print("Complete.") 118 | 119 | generator.load_state_dict(state_dict_g["generator"]) 120 | 121 | generator.eval() 122 | generator.remove_weight_norm() 123 | 124 | return h, generator 125 | 126 | def generate_wav(self, mel): 127 | #mel = torch.FloatTensor(mel).to(self.device) 128 | 129 | y_g_hat = self.hifi_gan_generator(mel.to(self.device)) # passing through vocoder 130 | audio = y_g_hat.squeeze() 131 | audio = audio * 32768.0 132 | audio = audio.cpu().detach().numpy().astype("int16") 133 | 134 | del y_g_hat 135 | del mel 136 | torch.cuda.empty_cache() 137 | return audio, self.h.sampling_rate 138 | 139 | 140 | if __name__ == "__main__": 141 | 142 | parser = ArgumentParser() 143 | parser.add_argument("-m", "--model", required=True, type=str) 144 | parser.add_argument("-g", "--gan", required=True, type=str) 145 | parser.add_argument("-d", "--device", type=str, default="cpu") 146 | parser.add_argument("-t", "--text", type=str, required=True) 147 | parser.add_argument("-w", "--wav", type=str, required=True) 148 | args = parser.parse_args() 149 | 150 | text_to_mel = TextToMel(glow_model_dir=args.model, device=args.device) 151 | mel_to_wav = MelToWav(hifi_model_dir=args.gan, device=args.device) 152 | 153 | mel = text_to_mel.generate_mel(args.text) 154 | audio, sr = mel_to_wav.generate_wav(mel) 155 | 156 | write(filename=args.wav, rate=sr, data=audio) 157 | 158 | pass 159 | -------------------------------------------------------------------------------- /utils/inference/tts.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | from typing import Tuple 3 | import sys 4 | from argparse import ArgumentParser 5 | 6 | import torch 7 | import numpy as np 8 | import os 9 | import json 10 | import torch 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), "../../src/glow_tts")) 13 | 14 | from scipy.io.wavfile import write 15 | from hifi.env import AttrDict 16 | from hifi.models import Generator 17 | 18 | 19 | from text import text_to_sequence 20 | import commons 21 | import models 22 | import utils 23 | 24 | 25 | def check_directory(dir): 26 | if not os.path.exists(dir): 27 | sys.exit("Error: {} directory does not exist".format(dir)) 28 | 29 | 30 | class TextToMel: 31 | def __init__(self, glow_model_dir, device="cuda"): 32 | self.glow_model_dir = glow_model_dir 33 | check_directory(self.glow_model_dir) 34 | self.device = device 35 | self.hps, self.glow_tts_model = self.load_glow_tts() 36 | 37 | def load_glow_tts(self): 38 | hps = utils.get_hparams_from_dir(self.glow_model_dir) 39 | checkpoint_path = utils.latest_checkpoint_path(self.glow_model_dir) 40 | symbols = list(hps.data.punc) + list(hps.data.chars) 41 | glow_tts_model = models.FlowGenerator( 42 | len(symbols) + getattr(hps.data, "add_blank", False), 43 | out_channels=hps.data.n_mel_channels, 44 | **hps.model 45 | ) # .to(self.device) 46 | 47 | if self.device == "cuda": 48 | glow_tts_model.to("cuda") 49 | 50 | utils.load_checkpoint(checkpoint_path, glow_tts_model) 51 | glow_tts_model.decoder.store_inverse() 52 | _ = glow_tts_model.eval() 53 | 54 | return hps, glow_tts_model 55 | 56 | def generate_mel(self, text, noise_scale=0.667, length_scale=1.0): 57 | print(f"Noise scale: {noise_scale} and Length scale: {length_scale}") 58 | symbols = list(self.hps.data.punc) + list(self.hps.data.chars) 59 | cleaner = self.hps.data.text_cleaners 60 | if getattr(self.hps.data, "add_blank", False): 61 | text_norm = text_to_sequence(text, symbols, cleaner) 62 | text_norm = commons.intersperse(text_norm, len(symbols)) 63 | else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality 64 | text = " " + text.strip() + " " 65 | text_norm = text_to_sequence(text, symbols, cleaner) 66 | 67 | sequence = np.array(text_norm)[None, :] 68 | 69 | del symbols 70 | del cleaner 71 | del text 72 | del text_norm 73 | 74 | if self.device == "cuda": 75 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() 76 | x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() 77 | else: 78 | x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).long() 79 | x_tst_lengths = torch.tensor([x_tst.shape[1]]) 80 | 81 | with torch.no_grad(): 82 | (y_gen_tst, *_), *_, (attn_gen, *_) = self.glow_tts_model( 83 | x_tst, 84 | x_tst_lengths, 85 | gen=True, 86 | noise_scale=noise_scale, 87 | length_scale=length_scale, 88 | ) 89 | del x_tst 90 | del x_tst_lengths 91 | torch.cuda.empty_cache() 92 | return y_gen_tst.cpu().detach().numpy() 93 | 94 | 95 | class MelToWav: 96 | def __init__(self, hifi_model_dir, device="cuda"): 97 | self.hifi_model_dir = hifi_model_dir 98 | check_directory(self.hifi_model_dir) 99 | self.device = device 100 | self.h, self.hifi_gan_generator = self.load_hifi_gan() 101 | 102 | def load_hifi_gan(self): 103 | checkpoint_path = utils.latest_checkpoint_path(self.hifi_model_dir, regex="g_*") 104 | config_file = os.path.join(self.hifi_model_dir, "config.json") 105 | data = open(config_file).read() 106 | json_config = json.loads(data) 107 | h = AttrDict(json_config) 108 | torch.manual_seed(h.seed) 109 | 110 | generator = Generator(h).to(self.device) 111 | 112 | assert os.path.isfile(checkpoint_path) 113 | print("Loading '{}'".format(checkpoint_path)) 114 | state_dict_g = torch.load(checkpoint_path, map_location=self.device) 115 | print("Complete.") 116 | 117 | generator.load_state_dict(state_dict_g["generator"]) 118 | 119 | generator.eval() 120 | generator.remove_weight_norm() 121 | 122 | return h, generator 123 | 124 | def generate_wav(self, mel): 125 | mel = torch.FloatTensor(mel).to(self.device) 126 | 127 | y_g_hat = self.hifi_gan_generator(mel) # passing through vocoder 128 | audio = y_g_hat.squeeze() 129 | audio = audio * 32768.0 130 | audio = audio.cpu().detach().numpy().astype("int16") 131 | 132 | del y_g_hat 133 | del mel 134 | torch.cuda.empty_cache() 135 | return audio, self.h.sampling_rate 136 | 137 | def restricted_float(x): 138 | try: 139 | x = float(x) 140 | except ValueError: 141 | raise argparse.ArgumentTypeError("%r not a floating-point literal" % (x,)) 142 | 143 | if x < 0.0 or x > 1.0: 144 | raise argparse.ArgumentTypeError("%r not in range [0.0, 1.0]"%(x,)) 145 | return x 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = ArgumentParser() 150 | parser.add_argument("-a", "--acoustic", required=True, type=str) 151 | parser.add_argument("-v", "--vocoder", required=True, type=str) 152 | parser.add_argument("-d", "--device", type=str, default="cpu") 153 | parser.add_argument("-t", "--text", type=str, required=True) 154 | parser.add_argument("-w", "--wav", type=str, required=True) 155 | parser.add_argument("-n", "--noise-scale", default=0.667, type=restricted_float ) 156 | parser.add_argument("-l", "--length-scale", default=1.0, type=float) 157 | 158 | args = parser.parse_args() 159 | 160 | text_to_mel = TextToMel(glow_model_dir=args.acoustic, device=args.device) 161 | mel_to_wav = MelToWav(hifi_model_dir=args.vocoder, device=args.device) 162 | 163 | mel = text_to_mel.generate_mel(args.text, args.noise_scale, args.length_scale) 164 | audio, sr = mel_to_wav.generate_wav(mel) 165 | 166 | write(filename=args.wav, rate=sr, data=audio) 167 | 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vakyansh-tts 2 | Text to Speech for Indic languages 3 | 4 | ## 1. Installation and Setup for training 5 | 6 | Clone repo 7 | Note : for multspeaker glow-tts training use branch [multispeaker](https://github.com/Open-Speech-EkStep/vakyansh-tts/tree/multispeaker) 8 | ``` 9 | git clone https://github.com/Open-Speech-EkStep/vakyansh-tts 10 | ``` 11 | Build conda virtual environment 12 | ``` 13 | cd ./vakyansh-tts 14 | conda create --name python=3.7 15 | conda activate 16 | pip install -r requirements.txt 17 | ``` 18 | Install [apex](https://github.com/NVIDIA/apex); commit: 37cdaf4 for Mixed-precision training 19 | 20 | Note : used only for glow-tts 21 | ``` 22 | cd .. 23 | git clone https://github.com/NVIDIA/apex 24 | cd apex 25 | git checkout 37cdaf4 26 | pip install -v --disable-pip-version-check --no-cache-dir ./ 27 | cd ../vakyansh-tts 28 | ``` 29 | Build Monotonic Alignment Search Code (Cython) 30 | 31 | Note : used only for glow-tts 32 | ``` 33 | bash install.sh 34 | ``` 35 | 36 | ## 2. Data Resampling 37 | 38 | The data format should have a folder containing all the .wav files for glow-tts and a text file containing filenames with their sentences. 39 | 40 | Directory structure: 41 | 42 | langauge_folder_name 43 | ``` 44 | language_folder_name 45 | |-- ./wav/*.wav 46 | |-- ./text_file_name.txt 47 | ``` 48 | The format for text_file_name.txt (Text file is only needed for glow-tts training) 49 | 50 | ``` 51 | ( audio1.wav "Sentence1." ) 52 | ( audio2.wav "Sentence2." ) 53 | ``` 54 | 55 | To resample the .wav files to 22050 sample rate, change the following parameters in the vakyansh-tts/scripts/data/resample.sh 56 | 57 | ``` 58 | input_wav_path : absolute path to wav file folder in vakyansh_tts/data/ 59 | output_wav_path : absolute path to vakyansh_tts/data/resampled_wav_folder_name 60 | output_sample_rate : 22050 (or any other desired sample rate) 61 | ``` 62 | 63 | To run: 64 | ```bash 65 | cd scripts/data/ 66 | bash resample.sh 67 | ``` 68 | 69 | 70 | ## 3. Spectogram Training (glow-tts) 71 | 72 | ### 3.1 Data Preparation 73 | 74 | 75 | To prepare the data edit the vakyansh-tts/scripts/glow/prepare_data.sh file and change the following parameters 76 | ``` 77 | input_text_path : absolute path to vakyansh_tts/data/text_file_name.txt 78 | input_wav_path : absolute path to vakyansh_tts/data/resampled_wav_folder_name 79 | gender : female or male voice 80 | ``` 81 | To run: 82 | ```bash 83 | cd scripts/glow/ 84 | bash prepare_data.sh 85 | ``` 86 | ### 3.2 Training glow-tts 87 | 88 | To start the spectogram-training edit the vakyansh-tts/scripts/glow/train_glow.sh file and change the following parameter: 89 | ``` 90 | gender : female or male voice 91 | ``` 92 | Make sure that the gender is same as that of the prepare_data.sh file 93 | 94 | To start the training, run: 95 | ```bash 96 | cd scripts/glow/ 97 | bash train_glow.sh 98 | ``` 99 | ## 4. Vocoder Training (hifi-gan) 100 | 101 | ### 4.1 Data Preparation 102 | 103 | To prepare the data edit the vakyansh-tts/scripts/hifi/prepare_data.sh file and change the following parameters 104 | ``` 105 | input_wav_path : absolute path to vakyansh_tts/data/resampled_wav_folder_name 106 | gender : female or male voice 107 | ``` 108 | To run: 109 | ```bash 110 | cd scripts/hifi/ 111 | bash prepare_data.sh 112 | ``` 113 | ### 4.2 Training hifi-gan 114 | 115 | To start the spectogram-training edit the vakyansh-tts/scripts/hifi/train_hifi.sh file and change the following parameter: 116 | ``` 117 | gender : female or male voice 118 | ``` 119 | Make sure that the gender is same as that of the prepare_data.sh file 120 | 121 | To start the training, run: 122 | ```bash 123 | cd scripts/hifi/ 124 | bash train_hifi.sh 125 | ``` 126 | 127 | ## 5. Inference 128 | 129 | ### 5.1 Using Gradio 130 | 131 | To use the gradio link edit the following parameters in the vakyansh-tts/scripts/inference/gradio.sh file: 132 | ``` 133 | gender : female or male voice 134 | device : cpu or cuda 135 | lang : langauge code 136 | ``` 137 | 138 | To run: 139 | ```bash 140 | cd scripts/inference/ 141 | bash gradio.sh 142 | ``` 143 | ### 5.2 Using fast API 144 | To use the fast api link edit the parameters in the vakyansh-tts/scripts/inference/api.sh file similar to section 5.1 145 | 146 | To run: 147 | ```bash 148 | cd scripts/inference/ 149 | bash api.sh 150 | ``` 151 | 152 | ### 5.3 Direct Inference using text 153 | To infer, edit the parameters in the vakyansh-tts/scripts/inference/infer.sh file similar to section 5.1 and set the text to the text variable 154 | 155 | To run: 156 | ```bash 157 | cd scripts/inference/ 158 | bash infer.sh 159 | ``` 160 | 161 | To configure other parameters there is a version that runs the advanced inference as well. Additional Parameters: 162 | ``` 163 | noise_scale : can vary from 0 to 1 for noise factor 164 | length_scale : can vary from 0 to 2 for changing the speed of the generated audio 165 | transliteration : whether to switch on/off transliteration. 1: ON, 0: OFF 166 | number_conversion : whether to switch on/off number to words conversion. 1: ON, 0: OFF 167 | split_sentences : whether to switch on/off splitting of sentences. 1: ON, 0: OFF 168 | ``` 169 | To run: 170 | ``` 171 | cd scripts/inference/ 172 | bash advanced_infer.sh 173 | ``` 174 | 175 | ### 5.4 Installation of tts_infer package 176 | 177 | In tts_infer package, we currently have two components: 178 | 179 | 1. Transliteration (AI4bharat's open sourced models) (Languages supported: {'hi', 'gu', 'mr', 'bn', 'te', 'ta', 'kn', 'pa', 'gom', 'mai', 'ml', 'sd', 'si', 'ur'} ) 180 | 181 | 2. Num to Word (Languages supported: {'en', 'hi', 'gu', 'mr', 'bn', 'te', 'ta', 'kn', 'or', 'pa'} ) 182 | ``` 183 | git clone https://github.com/Open-Speech-EkStep/vakyansh-tts 184 | cd vakyansh-tts 185 | bash install.sh 186 | python setup.py bdist_wheel 187 | pip install -e . 188 | cd tts_infer 189 | wget https://storage.googleapis.com/vakyansh-open-models/translit_models.zip && unzip -q translit_models.zip 190 | ``` 191 | 192 | Usage: Refer to example file in tts_infer/ 193 | ``` 194 | from tts_infer.tts import TextToMel, MelToWav 195 | from tts_infer.transliterate import XlitEngine 196 | from tts_infer.num_to_word_on_sent import normalize_nums 197 | 198 | import re 199 | from scipy.io.wavfile import write 200 | 201 | text_to_mel = TextToMel(glow_model_dir='/path/to/glow-tts/checkpoint/dir', device='cuda') 202 | mel_to_wav = MelToWav(hifi_model_dir='/path/to/hifi/checkpoint/dir', device='cuda') 203 | 204 | def translit(text, lang): 205 | reg = re.compile(r'[a-zA-Z]') 206 | engine = XlitEngine(lang) 207 | words = [engine.translit_word(word, topk=1)[lang][0] if reg.match(word) else word for word in text.split()] 208 | updated_sent = ' '.join(words) 209 | return updated_sent 210 | 211 | def run_tts(text, lang): 212 | text = text.replace('।', '.') # only for hindi models 213 | text_num_to_word = normalize_nums(text, lang) # converting numbers to words in lang 214 | text_num_to_word_and_transliterated = translit(text_num_to_word, lang) # transliterating english words to lang 215 | 216 | mel = text_to_mel.generate_mel(text_num_to_word_and_transliterated) 217 | audio, sr = mel_to_wav.generate_wav(mel) 218 | write(filename='temp.wav', rate=sr, data=audio) # for saving wav file, if needed 219 | return (sr, audio) 220 | ``` 221 | -------------------------------------------------------------------------------- /src/hifi_gan/meldataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | import numpy as np 7 | from librosa.util import normalize 8 | from scipy.io.wavfile import read 9 | from librosa.filters import mel as librosa_mel_fn 10 | 11 | MAX_WAV_VALUE = 32768.0 12 | 13 | 14 | def load_wav(full_path): 15 | sampling_rate, data = read(full_path) 16 | return data, sampling_rate 17 | 18 | 19 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 20 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 21 | 22 | 23 | def dynamic_range_decompression(x, C=1): 24 | return np.exp(x) / C 25 | 26 | 27 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 28 | return torch.log(torch.clamp(x, min=clip_val) * C) 29 | 30 | 31 | def dynamic_range_decompression_torch(x, C=1): 32 | return torch.exp(x) / C 33 | 34 | 35 | def spectral_normalize_torch(magnitudes): 36 | output = dynamic_range_compression_torch(magnitudes) 37 | return output 38 | 39 | 40 | def spectral_de_normalize_torch(magnitudes): 41 | output = dynamic_range_decompression_torch(magnitudes) 42 | return output 43 | 44 | 45 | mel_basis = {} 46 | hann_window = {} 47 | 48 | 49 | def mel_spectrogram( 50 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 51 | ): 52 | if torch.min(y) < -1.0: 53 | print("min value is ", torch.min(y)) 54 | if torch.max(y) > 1.0: 55 | print("max value is ", torch.max(y)) 56 | 57 | global mel_basis, hann_window 58 | if fmax not in mel_basis: 59 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 60 | mel_basis[str(fmax) + "_" + str(y.device)] = ( 61 | torch.from_numpy(mel).float().to(y.device) 62 | ) 63 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 64 | 65 | y = torch.nn.functional.pad( 66 | y.unsqueeze(1), 67 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 68 | mode="reflect", 69 | ) 70 | y = y.squeeze(1) 71 | 72 | spec = torch.stft( 73 | y, 74 | n_fft, 75 | hop_length=hop_size, 76 | win_length=win_size, 77 | window=hann_window[str(y.device)], 78 | center=center, 79 | pad_mode="reflect", 80 | normalized=False, 81 | onesided=True, 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 85 | 86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 87 | spec = spectral_normalize_torch(spec) 88 | 89 | return spec 90 | 91 | 92 | def get_dataset_filelist(a): 93 | with open(a.input_training_file, "r", encoding="utf-8") as fi: 94 | training_files = [x for x in fi.read().split("\n") if len(x) > 0] 95 | 96 | with open(a.input_validation_file, "r", encoding="utf-8") as fi: 97 | validation_files = [x for x in fi.read().split("\n") if len(x) > 0] 98 | return training_files, validation_files 99 | 100 | 101 | class MelDataset(torch.utils.data.Dataset): 102 | def __init__( 103 | self, 104 | training_files, 105 | segment_size, 106 | n_fft, 107 | num_mels, 108 | hop_size, 109 | win_size, 110 | sampling_rate, 111 | fmin, 112 | fmax, 113 | split=True, 114 | shuffle=True, 115 | n_cache_reuse=1, 116 | device=None, 117 | fmax_loss=None, 118 | fine_tuning=False, 119 | base_mels_path=None, 120 | ): 121 | self.audio_files = training_files 122 | random.seed(1234) 123 | if shuffle: 124 | random.shuffle(self.audio_files) 125 | self.segment_size = segment_size 126 | self.sampling_rate = sampling_rate 127 | self.split = split 128 | self.n_fft = n_fft 129 | self.num_mels = num_mels 130 | self.hop_size = hop_size 131 | self.win_size = win_size 132 | self.fmin = fmin 133 | self.fmax = fmax 134 | self.fmax_loss = fmax_loss 135 | self.cached_wav = None 136 | self.n_cache_reuse = n_cache_reuse 137 | self._cache_ref_count = 0 138 | self.device = device 139 | self.fine_tuning = fine_tuning 140 | self.base_mels_path = base_mels_path 141 | 142 | def __getitem__(self, index): 143 | filename = self.audio_files[index] 144 | if self._cache_ref_count == 0: 145 | audio, sampling_rate = load_wav(filename) 146 | audio = audio / MAX_WAV_VALUE 147 | if not self.fine_tuning: 148 | audio = normalize(audio) * 0.95 149 | self.cached_wav = audio 150 | if sampling_rate != self.sampling_rate: 151 | raise ValueError( 152 | "{} SR doesn't match target {} SR".format( 153 | sampling_rate, self.sampling_rate 154 | ) 155 | ) 156 | self._cache_ref_count = self.n_cache_reuse 157 | else: 158 | audio = self.cached_wav 159 | self._cache_ref_count -= 1 160 | 161 | audio = torch.FloatTensor(audio) 162 | audio = audio.unsqueeze(0) 163 | 164 | if not self.fine_tuning: 165 | if self.split: 166 | if audio.size(1) >= self.segment_size: 167 | max_audio_start = audio.size(1) - self.segment_size 168 | audio_start = random.randint(0, max_audio_start) 169 | audio = audio[:, audio_start : audio_start + self.segment_size] 170 | else: 171 | audio = torch.nn.functional.pad( 172 | audio, (0, self.segment_size - audio.size(1)), "constant" 173 | ) 174 | 175 | mel = mel_spectrogram( 176 | audio, 177 | self.n_fft, 178 | self.num_mels, 179 | self.sampling_rate, 180 | self.hop_size, 181 | self.win_size, 182 | self.fmin, 183 | self.fmax, 184 | center=False, 185 | ) 186 | else: 187 | mel = np.load( 188 | os.path.join( 189 | self.base_mels_path, 190 | os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", 191 | ) 192 | ) 193 | mel = torch.from_numpy(mel) 194 | 195 | if len(mel.shape) < 3: 196 | mel = mel.unsqueeze(0) 197 | 198 | if self.split: 199 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 200 | 201 | if audio.size(1) >= self.segment_size: 202 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 203 | mel = mel[:, :, mel_start : mel_start + frames_per_seg] 204 | audio = audio[ 205 | :, 206 | mel_start 207 | * self.hop_size : (mel_start + frames_per_seg) 208 | * self.hop_size, 209 | ] 210 | else: 211 | mel = torch.nn.functional.pad( 212 | mel, (0, frames_per_seg - mel.size(2)), "constant" 213 | ) 214 | audio = torch.nn.functional.pad( 215 | audio, (0, self.segment_size - audio.size(1)), "constant" 216 | ) 217 | 218 | mel_loss = mel_spectrogram( 219 | audio, 220 | self.n_fft, 221 | self.num_mels, 222 | self.sampling_rate, 223 | self.hop_size, 224 | self.win_size, 225 | self.fmin, 226 | self.fmax_loss, 227 | center=False, 228 | ) 229 | 230 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 231 | 232 | def __len__(self): 233 | return len(self.audio_files) 234 | -------------------------------------------------------------------------------- /src/glow_tts/stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from librosa import stft, istft 40 | from audio_processing import window_sumsquare 41 | 42 | 43 | class STFT(torch.nn.Module): 44 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 45 | 46 | def __init__( 47 | self, filter_length=800, hop_length=200, win_length=800, window="hann" 48 | ): 49 | super(STFT, self).__init__() 50 | self.filter_length = filter_length 51 | self.hop_length = hop_length 52 | self.win_length = win_length 53 | self.window = window 54 | self.forward_transform = None 55 | scale = self.filter_length / self.hop_length 56 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 57 | 58 | cutoff = int((self.filter_length / 2 + 1)) 59 | fourier_basis = np.vstack( 60 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 61 | ) 62 | 63 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 64 | inverse_basis = torch.FloatTensor( 65 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 66 | ) 67 | 68 | if window is not None: 69 | assert filter_length >= win_length 70 | # get window and zero center pad it to filter_length 71 | fft_window = get_window(window, win_length, fftbins=True) 72 | fft_window = pad_center(fft_window, filter_length) 73 | fft_window = torch.from_numpy(fft_window).float() 74 | 75 | # window the bases 76 | forward_basis *= fft_window 77 | inverse_basis *= fft_window 78 | 79 | self.register_buffer("forward_basis", forward_basis.float()) 80 | self.register_buffer("inverse_basis", inverse_basis.float()) 81 | 82 | def transform(self, input_data): 83 | num_batches = input_data.size(0) 84 | num_samples = input_data.size(1) 85 | 86 | self.num_samples = num_samples 87 | 88 | if input_data.device.type == "cuda": 89 | # similar to librosa, reflect-pad the input 90 | input_data = input_data.view(num_batches, 1, num_samples) 91 | input_data = F.pad( 92 | input_data.unsqueeze(1), 93 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 94 | mode="reflect", 95 | ) 96 | input_data = input_data.squeeze(1) 97 | 98 | forward_transform = F.conv1d( 99 | input_data, self.forward_basis, stride=self.hop_length, padding=0 100 | ) 101 | 102 | cutoff = int((self.filter_length / 2) + 1) 103 | real_part = forward_transform[:, :cutoff, :] 104 | imag_part = forward_transform[:, cutoff:, :] 105 | else: 106 | x = input_data.detach().numpy() 107 | real_part = [] 108 | imag_part = [] 109 | for y in x: 110 | y_ = stft( 111 | y, self.filter_length, self.hop_length, self.win_length, self.window 112 | ) 113 | real_part.append(y_.real[None, :, :]) 114 | imag_part.append(y_.imag[None, :, :]) 115 | real_part = np.concatenate(real_part, 0) 116 | imag_part = np.concatenate(imag_part, 0) 117 | 118 | real_part = torch.from_numpy(real_part).to(input_data.dtype) 119 | imag_part = torch.from_numpy(imag_part).to(input_data.dtype) 120 | 121 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) 122 | phase = torch.atan2(imag_part.data, real_part.data) 123 | 124 | return magnitude, phase 125 | 126 | def inverse(self, magnitude, phase): 127 | recombine_magnitude_phase = torch.cat( 128 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 129 | ) 130 | 131 | if magnitude.device.type == "cuda": 132 | inverse_transform = F.conv_transpose1d( 133 | recombine_magnitude_phase, 134 | self.inverse_basis, 135 | stride=self.hop_length, 136 | padding=0, 137 | ) 138 | 139 | if self.window is not None: 140 | window_sum = window_sumsquare( 141 | self.window, 142 | magnitude.size(-1), 143 | hop_length=self.hop_length, 144 | win_length=self.win_length, 145 | n_fft=self.filter_length, 146 | dtype=np.float32, 147 | ) 148 | # remove modulation effects 149 | approx_nonzero_indices = torch.from_numpy( 150 | np.where(window_sum > tiny(window_sum))[0] 151 | ) 152 | window_sum = torch.from_numpy(window_sum).to(inverse_transform.device) 153 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 154 | approx_nonzero_indices 155 | ] 156 | 157 | # scale by hop ratio 158 | inverse_transform *= float(self.filter_length) / self.hop_length 159 | 160 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 161 | inverse_transform = inverse_transform[ 162 | :, :, : -int(self.filter_length / 2) : 163 | ] 164 | inverse_transform = inverse_transform.squeeze(1) 165 | else: 166 | x_org = recombine_magnitude_phase.detach().numpy() 167 | n_b, n_f, n_t = x_org.shape 168 | x = np.empty([n_b, n_f // 2, n_t], dtype=np.complex64) 169 | x.real = x_org[:, : n_f // 2] 170 | x.imag = x_org[:, n_f // 2 :] 171 | inverse_transform = [] 172 | for y in x: 173 | y_ = istft(y, self.hop_length, self.win_length, self.window) 174 | inverse_transform.append(y_[None, :]) 175 | inverse_transform = np.concatenate(inverse_transform, 0) 176 | inverse_transform = torch.from_numpy(inverse_transform).to( 177 | recombine_magnitude_phase.dtype 178 | ) 179 | 180 | return inverse_transform 181 | 182 | def forward(self, input_data): 183 | self.magnitude, self.phase = self.transform(input_data) 184 | reconstruction = self.inverse(self.magnitude, self.phase) 185 | return reconstruction 186 | -------------------------------------------------------------------------------- /src/glow_tts/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from librosa.filters import mel as librosa_mel_fn 8 | from audio_processing import dynamic_range_compression 9 | from audio_processing import dynamic_range_decompression 10 | from stft import STFT 11 | 12 | 13 | def intersperse(lst, item): 14 | result = [item] * (len(lst) * 2 + 1) 15 | result[1::2] = lst 16 | return result 17 | 18 | 19 | def mle_loss(z, m, logs, logdet, mask): 20 | l = torch.sum(logs) + 0.5 * torch.sum( 21 | torch.exp(-2 * logs) * ((z - m) ** 2) 22 | ) # neg normal likelihood w/o the constant term 23 | l = l - torch.sum(logdet) # log jacobian determinant 24 | l = l / torch.sum( 25 | torch.ones_like(z) * mask 26 | ) # averaging across batch, channel and time axes 27 | l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term 28 | return l 29 | 30 | 31 | def duration_loss(logw, logw_, lengths): 32 | l = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 33 | return l 34 | 35 | 36 | @torch.jit.script 37 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 38 | n_channels_int = n_channels[0] 39 | in_act = input_a + input_b 40 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 41 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 42 | acts = t_act * s_act 43 | return acts 44 | 45 | 46 | def convert_pad_shape(pad_shape): 47 | l = pad_shape[::-1] 48 | pad_shape = [item for sublist in l for item in sublist] 49 | return pad_shape 50 | 51 | 52 | def shift_1d(x): 53 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 54 | return x 55 | 56 | 57 | def sequence_mask(length, max_length=None): 58 | if max_length is None: 59 | max_length = length.max() 60 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 61 | return x.unsqueeze(0) < length.unsqueeze(1) 62 | 63 | 64 | def maximum_path(value, mask, max_neg_val=-np.inf): 65 | """Numpy-friendly version. It's about 4 times faster than torch version. 66 | value: [b, t_x, t_y] 67 | mask: [b, t_x, t_y] 68 | """ 69 | value = value * mask 70 | 71 | device = value.device 72 | dtype = value.dtype 73 | value = value.cpu().detach().numpy() 74 | mask = mask.cpu().detach().numpy().astype(np.bool) 75 | 76 | b, t_x, t_y = value.shape 77 | direction = np.zeros(value.shape, dtype=np.int64) 78 | v = np.zeros((b, t_x), dtype=np.float32) 79 | x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) 80 | for j in range(t_y): 81 | v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[ 82 | :, :-1 83 | ] 84 | v1 = v 85 | max_mask = v1 >= v0 86 | v_max = np.where(max_mask, v1, v0) 87 | direction[:, :, j] = max_mask 88 | 89 | index_mask = x_range <= j 90 | v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) 91 | direction = np.where(mask, direction, 1) 92 | 93 | path = np.zeros(value.shape, dtype=np.float32) 94 | index = mask[:, :, 0].sum(1).astype(np.int64) - 1 95 | index_range = np.arange(b) 96 | for j in reversed(range(t_y)): 97 | path[index_range, index, j] = 1 98 | index = index + direction[index_range, index, j] - 1 99 | path = path * mask.astype(np.float32) 100 | path = torch.from_numpy(path).to(device=device, dtype=dtype) 101 | return path 102 | 103 | 104 | def generate_path(duration, mask): 105 | """ 106 | duration: [b, t_x] 107 | mask: [b, t_x, t_y] 108 | """ 109 | device = duration.device 110 | 111 | b, t_x, t_y = mask.shape 112 | cum_duration = torch.cumsum(duration, 1) 113 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 114 | 115 | cum_duration_flat = cum_duration.view(b * t_x) 116 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 117 | path = path.view(b, t_x, t_y) 118 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 119 | path = path * mask 120 | return path 121 | 122 | 123 | class Adam: 124 | def __init__( 125 | self, 126 | params, 127 | scheduler, 128 | dim_model, 129 | warmup_steps=4000, 130 | lr=1e0, 131 | betas=(0.9, 0.98), 132 | eps=1e-9, 133 | ): 134 | self.params = params 135 | self.scheduler = scheduler 136 | self.dim_model = dim_model 137 | self.warmup_steps = warmup_steps 138 | self.lr = lr 139 | self.betas = betas 140 | self.eps = eps 141 | 142 | self.step_num = 1 143 | self.cur_lr = lr * self._get_lr_scale() 144 | 145 | self._optim = torch.optim.Adam(params, lr=self.cur_lr, betas=betas, eps=eps) 146 | 147 | def _get_lr_scale(self): 148 | if self.scheduler == "noam": 149 | return np.power(self.dim_model, -0.5) * np.min( 150 | [ 151 | np.power(self.step_num, -0.5), 152 | self.step_num * np.power(self.warmup_steps, -1.5), 153 | ] 154 | ) 155 | else: 156 | return 1 157 | 158 | def _update_learning_rate(self): 159 | self.step_num += 1 160 | if self.scheduler == "noam": 161 | self.cur_lr = self.lr * self._get_lr_scale() 162 | for param_group in self._optim.param_groups: 163 | param_group["lr"] = self.cur_lr 164 | 165 | def get_lr(self): 166 | return self.cur_lr 167 | 168 | def step(self): 169 | self._optim.step() 170 | self._update_learning_rate() 171 | 172 | def zero_grad(self): 173 | self._optim.zero_grad() 174 | 175 | def load_state_dict(self, d): 176 | self._optim.load_state_dict(d) 177 | 178 | def state_dict(self): 179 | return self._optim.state_dict() 180 | 181 | 182 | class TacotronSTFT(nn.Module): 183 | def __init__( 184 | self, 185 | filter_length=1024, 186 | hop_length=256, 187 | win_length=1024, 188 | n_mel_channels=80, 189 | sampling_rate=22050, 190 | mel_fmin=0.0, 191 | mel_fmax=8000.0, 192 | ): 193 | super(TacotronSTFT, self).__init__() 194 | self.n_mel_channels = n_mel_channels 195 | self.sampling_rate = sampling_rate 196 | self.stft_fn = STFT(filter_length, hop_length, win_length) 197 | mel_basis = librosa_mel_fn( 198 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax 199 | ) 200 | mel_basis = torch.from_numpy(mel_basis).float() 201 | self.register_buffer("mel_basis", mel_basis) 202 | 203 | def spectral_normalize(self, magnitudes): 204 | output = dynamic_range_compression(magnitudes) 205 | return output 206 | 207 | def spectral_de_normalize(self, magnitudes): 208 | output = dynamic_range_decompression(magnitudes) 209 | return output 210 | 211 | def mel_spectrogram(self, y): 212 | """Computes mel-spectrograms from a batch of waves 213 | PARAMS 214 | ------ 215 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 216 | 217 | RETURNS 218 | ------- 219 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 220 | """ 221 | assert torch.min(y.data) >= -1 222 | assert torch.max(y.data) <= 1 223 | 224 | magnitudes, phases = self.stft_fn.transform(y) 225 | magnitudes = magnitudes.data 226 | mel_output = torch.matmul(self.mel_basis, magnitudes) 227 | mel_output = self.spectral_normalize(mel_output) 228 | return mel_output 229 | 230 | 231 | def clip_grad_value_(parameters, clip_value, norm_type=2): 232 | if isinstance(parameters, torch.Tensor): 233 | parameters = [parameters] 234 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 235 | norm_type = float(norm_type) 236 | clip_value = float(clip_value) 237 | 238 | total_norm = 0 239 | for p in parameters: 240 | param_norm = p.grad.data.norm(norm_type) 241 | total_norm += param_norm.item() ** norm_type 242 | 243 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 244 | total_norm = total_norm ** (1.0 / norm_type) 245 | return total_norm 246 | 247 | 248 | def squeeze(x, x_mask=None, n_sqz=2): 249 | b, c, t = x.size() 250 | 251 | t = (t // n_sqz) * n_sqz 252 | x = x[:, :, :t] 253 | x_sqz = x.view(b, c, t // n_sqz, n_sqz) 254 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) 255 | 256 | if x_mask is not None: 257 | x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz] 258 | else: 259 | x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) 260 | return x_sqz * x_mask, x_mask 261 | 262 | 263 | def unsqueeze(x, x_mask=None, n_sqz=2): 264 | b, c, t = x.size() 265 | 266 | x_unsqz = x.view(b, n_sqz, c // n_sqz, t) 267 | x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) 268 | 269 | if x_mask is not None: 270 | x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) 271 | else: 272 | x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) 273 | return x_unsqz * x_mask, x_mask 274 | -------------------------------------------------------------------------------- /src/glow_tts/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 21 | iteration = 1 22 | if "iteration" in checkpoint_dict.keys(): 23 | iteration = checkpoint_dict["iteration"] 24 | if "learning_rate" in checkpoint_dict.keys(): 25 | learning_rate = checkpoint_dict["learning_rate"] 26 | if optimizer is not None and "optimizer" in checkpoint_dict.keys(): 27 | optimizer.load_state_dict(checkpoint_dict["optimizer"]) 28 | saved_state_dict = checkpoint_dict["model"] 29 | if hasattr(model, "module"): 30 | state_dict = model.module.state_dict() 31 | else: 32 | state_dict = model.state_dict() 33 | new_state_dict = {} 34 | for k, v in state_dict.items(): 35 | try: 36 | new_state_dict[k] = saved_state_dict[k] 37 | except: 38 | logger.info("%s is not in the checkpoint" % k) 39 | new_state_dict[k] = v 40 | if hasattr(model, "module"): 41 | model.module.load_state_dict(new_state_dict) 42 | else: 43 | model.load_state_dict(new_state_dict) 44 | logger.info( 45 | "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) 46 | ) 47 | return model, optimizer, learning_rate, iteration 48 | 49 | 50 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 51 | logger.info( 52 | "Saving model and optimizer state at iteration {} to {}".format( 53 | iteration, checkpoint_path 54 | ) 55 | ) 56 | if hasattr(model, "module"): 57 | state_dict = model.module.state_dict() 58 | else: 59 | state_dict = model.state_dict() 60 | torch.save( 61 | { 62 | "model": state_dict, 63 | "iteration": iteration, 64 | "optimizer": optimizer.state_dict(), 65 | "learning_rate": learning_rate, 66 | }, 67 | checkpoint_path, 68 | ) 69 | 70 | 71 | def summarize(writer, global_step, scalars={}, histograms={}, images={}): 72 | for k, v in scalars.items(): 73 | writer.add_scalar(k, v, global_step) 74 | for k, v in histograms.items(): 75 | writer.add_histogram(k, v, global_step) 76 | for k, v in images.items(): 77 | writer.add_image(k, v, global_step, dataformats="HWC") 78 | 79 | 80 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 81 | f_list = glob.glob(os.path.join(dir_path, regex)) 82 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 83 | x = f_list[-1] 84 | print(x) 85 | return x 86 | 87 | 88 | def plot_spectrogram_to_numpy(spectrogram): 89 | global MATPLOTLIB_FLAG 90 | if not MATPLOTLIB_FLAG: 91 | import matplotlib 92 | 93 | matplotlib.use("Agg") 94 | MATPLOTLIB_FLAG = True 95 | mpl_logger = logging.getLogger("matplotlib") 96 | mpl_logger.setLevel(logging.WARNING) 97 | import matplotlib.pylab as plt 98 | import numpy as np 99 | 100 | fig, ax = plt.subplots() 101 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 102 | plt.colorbar(im, ax=ax) 103 | plt.xlabel("Frames") 104 | plt.ylabel("Channels") 105 | plt.tight_layout() 106 | 107 | fig.canvas.draw() 108 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 109 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 110 | plt.close() 111 | return data 112 | 113 | 114 | def plot_alignment_to_numpy(alignment, info=None): 115 | global MATPLOTLIB_FLAG 116 | if not MATPLOTLIB_FLAG: 117 | import matplotlib 118 | 119 | matplotlib.use("Agg") 120 | MATPLOTLIB_FLAG = True 121 | mpl_logger = logging.getLogger("matplotlib") 122 | mpl_logger.setLevel(logging.WARNING) 123 | import matplotlib.pylab as plt 124 | import numpy as np 125 | 126 | fig, ax = plt.subplots(figsize=(6, 4)) 127 | im = ax.imshow(alignment, aspect="auto", origin="lower", interpolation="none") 128 | fig.colorbar(im, ax=ax) 129 | xlabel = "Decoder timestep" 130 | if info is not None: 131 | xlabel += "\n\n" + info 132 | plt.xlabel(xlabel) 133 | plt.ylabel("Encoder timestep") 134 | plt.tight_layout() 135 | 136 | fig.canvas.draw() 137 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 138 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 139 | plt.close() 140 | return data 141 | 142 | 143 | def load_wav_to_torch(full_path): 144 | sampling_rate, data = read(full_path) 145 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 146 | 147 | 148 | def load_filepaths_and_text(filename, split="|"): 149 | with open(filename, encoding="utf-8") as f: 150 | filepaths_and_text = [line.strip().split(split) for line in f] 151 | return filepaths_and_text 152 | 153 | 154 | def get_hparams(init=True): 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument("-c", "--config", type=str, help="JSON file for configuration") 157 | parser.add_argument("-m", "--model", type=str, help="Model name") 158 | # parser.add_argument('-g', '--gan', type=str, 159 | # help='Model name') 160 | parser.add_argument("-l", "--logs", type=str, help="logs name") 161 | # parser.add_argument('-s', '--mels', type=str, 162 | # help='logs name') 163 | 164 | args = parser.parse_args() 165 | # model_dir = os.path.join("./logs", args.model) 166 | model_dir = args.model 167 | if not os.path.exists(model_dir): 168 | os.makedirs(model_dir) 169 | 170 | config_path = args.config 171 | config_save_path = os.path.join(model_dir, "config.json") 172 | 173 | # if not config_path : config_path = config_save_path 174 | 175 | if init: 176 | with open(config_path, "r") as f: 177 | data = f.read() 178 | with open(config_save_path, "w") as f: 179 | f.write(data) 180 | else: 181 | with open(config_save_path, "r") as f: 182 | data = f.read() 183 | config = json.loads(data) 184 | 185 | hparams = HParams(**config) 186 | hparams.model_dir = model_dir 187 | hparams.log_dir = args.logs 188 | # hparams.mels_dir = args.mels 189 | # hparams.gan_dir = args.gan 190 | return hparams 191 | 192 | 193 | def get_hparams_from_dir(model_dir): 194 | config_save_path = os.path.join(model_dir, "config.json") 195 | with open(config_save_path, "r") as f: 196 | data = f.read() 197 | config = json.loads(data) 198 | 199 | hparams = HParams(**config) 200 | hparams.model_dir = model_dir 201 | return hparams 202 | 203 | 204 | def get_hparams_from_file(config_path): 205 | with open(config_path, "r") as f: 206 | data = f.read() 207 | config = json.loads(data) 208 | 209 | hparams = HParams(**config) 210 | return hparams 211 | 212 | 213 | def check_git_hash(model_dir): 214 | source_dir = os.path.dirname(os.path.realpath(__file__)) 215 | if not os.path.exists(os.path.join(source_dir, ".git")): 216 | logger.warn( 217 | "{} is not a git repository, therefore hash value comparison will be ignored.".format( 218 | source_dir 219 | ) 220 | ) 221 | return 222 | 223 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 224 | 225 | path = os.path.join(model_dir, "githash") 226 | if os.path.exists(path): 227 | saved_hash = open(path).read() 228 | if saved_hash != cur_hash: 229 | logger.warn( 230 | "git hash values are different. {}(saved) != {}(current)".format( 231 | saved_hash[:8], cur_hash[:8] 232 | ) 233 | ) 234 | else: 235 | open(path, "w").write(cur_hash) 236 | 237 | 238 | def get_logger(model_dir, filename="train.log"): 239 | global logger 240 | logger = logging.getLogger(os.path.basename(model_dir)) 241 | logger.setLevel(logging.DEBUG) 242 | 243 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 244 | if not os.path.exists(model_dir): 245 | os.makedirs(model_dir) 246 | h = logging.FileHandler(os.path.join(model_dir, filename)) 247 | h.setLevel(logging.DEBUG) 248 | h.setFormatter(formatter) 249 | logger.addHandler(h) 250 | return logger 251 | 252 | 253 | class HParams: 254 | def __init__(self, **kwargs): 255 | for k, v in kwargs.items(): 256 | if type(v) == dict: 257 | v = HParams(**v) 258 | self[k] = v 259 | 260 | def keys(self): 261 | return self.__dict__.keys() 262 | 263 | def items(self): 264 | return self.__dict__.items() 265 | 266 | def values(self): 267 | return self.__dict__.values() 268 | 269 | def __len__(self): 270 | return len(self.__dict__) 271 | 272 | def __getitem__(self, key): 273 | return getattr(self, key) 274 | 275 | def __setitem__(self, key, value): 276 | return setattr(self, key, value) 277 | 278 | def __contains__(self, key): 279 | return key in self.__dict__ 280 | 281 | def __repr__(self): 282 | return self.__dict__.__repr__() 283 | -------------------------------------------------------------------------------- /src/glow_tts/modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | import commons 10 | 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, channels, eps=1e-4): 14 | super().__init__() 15 | self.channels = channels 16 | self.eps = eps 17 | 18 | self.gamma = nn.Parameter(torch.ones(channels)) 19 | self.beta = nn.Parameter(torch.zeros(channels)) 20 | 21 | def forward(self, x): 22 | n_dims = len(x.shape) 23 | mean = torch.mean(x, 1, keepdim=True) 24 | variance = torch.mean((x - mean) ** 2, 1, keepdim=True) 25 | 26 | x = (x - mean) * torch.rsqrt(variance + self.eps) 27 | 28 | shape = [1, -1] + [1] * (n_dims - 2) 29 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 30 | return x 31 | 32 | 33 | class ConvReluNorm(nn.Module): 34 | def __init__( 35 | self, 36 | in_channels, 37 | hidden_channels, 38 | out_channels, 39 | kernel_size, 40 | n_layers, 41 | p_dropout, 42 | ): 43 | super().__init__() 44 | self.in_channels = in_channels 45 | self.hidden_channels = hidden_channels 46 | self.out_channels = out_channels 47 | self.kernel_size = kernel_size 48 | self.n_layers = n_layers 49 | self.p_dropout = p_dropout 50 | assert n_layers > 1, "Number of layers should be larger than 0." 51 | 52 | self.conv_layers = nn.ModuleList() 53 | self.norm_layers = nn.ModuleList() 54 | self.conv_layers.append( 55 | nn.Conv1d( 56 | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 57 | ) 58 | ) 59 | self.norm_layers.append(LayerNorm(hidden_channels)) 60 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) 61 | for _ in range(n_layers - 1): 62 | self.conv_layers.append( 63 | nn.Conv1d( 64 | hidden_channels, 65 | hidden_channels, 66 | kernel_size, 67 | padding=kernel_size // 2, 68 | ) 69 | ) 70 | self.norm_layers.append(LayerNorm(hidden_channels)) 71 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 72 | self.proj.weight.data.zero_() 73 | self.proj.bias.data.zero_() 74 | 75 | def forward(self, x, x_mask): 76 | x_org = x 77 | for i in range(self.n_layers): 78 | x = self.conv_layers[i](x * x_mask) 79 | x = self.norm_layers[i](x) 80 | x = self.relu_drop(x) 81 | x = x_org + self.proj(x) 82 | return x * x_mask 83 | 84 | 85 | class WN(torch.nn.Module): 86 | def __init__( 87 | self, 88 | in_channels, 89 | hidden_channels, 90 | kernel_size, 91 | dilation_rate, 92 | n_layers, 93 | gin_channels=0, 94 | p_dropout=0, 95 | ): 96 | super(WN, self).__init__() 97 | assert kernel_size % 2 == 1 98 | assert hidden_channels % 2 == 0 99 | self.in_channels = in_channels 100 | self.hidden_channels = hidden_channels 101 | self.kernel_size = (kernel_size,) 102 | self.dilation_rate = dilation_rate 103 | self.n_layers = n_layers 104 | self.gin_channels = gin_channels 105 | self.p_dropout = p_dropout 106 | 107 | self.in_layers = torch.nn.ModuleList() 108 | self.res_skip_layers = torch.nn.ModuleList() 109 | self.drop = nn.Dropout(p_dropout) 110 | 111 | if gin_channels != 0: 112 | cond_layer = torch.nn.Conv1d( 113 | gin_channels, 2 * hidden_channels * n_layers, 1 114 | ) 115 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 116 | 117 | for i in range(n_layers): 118 | dilation = dilation_rate ** i 119 | padding = int((kernel_size * dilation - dilation) / 2) 120 | in_layer = torch.nn.Conv1d( 121 | hidden_channels, 122 | 2 * hidden_channels, 123 | kernel_size, 124 | dilation=dilation, 125 | padding=padding, 126 | ) 127 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 128 | self.in_layers.append(in_layer) 129 | 130 | # last one is not necessary 131 | if i < n_layers - 1: 132 | res_skip_channels = 2 * hidden_channels 133 | else: 134 | res_skip_channels = hidden_channels 135 | 136 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 137 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 138 | self.res_skip_layers.append(res_skip_layer) 139 | 140 | def forward(self, x, x_mask=None, g=None, **kwargs): 141 | output = torch.zeros_like(x) 142 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 143 | 144 | if g is not None: 145 | g = self.cond_layer(g) 146 | 147 | for i in range(self.n_layers): 148 | x_in = self.in_layers[i](x) 149 | x_in = self.drop(x_in) 150 | if g is not None: 151 | cond_offset = i * 2 * self.hidden_channels 152 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 153 | else: 154 | g_l = torch.zeros_like(x_in) 155 | 156 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 157 | 158 | res_skip_acts = self.res_skip_layers[i](acts) 159 | if i < self.n_layers - 1: 160 | x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask 161 | output = output + res_skip_acts[:, self.hidden_channels :, :] 162 | else: 163 | output = output + res_skip_acts 164 | return output * x_mask 165 | 166 | def remove_weight_norm(self): 167 | if self.gin_channels != 0: 168 | torch.nn.utils.remove_weight_norm(self.cond_layer) 169 | for l in self.in_layers: 170 | torch.nn.utils.remove_weight_norm(l) 171 | for l in self.res_skip_layers: 172 | torch.nn.utils.remove_weight_norm(l) 173 | 174 | 175 | class ActNorm(nn.Module): 176 | def __init__(self, channels, ddi=False, **kwargs): 177 | super().__init__() 178 | self.channels = channels 179 | self.initialized = not ddi 180 | 181 | self.logs = nn.Parameter(torch.zeros(1, channels, 1)) 182 | self.bias = nn.Parameter(torch.zeros(1, channels, 1)) 183 | 184 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 185 | if x_mask is None: 186 | x_mask = torch.ones(x.size(0), 1, x.size(2)).to( 187 | device=x.device, dtype=x.dtype 188 | ) 189 | x_len = torch.sum(x_mask, [1, 2]) 190 | if not self.initialized: 191 | self.initialize(x, x_mask) 192 | self.initialized = True 193 | 194 | if reverse: 195 | z = (x - self.bias) * torch.exp(-self.logs) * x_mask 196 | logdet = None 197 | else: 198 | z = (self.bias + torch.exp(self.logs) * x) * x_mask 199 | logdet = torch.sum(self.logs) * x_len # [b] 200 | 201 | return z, logdet 202 | 203 | def store_inverse(self): 204 | pass 205 | 206 | def set_ddi(self, ddi): 207 | self.initialized = not ddi 208 | 209 | def initialize(self, x, x_mask): 210 | with torch.no_grad(): 211 | denom = torch.sum(x_mask, [0, 2]) 212 | m = torch.sum(x * x_mask, [0, 2]) / denom 213 | m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom 214 | v = m_sq - (m ** 2) 215 | logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) 216 | 217 | bias_init = ( 218 | (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) 219 | ) 220 | logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) 221 | 222 | self.bias.data.copy_(bias_init) 223 | self.logs.data.copy_(logs_init) 224 | 225 | 226 | class InvConvNear(nn.Module): 227 | def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs): 228 | super().__init__() 229 | assert n_split % 2 == 0 230 | self.channels = channels 231 | self.n_split = n_split 232 | self.no_jacobian = no_jacobian 233 | 234 | w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] 235 | if torch.det(w_init) < 0: 236 | w_init[:, 0] = -1 * w_init[:, 0] 237 | self.weight = nn.Parameter(w_init) 238 | 239 | def forward(self, x, x_mask=None, reverse=False, **kwargs): 240 | b, c, t = x.size() 241 | assert c % self.n_split == 0 242 | if x_mask is None: 243 | x_mask = 1 244 | x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t 245 | else: 246 | x_len = torch.sum(x_mask, [1, 2]) 247 | 248 | x = x.view(b, 2, c // self.n_split, self.n_split // 2, t) 249 | x = ( 250 | x.permute(0, 1, 3, 2, 4) 251 | .contiguous() 252 | .view(b, self.n_split, c // self.n_split, t) 253 | ) 254 | 255 | if reverse: 256 | if hasattr(self, "weight_inv"): 257 | weight = self.weight_inv 258 | else: 259 | weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 260 | logdet = None 261 | else: 262 | weight = self.weight 263 | if self.no_jacobian: 264 | logdet = 0 265 | else: 266 | logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b] 267 | 268 | weight = weight.view(self.n_split, self.n_split, 1, 1) 269 | z = F.conv2d(x, weight) 270 | 271 | z = z.view(b, 2, self.n_split // 2, c // self.n_split, t) 272 | z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask 273 | return z, logdet 274 | 275 | def store_inverse(self): 276 | self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) 277 | -------------------------------------------------------------------------------- /src/glow_tts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import math 5 | import torch 6 | from torch import nn, optim 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.multiprocessing as mp 11 | import torch.distributed as dist 12 | from apex.parallel import DistributedDataParallel as DDP 13 | from apex import amp 14 | 15 | from data_utils import TextMelLoader, TextMelCollate 16 | import models 17 | import commons 18 | import utils 19 | 20 | 21 | global_step = 0 22 | 23 | 24 | def main(): 25 | """Assume Single Node Multi GPUs Training Only""" 26 | assert torch.cuda.is_available(), "CPU training is not allowed." 27 | 28 | n_gpus = torch.cuda.device_count() 29 | os.environ["MASTER_ADDR"] = "localhost" 30 | os.environ["MASTER_PORT"] = "80000" 31 | 32 | hps = utils.get_hparams() 33 | mp.spawn( 34 | train_and_eval, 35 | nprocs=n_gpus, 36 | args=( 37 | n_gpus, 38 | hps, 39 | ), 40 | ) 41 | 42 | 43 | def train_and_eval(rank, n_gpus, hps): 44 | global global_step 45 | if rank == 0: 46 | logger = utils.get_logger(hps.log_dir) 47 | logger.info(hps) 48 | utils.check_git_hash(hps.log_dir) 49 | writer = SummaryWriter(log_dir=hps.log_dir) 50 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.log_dir, "eval")) 51 | 52 | dist.init_process_group( 53 | backend="nccl", init_method="env://", world_size=n_gpus, rank=rank 54 | ) 55 | torch.manual_seed(hps.train.seed) 56 | torch.cuda.set_device(rank) 57 | 58 | train_dataset = TextMelLoader(hps.data.training_files, hps.data) 59 | train_sampler = torch.utils.data.distributed.DistributedSampler( 60 | train_dataset, num_replicas=n_gpus, rank=rank, shuffle=True 61 | ) 62 | collate_fn = TextMelCollate(1) 63 | train_loader = DataLoader( 64 | train_dataset, 65 | num_workers=8, 66 | shuffle=False, 67 | batch_size=hps.train.batch_size, 68 | pin_memory=True, 69 | drop_last=True, 70 | collate_fn=collate_fn, 71 | sampler=train_sampler, 72 | ) 73 | if rank == 0: 74 | val_dataset = TextMelLoader(hps.data.validation_files, hps.data) 75 | val_loader = DataLoader( 76 | val_dataset, 77 | num_workers=8, 78 | shuffle=False, 79 | batch_size=hps.train.batch_size, 80 | pin_memory=True, 81 | drop_last=True, 82 | collate_fn=collate_fn, 83 | ) 84 | symbols = hps.data.punc + hps.data.chars 85 | generator = models.FlowGenerator( 86 | n_vocab=len(symbols) + getattr(hps.data, "add_blank", False), 87 | out_channels=hps.data.n_mel_channels, 88 | **hps.model 89 | ).cuda(rank) 90 | optimizer_g = commons.Adam( 91 | generator.parameters(), 92 | scheduler=hps.train.scheduler, 93 | dim_model=hps.model.hidden_channels, 94 | warmup_steps=hps.train.warmup_steps, 95 | lr=hps.train.learning_rate, 96 | betas=hps.train.betas, 97 | eps=hps.train.eps, 98 | ) 99 | if hps.train.fp16_run: 100 | generator, optimizer_g._optim = amp.initialize( 101 | generator, optimizer_g._optim, opt_level="O1" 102 | ) 103 | generator = DDP(generator) 104 | epoch_str = 1 105 | global_step = 0 106 | try: 107 | _, _, _, epoch_str = utils.load_checkpoint( 108 | utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), 109 | generator, 110 | optimizer_g, 111 | ) 112 | epoch_str += 1 113 | optimizer_g.step_num = (epoch_str - 1) * len(train_loader) 114 | optimizer_g._update_learning_rate() 115 | global_step = (epoch_str - 1) * len(train_loader) 116 | except: 117 | if hps.train.ddi and os.path.isfile(os.path.join(hps.model_dir, "ddi_G.pth")): 118 | _ = utils.load_checkpoint( 119 | os.path.join(hps.model_dir, "ddi_G.pth"), generator, optimizer_g 120 | ) 121 | 122 | for epoch in range(epoch_str, hps.train.epochs + 1): 123 | if rank == 0: 124 | train( 125 | rank, epoch, hps, generator, optimizer_g, train_loader, logger, writer 126 | ) 127 | evaluate( 128 | rank, 129 | epoch, 130 | hps, 131 | generator, 132 | optimizer_g, 133 | val_loader, 134 | logger, 135 | writer_eval, 136 | ) 137 | if epoch % hps.train.save_epoch == 0: 138 | utils.save_checkpoint( 139 | generator, 140 | optimizer_g, 141 | hps.train.learning_rate, 142 | epoch, 143 | os.path.join(hps.model_dir, "G_{}.pth".format(epoch)), 144 | ) 145 | else: 146 | train(rank, epoch, hps, generator, optimizer_g, train_loader, None, None) 147 | 148 | 149 | def train(rank, epoch, hps, generator, optimizer_g, train_loader, logger, writer): 150 | train_loader.sampler.set_epoch(epoch) 151 | global global_step 152 | 153 | generator.train() 154 | for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader): 155 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( 156 | rank, non_blocking=True 157 | ) 158 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( 159 | rank, non_blocking=True 160 | ) 161 | 162 | # Train Generator 163 | optimizer_g.zero_grad() 164 | 165 | ( 166 | (z, z_m, z_logs, logdet, z_mask), 167 | (x_m, x_logs, x_mask), 168 | (attn, logw, logw_), 169 | ) = generator(x, x_lengths, y, y_lengths, gen=False) 170 | l_mle = commons.mle_loss(z, z_m, z_logs, logdet, z_mask) 171 | l_length = commons.duration_loss(logw, logw_, x_lengths) 172 | 173 | loss_gs = [l_mle, l_length] 174 | loss_g = sum(loss_gs) 175 | 176 | if hps.train.fp16_run: 177 | with amp.scale_loss(loss_g, optimizer_g._optim) as scaled_loss: 178 | scaled_loss.backward() 179 | grad_norm = commons.clip_grad_value_( 180 | amp.master_params(optimizer_g._optim), 5 181 | ) 182 | else: 183 | loss_g.backward() 184 | grad_norm = commons.clip_grad_value_(generator.parameters(), 5) 185 | optimizer_g.step() 186 | 187 | if rank == 0: 188 | if batch_idx % hps.train.log_interval == 0: 189 | (y_gen, *_), *_ = generator.module(x[:1], x_lengths[:1], gen=True) 190 | logger.info( 191 | "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 192 | epoch, 193 | batch_idx * len(x), 194 | len(train_loader.dataset), 195 | 100.0 * batch_idx / len(train_loader), 196 | loss_g.item(), 197 | ) 198 | ) 199 | logger.info( 200 | [x.item() for x in loss_gs] + [global_step, optimizer_g.get_lr()] 201 | ) 202 | 203 | scalar_dict = { 204 | "loss/g/total": loss_g, 205 | "learning_rate": optimizer_g.get_lr(), 206 | "grad_norm": grad_norm, 207 | } 208 | scalar_dict.update( 209 | {"loss/g/{}".format(i): v for i, v in enumerate(loss_gs)} 210 | ) 211 | utils.summarize( 212 | writer=writer, 213 | global_step=global_step, 214 | images={ 215 | "y_org": utils.plot_spectrogram_to_numpy( 216 | y[0].data.cpu().numpy() 217 | ), 218 | "y_gen": utils.plot_spectrogram_to_numpy( 219 | y_gen[0].data.cpu().numpy() 220 | ), 221 | "attn": utils.plot_alignment_to_numpy( 222 | attn[0, 0].data.cpu().numpy() 223 | ), 224 | }, 225 | scalars=scalar_dict, 226 | ) 227 | global_step += 1 228 | 229 | if rank == 0: 230 | logger.info("====> Epoch: {}".format(epoch)) 231 | 232 | 233 | def evaluate(rank, epoch, hps, generator, optimizer_g, val_loader, logger, writer_eval): 234 | if rank == 0: 235 | global global_step 236 | generator.eval() 237 | losses_tot = [] 238 | with torch.no_grad(): 239 | for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(val_loader): 240 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( 241 | rank, non_blocking=True 242 | ) 243 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( 244 | rank, non_blocking=True 245 | ) 246 | 247 | ( 248 | (z, z_m, z_logs, logdet, z_mask), 249 | (x_m, x_logs, x_mask), 250 | (attn, logw, logw_), 251 | ) = generator(x, x_lengths, y, y_lengths, gen=False) 252 | l_mle = commons.mle_loss(z, z_m, z_logs, logdet, z_mask) 253 | l_length = commons.duration_loss(logw, logw_, x_lengths) 254 | 255 | loss_gs = [l_mle, l_length] 256 | loss_g = sum(loss_gs) 257 | 258 | if batch_idx == 0: 259 | losses_tot = loss_gs 260 | else: 261 | losses_tot = [x + y for (x, y) in zip(losses_tot, loss_gs)] 262 | 263 | if batch_idx % hps.train.log_interval == 0: 264 | logger.info( 265 | "Eval Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( 266 | epoch, 267 | batch_idx * len(x), 268 | len(val_loader.dataset), 269 | 100.0 * batch_idx / len(val_loader), 270 | loss_g.item(), 271 | ) 272 | ) 273 | logger.info([x.item() for x in loss_gs]) 274 | 275 | losses_tot = [x / len(val_loader) for x in losses_tot] 276 | loss_tot = sum(losses_tot) 277 | scalar_dict = {"loss/g/total": loss_tot} 278 | scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_tot)}) 279 | utils.summarize( 280 | writer=writer_eval, global_step=global_step, scalars=scalar_dict 281 | ) 282 | logger.info("====> Epoch: {}".format(epoch)) 283 | 284 | 285 | if __name__ == "__main__": 286 | main() 287 | -------------------------------------------------------------------------------- /src/glow_tts/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | import commons 7 | from utils import load_wav_to_torch, load_filepaths_and_text 8 | from text import text_to_sequence 9 | 10 | class TextMelLoader(torch.utils.data.Dataset): 11 | """ 12 | 1) loads audio,text pairs 13 | 2) normalizes text and converts them to sequences of one-hot vectors 14 | 3) computes mel-spectrograms from audio files. 15 | """ 16 | 17 | def __init__(self, audiopaths_and_text, hparams): 18 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 19 | self.text_cleaners = hparams.text_cleaners 20 | self.max_wav_value = hparams.max_wav_value 21 | self.sampling_rate = hparams.sampling_rate 22 | self.load_mel_from_disk = hparams.load_mel_from_disk 23 | self.add_noise = hparams.add_noise 24 | self.symbols = hparams.punc + hparams.chars 25 | self.add_blank = getattr(hparams, "add_blank", False) # improved version 26 | self.stft = commons.TacotronSTFT( 27 | hparams.filter_length, 28 | hparams.hop_length, 29 | hparams.win_length, 30 | hparams.n_mel_channels, 31 | hparams.sampling_rate, 32 | hparams.mel_fmin, 33 | hparams.mel_fmax, 34 | ) 35 | random.seed(1234) 36 | random.shuffle(self.audiopaths_and_text) 37 | 38 | def get_mel_text_pair(self, audiopath_and_text): 39 | # separate filename and text 40 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 41 | text = self.get_text(text) 42 | mel = self.get_mel(audiopath) 43 | return (text, mel) 44 | 45 | def get_mel(self, filename): 46 | if not self.load_mel_from_disk: 47 | audio, sampling_rate = load_wav_to_torch(filename) 48 | if sampling_rate != self.stft.sampling_rate: 49 | raise ValueError( 50 | "{} {} SR doesn't match target {} SR".format( 51 | sampling_rate, self.stft.sampling_rate 52 | ) 53 | ) 54 | if self.add_noise: 55 | audio = audio + torch.rand_like(audio) 56 | audio_norm = audio / self.max_wav_value 57 | audio_norm = audio_norm.unsqueeze(0) 58 | melspec = self.stft.mel_spectrogram(audio_norm) 59 | melspec = torch.squeeze(melspec, 0) 60 | else: 61 | melspec = torch.from_numpy(np.load(filename)) 62 | assert ( 63 | melspec.size(0) == self.stft.n_mel_channels 64 | ), "Mel dimension mismatch: given {}, expected {}".format( 65 | melspec.size(0), self.stft.n_mel_channels 66 | ) 67 | 68 | return melspec 69 | 70 | def get_text(self, text): 71 | text_norm = text_to_sequence(text, self.symbols, self.text_cleaners) 72 | if self.add_blank: 73 | text_norm = commons.intersperse( 74 | text_norm, len(self.symbols) 75 | ) # add a blank token, whose id number is len(symbols) 76 | text_norm = torch.IntTensor(text_norm) 77 | return text_norm 78 | 79 | def __getitem__(self, index): 80 | return self.get_mel_text_pair(self.audiopaths_and_text[index]) 81 | 82 | def __len__(self): 83 | return len(self.audiopaths_and_text) 84 | 85 | 86 | class TextMelCollate: 87 | """Zero-pads model inputs and targets based on number of frames per step""" 88 | 89 | def __init__(self, n_frames_per_step=1): 90 | self.n_frames_per_step = n_frames_per_step 91 | 92 | def __call__(self, batch): 93 | """Collate's training batch from normalized text and mel-spectrogram 94 | PARAMS 95 | ------ 96 | batch: [text_normalized, mel_normalized] 97 | """ 98 | # Right zero-pad all one-hot text sequences to max input length 99 | input_lengths, ids_sorted_decreasing = torch.sort( 100 | torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True 101 | ) 102 | max_input_len = input_lengths[0] 103 | 104 | text_padded = torch.LongTensor(len(batch), max_input_len) 105 | text_padded.zero_() 106 | for i in range(len(ids_sorted_decreasing)): 107 | text = batch[ids_sorted_decreasing[i]][0] 108 | text_padded[i, : text.size(0)] = text 109 | 110 | # Right zero-pad mel-spec 111 | num_mels = batch[0][1].size(0) 112 | max_target_len = max([x[1].size(1) for x in batch]) 113 | if max_target_len % self.n_frames_per_step != 0: 114 | max_target_len += ( 115 | self.n_frames_per_step - max_target_len % self.n_frames_per_step 116 | ) 117 | assert max_target_len % self.n_frames_per_step == 0 118 | 119 | # include mel padded 120 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 121 | mel_padded.zero_() 122 | output_lengths = torch.LongTensor(len(batch)) 123 | for i in range(len(ids_sorted_decreasing)): 124 | mel = batch[ids_sorted_decreasing[i]][1] 125 | mel_padded[i, :, : mel.size(1)] = mel 126 | output_lengths[i] = mel.size(1) 127 | 128 | return text_padded, input_lengths, mel_padded, output_lengths 129 | 130 | 131 | """Multi speaker version""" 132 | 133 | 134 | class TextMelSpeakerLoader(torch.utils.data.Dataset): 135 | """ 136 | 1) loads audio, speaker_id, text pairs 137 | 2) normalizes text and converts them to sequences of one-hot vectors 138 | 3) computes mel-spectrograms from audio files. 139 | """ 140 | 141 | def __init__(self, audiopaths_sid_text, hparams): 142 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 143 | self.text_cleaners = hparams.text_cleaners 144 | self.max_wav_value = hparams.max_wav_value 145 | self.sampling_rate = hparams.sampling_rate 146 | self.load_mel_from_disk = hparams.load_mel_from_disk 147 | self.add_noise = hparams.add_noise 148 | self.symbols = hparams.punc + hparams.chars 149 | self.add_blank = getattr(hparams, "add_blank", False) # improved version 150 | self.min_text_len = getattr(hparams, "min_text_len", 1) 151 | self.max_text_len = getattr(hparams, "max_text_len", 190) 152 | self.stft = commons.TacotronSTFT( 153 | hparams.filter_length, 154 | hparams.hop_length, 155 | hparams.win_length, 156 | hparams.n_mel_channels, 157 | hparams.sampling_rate, 158 | hparams.mel_fmin, 159 | hparams.mel_fmax, 160 | ) 161 | 162 | self._filter_text_len() 163 | random.seed(1234) 164 | random.shuffle(self.audiopaths_sid_text) 165 | 166 | def _filter_text_len(self): 167 | audiopaths_sid_text_new = [] 168 | for audiopath, sid, text in self.audiopaths_sid_text: 169 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 170 | audiopaths_sid_text_new.append([audiopath, sid, text]) 171 | self.audiopaths_sid_text = audiopaths_sid_text_new 172 | 173 | def get_mel_text_speaker_pair(self, audiopath_sid_text): 174 | # separate filename, speaker_id and text 175 | audiopath, sid, text = ( 176 | audiopath_sid_text[0], 177 | audiopath_sid_text[1], 178 | audiopath_sid_text[2], 179 | ) 180 | text = self.get_text(text) 181 | mel = self.get_mel(audiopath) 182 | sid = self.get_sid(sid) 183 | return (text, mel, sid) 184 | 185 | def get_mel(self, filename): 186 | if not self.load_mel_from_disk: 187 | audio, sampling_rate = load_wav_to_torch(filename) 188 | if sampling_rate != self.stft.sampling_rate: 189 | raise ValueError( 190 | "{} {} SR doesn't match target {} SR".format( 191 | sampling_rate, self.stft.sampling_rate 192 | ) 193 | ) 194 | if self.add_noise: 195 | audio = audio + torch.rand_like(audio) 196 | audio_norm = audio / self.max_wav_value 197 | audio_norm = audio_norm.unsqueeze(0) 198 | melspec = self.stft.mel_spectrogram(audio_norm) 199 | melspec = torch.squeeze(melspec, 0) 200 | else: 201 | melspec = torch.from_numpy(np.load(filename)) 202 | assert ( 203 | melspec.size(0) == self.stft.n_mel_channels 204 | ), "Mel dimension mismatch: given {}, expected {}".format( 205 | melspec.size(0), self.stft.n_mel_channels 206 | ) 207 | 208 | return melspec 209 | 210 | def get_text(self, text): 211 | text_norm = text_to_sequence(text, self.symbols, self.text_cleaners) 212 | if self.add_blank: 213 | text_norm = commons.intersperse( 214 | text_norm, len(self.symbols) 215 | ) # add a blank token, whose id number is len(symbols) 216 | text_norm = torch.IntTensor(text_norm) 217 | return text_norm 218 | 219 | def get_sid(self, sid): 220 | sid = torch.IntTensor([int(sid)]) 221 | return sid 222 | 223 | def __getitem__(self, index): 224 | return self.get_mel_text_speaker_pair(self.audiopaths_sid_text[index]) 225 | 226 | def __len__(self): 227 | return len(self.audiopaths_sid_text) 228 | 229 | 230 | class TextMelSpeakerCollate: 231 | """Zero-pads model inputs and targets based on number of frames per step""" 232 | 233 | def __init__(self, n_frames_per_step=1): 234 | self.n_frames_per_step = n_frames_per_step 235 | 236 | def __call__(self, batch): 237 | """Collate's training batch from normalized text and mel-spectrogram 238 | PARAMS 239 | ------ 240 | batch: [text_normalized, mel_normalized] 241 | """ 242 | # Right zero-pad all one-hot text sequences to max input length 243 | input_lengths, ids_sorted_decreasing = torch.sort( 244 | torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True 245 | ) 246 | max_input_len = input_lengths[0] 247 | 248 | text_padded = torch.LongTensor(len(batch), max_input_len) 249 | text_padded.zero_() 250 | for i in range(len(ids_sorted_decreasing)): 251 | text = batch[ids_sorted_decreasing[i]][0] 252 | text_padded[i, : text.size(0)] = text 253 | 254 | # Right zero-pad mel-spec 255 | num_mels = batch[0][1].size(0) 256 | max_target_len = max([x[1].size(1) for x in batch]) 257 | if max_target_len % self.n_frames_per_step != 0: 258 | max_target_len += ( 259 | self.n_frames_per_step - max_target_len % self.n_frames_per_step 260 | ) 261 | assert max_target_len % self.n_frames_per_step == 0 262 | 263 | # include mel padded & sid 264 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 265 | mel_padded.zero_() 266 | output_lengths = torch.LongTensor(len(batch)) 267 | sid = torch.LongTensor(len(batch)) 268 | for i in range(len(ids_sorted_decreasing)): 269 | mel = batch[ids_sorted_decreasing[i]][1] 270 | mel_padded[i, :, : mel.size(1)] = mel 271 | output_lengths[i] = mel.size(1) 272 | sid[i] = batch[ids_sorted_decreasing[i]][2] 273 | 274 | return text_padded, input_lengths, mel_padded, output_lengths, sid 275 | -------------------------------------------------------------------------------- /src/hifi_gan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | from utils import init_weights, get_padding 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | 11 | class ResBlock1(torch.nn.Module): 12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 13 | super(ResBlock1, self).__init__() 14 | self.h = h 15 | self.convs1 = nn.ModuleList( 16 | [ 17 | weight_norm( 18 | Conv1d( 19 | channels, 20 | channels, 21 | kernel_size, 22 | 1, 23 | dilation=dilation[0], 24 | padding=get_padding(kernel_size, dilation[0]), 25 | ) 26 | ), 27 | weight_norm( 28 | Conv1d( 29 | channels, 30 | channels, 31 | kernel_size, 32 | 1, 33 | dilation=dilation[1], 34 | padding=get_padding(kernel_size, dilation[1]), 35 | ) 36 | ), 37 | weight_norm( 38 | Conv1d( 39 | channels, 40 | channels, 41 | kernel_size, 42 | 1, 43 | dilation=dilation[2], 44 | padding=get_padding(kernel_size, dilation[2]), 45 | ) 46 | ), 47 | ] 48 | ) 49 | self.convs1.apply(init_weights) 50 | 51 | self.convs2 = nn.ModuleList( 52 | [ 53 | weight_norm( 54 | Conv1d( 55 | channels, 56 | channels, 57 | kernel_size, 58 | 1, 59 | dilation=1, 60 | padding=get_padding(kernel_size, 1), 61 | ) 62 | ), 63 | weight_norm( 64 | Conv1d( 65 | channels, 66 | channels, 67 | kernel_size, 68 | 1, 69 | dilation=1, 70 | padding=get_padding(kernel_size, 1), 71 | ) 72 | ), 73 | weight_norm( 74 | Conv1d( 75 | channels, 76 | channels, 77 | kernel_size, 78 | 1, 79 | dilation=1, 80 | padding=get_padding(kernel_size, 1), 81 | ) 82 | ), 83 | ] 84 | ) 85 | self.convs2.apply(init_weights) 86 | 87 | def forward(self, x): 88 | for c1, c2 in zip(self.convs1, self.convs2): 89 | xt = F.leaky_relu(x, LRELU_SLOPE) 90 | xt = c1(xt) 91 | xt = F.leaky_relu(xt, LRELU_SLOPE) 92 | xt = c2(xt) 93 | x = xt + x 94 | return x 95 | 96 | def remove_weight_norm(self): 97 | for l in self.convs1: 98 | remove_weight_norm(l) 99 | for l in self.convs2: 100 | remove_weight_norm(l) 101 | 102 | 103 | class ResBlock2(torch.nn.Module): 104 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 105 | super(ResBlock2, self).__init__() 106 | self.h = h 107 | self.convs = nn.ModuleList( 108 | [ 109 | weight_norm( 110 | Conv1d( 111 | channels, 112 | channels, 113 | kernel_size, 114 | 1, 115 | dilation=dilation[0], 116 | padding=get_padding(kernel_size, dilation[0]), 117 | ) 118 | ), 119 | weight_norm( 120 | Conv1d( 121 | channels, 122 | channels, 123 | kernel_size, 124 | 1, 125 | dilation=dilation[1], 126 | padding=get_padding(kernel_size, dilation[1]), 127 | ) 128 | ), 129 | ] 130 | ) 131 | self.convs.apply(init_weights) 132 | 133 | def forward(self, x): 134 | for c in self.convs: 135 | xt = F.leaky_relu(x, LRELU_SLOPE) 136 | xt = c(xt) 137 | x = xt + x 138 | return x 139 | 140 | def remove_weight_norm(self): 141 | for l in self.convs: 142 | remove_weight_norm(l) 143 | 144 | 145 | class Generator(torch.nn.Module): 146 | def __init__(self, h): 147 | super(Generator, self).__init__() 148 | self.h = h 149 | self.num_kernels = len(h.resblock_kernel_sizes) 150 | self.num_upsamples = len(h.upsample_rates) 151 | self.conv_pre = weight_norm( 152 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 153 | ) 154 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 155 | 156 | self.ups = nn.ModuleList() 157 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 158 | self.ups.append( 159 | weight_norm( 160 | ConvTranspose1d( 161 | h.upsample_initial_channel // (2 ** i), 162 | h.upsample_initial_channel // (2 ** (i + 1)), 163 | k, 164 | u, 165 | padding=(k - u) // 2, 166 | ) 167 | ) 168 | ) 169 | 170 | self.resblocks = nn.ModuleList() 171 | for i in range(len(self.ups)): 172 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 173 | for j, (k, d) in enumerate( 174 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 175 | ): 176 | self.resblocks.append(resblock(h, ch, k, d)) 177 | 178 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 179 | self.ups.apply(init_weights) 180 | self.conv_post.apply(init_weights) 181 | 182 | def forward(self, x): 183 | x = self.conv_pre(x) 184 | for i in range(self.num_upsamples): 185 | x = F.leaky_relu(x, LRELU_SLOPE) 186 | x = self.ups[i](x) 187 | xs = None 188 | for j in range(self.num_kernels): 189 | if xs is None: 190 | xs = self.resblocks[i * self.num_kernels + j](x) 191 | else: 192 | xs += self.resblocks[i * self.num_kernels + j](x) 193 | x = xs / self.num_kernels 194 | x = F.leaky_relu(x) 195 | x = self.conv_post(x) 196 | x = torch.tanh(x) 197 | 198 | return x 199 | 200 | def remove_weight_norm(self): 201 | print("Removing weight norm...") 202 | for l in self.ups: 203 | remove_weight_norm(l) 204 | for l in self.resblocks: 205 | l.remove_weight_norm() 206 | remove_weight_norm(self.conv_pre) 207 | remove_weight_norm(self.conv_post) 208 | 209 | 210 | class DiscriminatorP(torch.nn.Module): 211 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 212 | super(DiscriminatorP, self).__init__() 213 | self.period = period 214 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 215 | self.convs = nn.ModuleList( 216 | [ 217 | norm_f( 218 | Conv2d( 219 | 1, 220 | 32, 221 | (kernel_size, 1), 222 | (stride, 1), 223 | padding=(get_padding(5, 1), 0), 224 | ) 225 | ), 226 | norm_f( 227 | Conv2d( 228 | 32, 229 | 128, 230 | (kernel_size, 1), 231 | (stride, 1), 232 | padding=(get_padding(5, 1), 0), 233 | ) 234 | ), 235 | norm_f( 236 | Conv2d( 237 | 128, 238 | 512, 239 | (kernel_size, 1), 240 | (stride, 1), 241 | padding=(get_padding(5, 1), 0), 242 | ) 243 | ), 244 | norm_f( 245 | Conv2d( 246 | 512, 247 | 1024, 248 | (kernel_size, 1), 249 | (stride, 1), 250 | padding=(get_padding(5, 1), 0), 251 | ) 252 | ), 253 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 254 | ] 255 | ) 256 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 257 | 258 | def forward(self, x): 259 | fmap = [] 260 | 261 | # 1d to 2d 262 | b, c, t = x.shape 263 | if t % self.period != 0: # pad first 264 | n_pad = self.period - (t % self.period) 265 | x = F.pad(x, (0, n_pad), "reflect") 266 | t = t + n_pad 267 | x = x.view(b, c, t // self.period, self.period) 268 | 269 | for l in self.convs: 270 | x = l(x) 271 | x = F.leaky_relu(x, LRELU_SLOPE) 272 | fmap.append(x) 273 | x = self.conv_post(x) 274 | fmap.append(x) 275 | x = torch.flatten(x, 1, -1) 276 | 277 | return x, fmap 278 | 279 | 280 | class MultiPeriodDiscriminator(torch.nn.Module): 281 | def __init__(self): 282 | super(MultiPeriodDiscriminator, self).__init__() 283 | self.discriminators = nn.ModuleList( 284 | [ 285 | DiscriminatorP(2), 286 | DiscriminatorP(3), 287 | DiscriminatorP(5), 288 | DiscriminatorP(7), 289 | DiscriminatorP(11), 290 | ] 291 | ) 292 | 293 | def forward(self, y, y_hat): 294 | y_d_rs = [] 295 | y_d_gs = [] 296 | fmap_rs = [] 297 | fmap_gs = [] 298 | for i, d in enumerate(self.discriminators): 299 | y_d_r, fmap_r = d(y) 300 | y_d_g, fmap_g = d(y_hat) 301 | y_d_rs.append(y_d_r) 302 | fmap_rs.append(fmap_r) 303 | y_d_gs.append(y_d_g) 304 | fmap_gs.append(fmap_g) 305 | 306 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 307 | 308 | 309 | class DiscriminatorS(torch.nn.Module): 310 | def __init__(self, use_spectral_norm=False): 311 | super(DiscriminatorS, self).__init__() 312 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 313 | self.convs = nn.ModuleList( 314 | [ 315 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 316 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 317 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 318 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 319 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 320 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 321 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 322 | ] 323 | ) 324 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 325 | 326 | def forward(self, x): 327 | fmap = [] 328 | for l in self.convs: 329 | x = l(x) 330 | x = F.leaky_relu(x, LRELU_SLOPE) 331 | fmap.append(x) 332 | x = self.conv_post(x) 333 | fmap.append(x) 334 | x = torch.flatten(x, 1, -1) 335 | 336 | return x, fmap 337 | 338 | 339 | class MultiScaleDiscriminator(torch.nn.Module): 340 | def __init__(self): 341 | super(MultiScaleDiscriminator, self).__init__() 342 | self.discriminators = nn.ModuleList( 343 | [ 344 | DiscriminatorS(use_spectral_norm=True), 345 | DiscriminatorS(), 346 | DiscriminatorS(), 347 | ] 348 | ) 349 | self.meanpools = nn.ModuleList( 350 | [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] 351 | ) 352 | 353 | def forward(self, y, y_hat): 354 | y_d_rs = [] 355 | y_d_gs = [] 356 | fmap_rs = [] 357 | fmap_gs = [] 358 | for i, d in enumerate(self.discriminators): 359 | if i != 0: 360 | y = self.meanpools[i - 1](y) 361 | y_hat = self.meanpools[i - 1](y_hat) 362 | y_d_r, fmap_r = d(y) 363 | y_d_g, fmap_g = d(y_hat) 364 | y_d_rs.append(y_d_r) 365 | fmap_rs.append(fmap_r) 366 | y_d_gs.append(y_d_g) 367 | fmap_gs.append(fmap_g) 368 | 369 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 370 | 371 | 372 | def feature_loss(fmap_r, fmap_g): 373 | loss = 0 374 | for dr, dg in zip(fmap_r, fmap_g): 375 | for rl, gl in zip(dr, dg): 376 | loss += torch.mean(torch.abs(rl - gl)) 377 | 378 | return loss * 2 379 | 380 | 381 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 382 | loss = 0 383 | r_losses = [] 384 | g_losses = [] 385 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 386 | r_loss = torch.mean((1 - dr) ** 2) 387 | g_loss = torch.mean(dg ** 2) 388 | loss += r_loss + g_loss 389 | r_losses.append(r_loss.item()) 390 | g_losses.append(g_loss.item()) 391 | 392 | return loss, r_losses, g_losses 393 | 394 | 395 | def generator_loss(disc_outputs): 396 | loss = 0 397 | gen_losses = [] 398 | for dg in disc_outputs: 399 | l = torch.mean((1 - dg) ** 2) 400 | gen_losses.append(l) 401 | loss += l 402 | 403 | return loss, gen_losses 404 | -------------------------------------------------------------------------------- /src/glow_tts/hifi/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 6 | from .utils import init_weights, get_padding 7 | 8 | LRELU_SLOPE = 0.1 9 | 10 | 11 | class ResBlock1(torch.nn.Module): 12 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 13 | super(ResBlock1, self).__init__() 14 | self.h = h 15 | self.convs1 = nn.ModuleList( 16 | [ 17 | weight_norm( 18 | Conv1d( 19 | channels, 20 | channels, 21 | kernel_size, 22 | 1, 23 | dilation=dilation[0], 24 | padding=get_padding(kernel_size, dilation[0]), 25 | ) 26 | ), 27 | weight_norm( 28 | Conv1d( 29 | channels, 30 | channels, 31 | kernel_size, 32 | 1, 33 | dilation=dilation[1], 34 | padding=get_padding(kernel_size, dilation[1]), 35 | ) 36 | ), 37 | weight_norm( 38 | Conv1d( 39 | channels, 40 | channels, 41 | kernel_size, 42 | 1, 43 | dilation=dilation[2], 44 | padding=get_padding(kernel_size, dilation[2]), 45 | ) 46 | ), 47 | ] 48 | ) 49 | self.convs1.apply(init_weights) 50 | 51 | self.convs2 = nn.ModuleList( 52 | [ 53 | weight_norm( 54 | Conv1d( 55 | channels, 56 | channels, 57 | kernel_size, 58 | 1, 59 | dilation=1, 60 | padding=get_padding(kernel_size, 1), 61 | ) 62 | ), 63 | weight_norm( 64 | Conv1d( 65 | channels, 66 | channels, 67 | kernel_size, 68 | 1, 69 | dilation=1, 70 | padding=get_padding(kernel_size, 1), 71 | ) 72 | ), 73 | weight_norm( 74 | Conv1d( 75 | channels, 76 | channels, 77 | kernel_size, 78 | 1, 79 | dilation=1, 80 | padding=get_padding(kernel_size, 1), 81 | ) 82 | ), 83 | ] 84 | ) 85 | self.convs2.apply(init_weights) 86 | 87 | def forward(self, x): 88 | for c1, c2 in zip(self.convs1, self.convs2): 89 | xt = F.leaky_relu(x, LRELU_SLOPE) 90 | xt = c1(xt) 91 | xt = F.leaky_relu(xt, LRELU_SLOPE) 92 | xt = c2(xt) 93 | x = xt + x 94 | return x 95 | 96 | def remove_weight_norm(self): 97 | for l in self.convs1: 98 | remove_weight_norm(l) 99 | for l in self.convs2: 100 | remove_weight_norm(l) 101 | 102 | 103 | class ResBlock2(torch.nn.Module): 104 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 105 | super(ResBlock2, self).__init__() 106 | self.h = h 107 | self.convs = nn.ModuleList( 108 | [ 109 | weight_norm( 110 | Conv1d( 111 | channels, 112 | channels, 113 | kernel_size, 114 | 1, 115 | dilation=dilation[0], 116 | padding=get_padding(kernel_size, dilation[0]), 117 | ) 118 | ), 119 | weight_norm( 120 | Conv1d( 121 | channels, 122 | channels, 123 | kernel_size, 124 | 1, 125 | dilation=dilation[1], 126 | padding=get_padding(kernel_size, dilation[1]), 127 | ) 128 | ), 129 | ] 130 | ) 131 | self.convs.apply(init_weights) 132 | 133 | def forward(self, x): 134 | for c in self.convs: 135 | xt = F.leaky_relu(x, LRELU_SLOPE) 136 | xt = c(xt) 137 | x = xt + x 138 | return x 139 | 140 | def remove_weight_norm(self): 141 | for l in self.convs: 142 | remove_weight_norm(l) 143 | 144 | 145 | class Generator(torch.nn.Module): 146 | def __init__(self, h): 147 | super(Generator, self).__init__() 148 | self.h = h 149 | self.num_kernels = len(h.resblock_kernel_sizes) 150 | self.num_upsamples = len(h.upsample_rates) 151 | self.conv_pre = weight_norm( 152 | Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) 153 | ) 154 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 155 | 156 | self.ups = nn.ModuleList() 157 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 158 | self.ups.append( 159 | weight_norm( 160 | ConvTranspose1d( 161 | h.upsample_initial_channel // (2 ** i), 162 | h.upsample_initial_channel // (2 ** (i + 1)), 163 | k, 164 | u, 165 | padding=(k - u) // 2, 166 | ) 167 | ) 168 | ) 169 | 170 | self.resblocks = nn.ModuleList() 171 | for i in range(len(self.ups)): 172 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 173 | for j, (k, d) in enumerate( 174 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 175 | ): 176 | self.resblocks.append(resblock(h, ch, k, d)) 177 | 178 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 179 | self.ups.apply(init_weights) 180 | self.conv_post.apply(init_weights) 181 | 182 | def forward(self, x): 183 | x = self.conv_pre(x) 184 | for i in range(self.num_upsamples): 185 | x = F.leaky_relu(x, LRELU_SLOPE) 186 | x = self.ups[i](x) 187 | xs = None 188 | for j in range(self.num_kernels): 189 | if xs is None: 190 | xs = self.resblocks[i * self.num_kernels + j](x) 191 | else: 192 | xs += self.resblocks[i * self.num_kernels + j](x) 193 | x = xs / self.num_kernels 194 | x = F.leaky_relu(x) 195 | x = self.conv_post(x) 196 | x = torch.tanh(x) 197 | 198 | return x 199 | 200 | def remove_weight_norm(self): 201 | print("Removing weight norm...") 202 | for l in self.ups: 203 | remove_weight_norm(l) 204 | for l in self.resblocks: 205 | l.remove_weight_norm() 206 | remove_weight_norm(self.conv_pre) 207 | remove_weight_norm(self.conv_post) 208 | 209 | 210 | class DiscriminatorP(torch.nn.Module): 211 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 212 | super(DiscriminatorP, self).__init__() 213 | self.period = period 214 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 215 | self.convs = nn.ModuleList( 216 | [ 217 | norm_f( 218 | Conv2d( 219 | 1, 220 | 32, 221 | (kernel_size, 1), 222 | (stride, 1), 223 | padding=(get_padding(5, 1), 0), 224 | ) 225 | ), 226 | norm_f( 227 | Conv2d( 228 | 32, 229 | 128, 230 | (kernel_size, 1), 231 | (stride, 1), 232 | padding=(get_padding(5, 1), 0), 233 | ) 234 | ), 235 | norm_f( 236 | Conv2d( 237 | 128, 238 | 512, 239 | (kernel_size, 1), 240 | (stride, 1), 241 | padding=(get_padding(5, 1), 0), 242 | ) 243 | ), 244 | norm_f( 245 | Conv2d( 246 | 512, 247 | 1024, 248 | (kernel_size, 1), 249 | (stride, 1), 250 | padding=(get_padding(5, 1), 0), 251 | ) 252 | ), 253 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 254 | ] 255 | ) 256 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 257 | 258 | def forward(self, x): 259 | fmap = [] 260 | 261 | # 1d to 2d 262 | b, c, t = x.shape 263 | if t % self.period != 0: # pad first 264 | n_pad = self.period - (t % self.period) 265 | x = F.pad(x, (0, n_pad), "reflect") 266 | t = t + n_pad 267 | x = x.view(b, c, t // self.period, self.period) 268 | 269 | for l in self.convs: 270 | x = l(x) 271 | x = F.leaky_relu(x, LRELU_SLOPE) 272 | fmap.append(x) 273 | x = self.conv_post(x) 274 | fmap.append(x) 275 | x = torch.flatten(x, 1, -1) 276 | 277 | return x, fmap 278 | 279 | 280 | class MultiPeriodDiscriminator(torch.nn.Module): 281 | def __init__(self): 282 | super(MultiPeriodDiscriminator, self).__init__() 283 | self.discriminators = nn.ModuleList( 284 | [ 285 | DiscriminatorP(2), 286 | DiscriminatorP(3), 287 | DiscriminatorP(5), 288 | DiscriminatorP(7), 289 | DiscriminatorP(11), 290 | ] 291 | ) 292 | 293 | def forward(self, y, y_hat): 294 | y_d_rs = [] 295 | y_d_gs = [] 296 | fmap_rs = [] 297 | fmap_gs = [] 298 | for i, d in enumerate(self.discriminators): 299 | y_d_r, fmap_r = d(y) 300 | y_d_g, fmap_g = d(y_hat) 301 | y_d_rs.append(y_d_r) 302 | fmap_rs.append(fmap_r) 303 | y_d_gs.append(y_d_g) 304 | fmap_gs.append(fmap_g) 305 | 306 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 307 | 308 | 309 | class DiscriminatorS(torch.nn.Module): 310 | def __init__(self, use_spectral_norm=False): 311 | super(DiscriminatorS, self).__init__() 312 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 313 | self.convs = nn.ModuleList( 314 | [ 315 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 316 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 317 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 318 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 319 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 320 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 321 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 322 | ] 323 | ) 324 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 325 | 326 | def forward(self, x): 327 | fmap = [] 328 | for l in self.convs: 329 | x = l(x) 330 | x = F.leaky_relu(x, LRELU_SLOPE) 331 | fmap.append(x) 332 | x = self.conv_post(x) 333 | fmap.append(x) 334 | x = torch.flatten(x, 1, -1) 335 | 336 | return x, fmap 337 | 338 | 339 | class MultiScaleDiscriminator(torch.nn.Module): 340 | def __init__(self): 341 | super(MultiScaleDiscriminator, self).__init__() 342 | self.discriminators = nn.ModuleList( 343 | [ 344 | DiscriminatorS(use_spectral_norm=True), 345 | DiscriminatorS(), 346 | DiscriminatorS(), 347 | ] 348 | ) 349 | self.meanpools = nn.ModuleList( 350 | [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] 351 | ) 352 | 353 | def forward(self, y, y_hat): 354 | y_d_rs = [] 355 | y_d_gs = [] 356 | fmap_rs = [] 357 | fmap_gs = [] 358 | for i, d in enumerate(self.discriminators): 359 | if i != 0: 360 | y = self.meanpools[i - 1](y) 361 | y_hat = self.meanpools[i - 1](y_hat) 362 | y_d_r, fmap_r = d(y) 363 | y_d_g, fmap_g = d(y_hat) 364 | y_d_rs.append(y_d_r) 365 | fmap_rs.append(fmap_r) 366 | y_d_gs.append(y_d_g) 367 | fmap_gs.append(fmap_g) 368 | 369 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 370 | 371 | 372 | def feature_loss(fmap_r, fmap_g): 373 | loss = 0 374 | for dr, dg in zip(fmap_r, fmap_g): 375 | for rl, gl in zip(dr, dg): 376 | loss += torch.mean(torch.abs(rl - gl)) 377 | 378 | return loss * 2 379 | 380 | 381 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 382 | loss = 0 383 | r_losses = [] 384 | g_losses = [] 385 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 386 | r_loss = torch.mean((1 - dr) ** 2) 387 | g_loss = torch.mean(dg ** 2) 388 | loss += r_loss + g_loss 389 | r_losses.append(r_loss.item()) 390 | g_losses.append(g_loss.item()) 391 | 392 | return loss, r_losses, g_losses 393 | 394 | 395 | def generator_loss(disc_outputs): 396 | loss = 0 397 | gen_losses = [] 398 | for dg in disc_outputs: 399 | l = torch.mean((1 - dg) ** 2) 400 | gen_losses.append(l) 401 | loss += l 402 | 403 | return loss, gen_losses 404 | -------------------------------------------------------------------------------- /src/glow_tts/attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import commons 9 | import modules 10 | from modules import LayerNorm 11 | 12 | 13 | class Encoder(nn.Module): 14 | def __init__( 15 | self, 16 | hidden_channels, 17 | filter_channels, 18 | n_heads, 19 | n_layers, 20 | kernel_size=1, 21 | p_dropout=0.0, 22 | window_size=None, 23 | block_length=None, 24 | **kwargs 25 | ): 26 | super().__init__() 27 | self.hidden_channels = hidden_channels 28 | self.filter_channels = filter_channels 29 | self.n_heads = n_heads 30 | self.n_layers = n_layers 31 | self.kernel_size = kernel_size 32 | self.p_dropout = p_dropout 33 | self.window_size = window_size 34 | self.block_length = block_length 35 | 36 | self.drop = nn.Dropout(p_dropout) 37 | self.attn_layers = nn.ModuleList() 38 | self.norm_layers_1 = nn.ModuleList() 39 | self.ffn_layers = nn.ModuleList() 40 | self.norm_layers_2 = nn.ModuleList() 41 | for i in range(self.n_layers): 42 | self.attn_layers.append( 43 | MultiHeadAttention( 44 | hidden_channels, 45 | hidden_channels, 46 | n_heads, 47 | window_size=window_size, 48 | p_dropout=p_dropout, 49 | block_length=block_length, 50 | ) 51 | ) 52 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 53 | self.ffn_layers.append( 54 | FFN( 55 | hidden_channels, 56 | hidden_channels, 57 | filter_channels, 58 | kernel_size, 59 | p_dropout=p_dropout, 60 | ) 61 | ) 62 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 63 | 64 | def forward(self, x, x_mask): 65 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 66 | for i in range(self.n_layers): 67 | x = x * x_mask 68 | y = self.attn_layers[i](x, x, attn_mask) 69 | y = self.drop(y) 70 | x = self.norm_layers_1[i](x + y) 71 | 72 | y = self.ffn_layers[i](x, x_mask) 73 | y = self.drop(y) 74 | x = self.norm_layers_2[i](x + y) 75 | x = x * x_mask 76 | return x 77 | 78 | 79 | class CouplingBlock(nn.Module): 80 | def __init__( 81 | self, 82 | in_channels, 83 | hidden_channels, 84 | kernel_size, 85 | dilation_rate, 86 | n_layers, 87 | gin_channels=0, 88 | p_dropout=0, 89 | sigmoid_scale=False, 90 | ): 91 | super().__init__() 92 | self.in_channels = in_channels 93 | self.hidden_channels = hidden_channels 94 | self.kernel_size = kernel_size 95 | self.dilation_rate = dilation_rate 96 | self.n_layers = n_layers 97 | self.gin_channels = gin_channels 98 | self.p_dropout = p_dropout 99 | self.sigmoid_scale = sigmoid_scale 100 | 101 | start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) 102 | start = torch.nn.utils.weight_norm(start) 103 | self.start = start 104 | # Initializing last layer to 0 makes the affine coupling layers 105 | # do nothing at first. It helps to stabilze training. 106 | end = torch.nn.Conv1d(hidden_channels, in_channels, 1) 107 | end.weight.data.zero_() 108 | end.bias.data.zero_() 109 | self.end = end 110 | 111 | self.wn = modules.WN( 112 | in_channels, 113 | hidden_channels, 114 | kernel_size, 115 | dilation_rate, 116 | n_layers, 117 | gin_channels, 118 | p_dropout, 119 | ) 120 | 121 | def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): 122 | b, c, t = x.size() 123 | if x_mask is None: 124 | x_mask = 1 125 | x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] 126 | 127 | x = self.start(x_0) * x_mask 128 | x = self.wn(x, x_mask, g) 129 | out = self.end(x) 130 | 131 | z_0 = x_0 132 | m = out[:, : self.in_channels // 2, :] 133 | logs = out[:, self.in_channels // 2 :, :] 134 | if self.sigmoid_scale: 135 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) 136 | 137 | if reverse: 138 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask 139 | logdet = None 140 | else: 141 | z_1 = (m + torch.exp(logs) * x_1) * x_mask 142 | logdet = torch.sum(logs * x_mask, [1, 2]) 143 | 144 | z = torch.cat([z_0, z_1], 1) 145 | return z, logdet 146 | 147 | def store_inverse(self): 148 | self.wn.remove_weight_norm() 149 | 150 | 151 | class MultiHeadAttention(nn.Module): 152 | def __init__( 153 | self, 154 | channels, 155 | out_channels, 156 | n_heads, 157 | window_size=None, 158 | heads_share=True, 159 | p_dropout=0.0, 160 | block_length=None, 161 | proximal_bias=False, 162 | proximal_init=False, 163 | ): 164 | super().__init__() 165 | assert channels % n_heads == 0 166 | 167 | self.channels = channels 168 | self.out_channels = out_channels 169 | self.n_heads = n_heads 170 | self.window_size = window_size 171 | self.heads_share = heads_share 172 | self.block_length = block_length 173 | self.proximal_bias = proximal_bias 174 | self.p_dropout = p_dropout 175 | self.attn = None 176 | 177 | self.k_channels = channels // n_heads 178 | self.conv_q = nn.Conv1d(channels, channels, 1) 179 | self.conv_k = nn.Conv1d(channels, channels, 1) 180 | self.conv_v = nn.Conv1d(channels, channels, 1) 181 | if window_size is not None: 182 | n_heads_rel = 1 if heads_share else n_heads 183 | rel_stddev = self.k_channels ** -0.5 184 | self.emb_rel_k = nn.Parameter( 185 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 186 | * rel_stddev 187 | ) 188 | self.emb_rel_v = nn.Parameter( 189 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 190 | * rel_stddev 191 | ) 192 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 193 | self.drop = nn.Dropout(p_dropout) 194 | 195 | nn.init.xavier_uniform_(self.conv_q.weight) 196 | nn.init.xavier_uniform_(self.conv_k.weight) 197 | if proximal_init: 198 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 199 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 200 | nn.init.xavier_uniform_(self.conv_v.weight) 201 | 202 | def forward(self, x, c, attn_mask=None): 203 | q = self.conv_q(x) 204 | k = self.conv_k(c) 205 | v = self.conv_v(c) 206 | 207 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 208 | 209 | x = self.conv_o(x) 210 | return x 211 | 212 | def attention(self, query, key, value, mask=None): 213 | # reshape [b, d, t] -> [b, n_h, t, d_k] 214 | b, d, t_s, t_t = (*key.size(), query.size(2)) 215 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 216 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 217 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 218 | 219 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 220 | if self.window_size is not None: 221 | assert ( 222 | t_s == t_t 223 | ), "Relative attention is only available for self-attention." 224 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 225 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) 226 | rel_logits = self._relative_position_to_absolute_position(rel_logits) 227 | scores_local = rel_logits / math.sqrt(self.k_channels) 228 | scores = scores + scores_local 229 | if self.proximal_bias: 230 | assert t_s == t_t, "Proximal bias is only available for self-attention." 231 | scores = scores + self._attention_bias_proximal(t_s).to( 232 | device=scores.device, dtype=scores.dtype 233 | ) 234 | if mask is not None: 235 | scores = scores.masked_fill(mask == 0, -1e4) 236 | if self.block_length is not None: 237 | block_mask = ( 238 | torch.ones_like(scores) 239 | .triu(-self.block_length) 240 | .tril(self.block_length) 241 | ) 242 | scores = scores * block_mask + -1e4 * (1 - block_mask) 243 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 244 | p_attn = self.drop(p_attn) 245 | output = torch.matmul(p_attn, value) 246 | if self.window_size is not None: 247 | relative_weights = self._absolute_position_to_relative_position(p_attn) 248 | value_relative_embeddings = self._get_relative_embeddings( 249 | self.emb_rel_v, t_s 250 | ) 251 | output = output + self._matmul_with_relative_values( 252 | relative_weights, value_relative_embeddings 253 | ) 254 | output = ( 255 | output.transpose(2, 3).contiguous().view(b, d, t_t) 256 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 257 | return output, p_attn 258 | 259 | def _matmul_with_relative_values(self, x, y): 260 | """ 261 | x: [b, h, l, m] 262 | y: [h or 1, m, d] 263 | ret: [b, h, l, d] 264 | """ 265 | ret = torch.matmul(x, y.unsqueeze(0)) 266 | return ret 267 | 268 | def _matmul_with_relative_keys(self, x, y): 269 | """ 270 | x: [b, h, l, d] 271 | y: [h or 1, m, d] 272 | ret: [b, h, l, m] 273 | """ 274 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 275 | return ret 276 | 277 | def _get_relative_embeddings(self, relative_embeddings, length): 278 | max_relative_position = 2 * self.window_size + 1 279 | # Pad first before slice to avoid using cond ops. 280 | pad_length = max(length - (self.window_size + 1), 0) 281 | slice_start_position = max((self.window_size + 1) - length, 0) 282 | slice_end_position = slice_start_position + 2 * length - 1 283 | if pad_length > 0: 284 | padded_relative_embeddings = F.pad( 285 | relative_embeddings, 286 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 287 | ) 288 | else: 289 | padded_relative_embeddings = relative_embeddings 290 | used_relative_embeddings = padded_relative_embeddings[ 291 | :, slice_start_position:slice_end_position 292 | ] 293 | return used_relative_embeddings 294 | 295 | def _relative_position_to_absolute_position(self, x): 296 | """ 297 | x: [b, h, l, 2*l-1] 298 | ret: [b, h, l, l] 299 | """ 300 | batch, heads, length, _ = x.size() 301 | # Concat columns of pad to shift from relative to absolute indexing. 302 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 303 | 304 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 305 | x_flat = x.view([batch, heads, length * 2 * length]) 306 | x_flat = F.pad( 307 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 308 | ) 309 | 310 | # Reshape and slice out the padded elements. 311 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 312 | :, :, :length, length - 1 : 313 | ] 314 | return x_final 315 | 316 | def _absolute_position_to_relative_position(self, x): 317 | """ 318 | x: [b, h, l, l] 319 | ret: [b, h, l, 2*l-1] 320 | """ 321 | batch, heads, length, _ = x.size() 322 | # padd along column 323 | x = F.pad( 324 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 325 | ) 326 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) 327 | # add 0's in the beginning that will skew the elements after reshape 328 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 329 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 330 | return x_final 331 | 332 | def _attention_bias_proximal(self, length): 333 | """Bias for self-attention to encourage attention to close positions. 334 | Args: 335 | length: an integer scalar. 336 | Returns: 337 | a Tensor with shape [1, 1, length, length] 338 | """ 339 | r = torch.arange(length, dtype=torch.float32) 340 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 341 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 342 | 343 | 344 | class FFN(nn.Module): 345 | def __init__( 346 | self, 347 | in_channels, 348 | out_channels, 349 | filter_channels, 350 | kernel_size, 351 | p_dropout=0.0, 352 | activation=None, 353 | ): 354 | super().__init__() 355 | self.in_channels = in_channels 356 | self.out_channels = out_channels 357 | self.filter_channels = filter_channels 358 | self.kernel_size = kernel_size 359 | self.p_dropout = p_dropout 360 | self.activation = activation 361 | 362 | self.conv_1 = nn.Conv1d( 363 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2 364 | ) 365 | self.conv_2 = nn.Conv1d( 366 | filter_channels, out_channels, kernel_size, padding=kernel_size // 2 367 | ) 368 | self.drop = nn.Dropout(p_dropout) 369 | 370 | def forward(self, x, x_mask): 371 | x = self.conv_1(x * x_mask) 372 | if self.activation == "gelu": 373 | x = x * torch.sigmoid(1.702 * x) 374 | else: 375 | x = torch.relu(x) 376 | x = self.drop(x) 377 | x = self.conv_2(x * x_mask) 378 | return x * x_mask 379 | --------------------------------------------------------------------------------