├── models ├── hubert │ └── Put hubert checkpoint here.txt └── sovits │ └── Put so-vits checkpoint in a folder here.txt ├── crowdin.yml ├── .gitignore ├── README.md ├── localizations ├── zh_CN.json └── en.json ├── requirements.txt ├── modules ├── path.py ├── localization.py ├── devices.py ├── model.py ├── utils.py ├── vits_model.py ├── options.py ├── safe.py ├── process.py ├── sovits_model.py └── ui.py ├── vits ├── monotonic_align │ ├── __init__.py │ └── core.py ├── text │ ├── symbols.py │ ├── __init__.py │ ├── LICENSE │ ├── thai.py │ ├── ngu_dialect.py │ ├── sanskrit.py │ ├── cantonese.py │ ├── shanghainese.py │ ├── japanese.py │ ├── english.py │ ├── korean.py │ ├── cleaners.py │ └── mandarin.py ├── mel_processing.py ├── commons.py ├── utils.py ├── transforms.py ├── attentions.py ├── modules.py └── models.py ├── webui.py ├── script.js ├── scripts └── localization.js └── style.css /models/hubert/Put hubert checkpoint here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/sovits/Put so-vits checkpoint in a folder here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /crowdin.yml: -------------------------------------------------------------------------------- 1 | files: 2 | - source: /localizations/en.json 3 | translation: /localizations/%locale_with_underscore%.json 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | 4 | __pycache__ 5 | venv 6 | 7 | models/vits/** 8 | models/sovits/** 9 | models/hubert/** 10 | 11 | outputs/* 12 | temp/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VITS web UI 2 | 3 | A browser interface based on Gradio library for VITS/SO-VITS. 4 | 5 | ## Installation and Running 6 | 7 | Just run it. Trust yourself. You can make it! -------------------------------------------------------------------------------- /localizations/zh_CN.json: -------------------------------------------------------------------------------- 1 | { 2 | "Text (press Ctrl+Enter or Alt+Enter to generate)": "文字 (按下 Ctrl+Enter 或 Alt+Enter 开始生成)", 3 | "Generate": "生成", 4 | "Speakers": "说话人", 5 | "VITS Checkpoint": "VITS 模型", 6 | "Process Method": "处理方式", 7 | "Speed": "语速", 8 | "Output Message": "输出信息", 9 | "Output Audio": "输出音频", 10 | "Open Folder": "打开文件夹", 11 | "Save": "保存" 12 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | librosa 3 | matplotlib 4 | numpy 5 | phonemizer 6 | scipy 7 | tensorboard 8 | torch 9 | torchvision 10 | torchaudio 11 | unidecode 12 | pyopenjtalk>=0.3.0 13 | jamo 14 | pypinyin 15 | ko_pron 16 | jieba 17 | cn2an 18 | protobuf 19 | inflect 20 | eng_to_ipa 21 | indic_transliteration 22 | num_thai 23 | opencc 24 | gradio==3.15 25 | praat-parselmouth 26 | tqdm -------------------------------------------------------------------------------- /modules/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | webui_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 5 | sys.path.insert(0, webui_path) 6 | 7 | paths = [ 8 | { 9 | "t": 0, 10 | "p": os.path.join(webui_path, "repositories/sovits") 11 | } 12 | ] 13 | 14 | 15 | def insert_repositories_path(): 16 | for p in paths: 17 | if p["t"] == 0: 18 | sys.path.insert(0, p["p"]) 19 | else: 20 | sys.path.append(p["p"]) 21 | 22 | 23 | # insert_repositories_path() 24 | -------------------------------------------------------------------------------- /modules/localization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | localization_dir = "localizations" 4 | localization_files = os.listdir(localization_dir) 5 | 6 | 7 | def gen_localization_js(name): 8 | if name not in localization_files: 9 | print(f"Load localization file {name} failed. Try set another localization file in settings panel.") 10 | return "" 11 | with open(os.path.join("localizations", name), "r", encoding="utf8") as lf: 12 | localization_file = lf.read() 13 | js = f"\n" 14 | return js 15 | -------------------------------------------------------------------------------- /localizations/en.json: -------------------------------------------------------------------------------- 1 | { 2 | "Drop Audio Here": "Drop Audio Here", 3 | "Click to Upload": "Click to Upload", 4 | 5 | "Text (press Ctrl+Enter or Alt+Enter to generate)": "Text (press Ctrl+Enter or Alt+Enter to generate)", 6 | "Generate": "Generate", 7 | "Speakers": "Speakers", 8 | "VITS Checkpoint": "VITS Checkpoint", 9 | "Process Method": "Process Method", 10 | "Speed": "Speed", 11 | "Output Message": "Output Message", 12 | "Output Audio": "Output Audio", 13 | "Open Folder": "Open Folder", 14 | "Save": "Save", 15 | "Simple": "Simple", 16 | "Batch Process": "Batch Process", 17 | "Multi Speakers": "Multi Speakers", 18 | "Apply settings": "Apply settings", 19 | "Reload UI": "Reload UI" 20 | } -------------------------------------------------------------------------------- /vits/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | from numpy import zeros, int32, float32 2 | from torch import from_numpy 3 | 4 | from .core import maximum_path_jit 5 | 6 | 7 | def maximum_path(neg_cent, mask): 8 | """ numba optimized version. 9 | neg_cent: [b, t_t, t_s] 10 | mask: [b, t_t, t_s] 11 | """ 12 | device = neg_cent.device 13 | dtype = neg_cent.dtype 14 | neg_cent = neg_cent.data.cpu().numpy().astype(float32) 15 | path = zeros(neg_cent.shape, dtype=int32) 16 | 17 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) 18 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) 19 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max) 20 | return from_numpy(path).to(device=device, dtype=dtype) 21 | 22 | -------------------------------------------------------------------------------- /vits/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | ''' 6 | _pad = '_' 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | 9 | _punctuation_zh = ';:,。!?-“”《》、()BP…—~.\·『』・ ' 10 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 11 | 12 | _numbers = '1234567890' 13 | _others = '' 14 | 15 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 19 | 20 | symbols_zh = [_pad] + list(_punctuation_zh) + list(_letters) + list(_numbers) 21 | 22 | # Special symbol ids 23 | SPACE_ID = symbols.index(" ") 24 | -------------------------------------------------------------------------------- /modules/devices.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules import options 3 | 4 | cpu = torch.device("cpu") 5 | cuda_available = torch.cuda.is_available() 6 | 7 | 8 | def get_cuda_device(): 9 | if options.cmd_opts.device_id is not None: 10 | return f"cuda:{options.cmd_opts.device_id}" 11 | 12 | return "cuda" 13 | 14 | 15 | def get_optimal_device(): 16 | if cuda_available: 17 | return torch.device(get_cuda_device()) 18 | return cpu 19 | 20 | 21 | def torch_gc(): 22 | if cuda_available: 23 | with torch.cuda.device(get_cuda_device()): 24 | torch.cuda.empty_cache() 25 | torch.cuda.ipc_collect() 26 | 27 | 28 | device = cpu if options.cmd_opts.use_cpu else get_optimal_device() 29 | 30 | if not cuda_available: 31 | print("CUDA is not available, using cpu mode...") 32 | -------------------------------------------------------------------------------- /vits/text/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cleaners 2 | 3 | 4 | def text_to_sequence(text, symbols, cleaner_names: str, cleaner=None): 5 | """ 6 | modified t2s: custom symbols and cleaner 7 | """ 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | 10 | sequence = [] 11 | if cleaner: 12 | clean_text = cleaner(text) 13 | else: 14 | clean_text = _clean_text(text, cleaner_names) 15 | for symbol in clean_text: 16 | if symbol not in _symbol_to_id.keys(): 17 | continue 18 | symbol_id = _symbol_to_id[symbol] 19 | sequence += [symbol_id] 20 | return sequence 21 | 22 | 23 | def _clean_text(text, cleaner_names): 24 | for name in cleaner_names: 25 | cleaner = getattr(cleaners, name) 26 | if not cleaner: 27 | raise Exception('Unknown cleaner: %s' % name) 28 | text = cleaner(text) 29 | return text 30 | -------------------------------------------------------------------------------- /vits/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /vits/text/thai.py: -------------------------------------------------------------------------------- 1 | import re 2 | from num_thai.thainumbers import NumThai 3 | 4 | 5 | num = NumThai() 6 | 7 | # List of (Latin alphabet, Thai) pairs: 8 | _latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 9 | ('a', 'เอ'), 10 | ('b','บี'), 11 | ('c','ซี'), 12 | ('d','ดี'), 13 | ('e','อี'), 14 | ('f','เอฟ'), 15 | ('g','จี'), 16 | ('h','เอช'), 17 | ('i','ไอ'), 18 | ('j','เจ'), 19 | ('k','เค'), 20 | ('l','แอล'), 21 | ('m','เอ็ม'), 22 | ('n','เอ็น'), 23 | ('o','โอ'), 24 | ('p','พี'), 25 | ('q','คิว'), 26 | ('r','แอร์'), 27 | ('s','เอส'), 28 | ('t','ที'), 29 | ('u','ยู'), 30 | ('v','วี'), 31 | ('w','ดับเบิลยู'), 32 | ('x','เอ็กซ์'), 33 | ('y','วาย'), 34 | ('z','ซี') 35 | ]] 36 | 37 | 38 | def num_to_thai(text): 39 | return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text) 40 | 41 | def latin_to_thai(text): 42 | for regex, replacement in _latin_to_thai: 43 | text = re.sub(regex, replacement, text) 44 | return text 45 | -------------------------------------------------------------------------------- /vits/text/ngu_dialect.py: -------------------------------------------------------------------------------- 1 | import re 2 | import opencc 3 | 4 | 5 | dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou', 6 | 'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing', 7 | 'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang', 8 | 'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan', 9 | 'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen', 10 | 'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'} 11 | 12 | converters = {} 13 | 14 | for dialect in dialects.values(): 15 | try: 16 | converters[dialect] = opencc.OpenCC(dialect) 17 | except: 18 | pass 19 | 20 | 21 | def ngu_dialect_to_ipa(text, dialect): 22 | dialect = dialects[dialect] 23 | text = converters[dialect].convert(text).replace('-','').replace('$',' ') 24 | text = re.sub(r'[、;:]', ',', text) 25 | text = re.sub(r'\s*,\s*', ', ', text) 26 | text = re.sub(r'\s*。\s*', '. ', text) 27 | text = re.sub(r'\s*?\s*', '? ', text) 28 | text = re.sub(r'\s*!\s*', '! ', text) 29 | text = re.sub(r'\s*$', '', text) 30 | return text 31 | -------------------------------------------------------------------------------- /vits/monotonic_align/core.py: -------------------------------------------------------------------------------- 1 | import numba 2 | 3 | 4 | @numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]), 5 | nopython=True, nogil=True) 6 | def maximum_path_jit(paths, values, t_ys, t_xs): 7 | b = paths.shape[0] 8 | max_neg_val = -1e9 9 | for i in range(int(b)): 10 | path = paths[i] 11 | value = values[i] 12 | t_y = t_ys[i] 13 | t_x = t_xs[i] 14 | 15 | v_prev = v_cur = 0.0 16 | index = t_x - 1 17 | 18 | for y in range(t_y): 19 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 20 | if x == y: 21 | v_cur = max_neg_val 22 | else: 23 | v_cur = value[y - 1, x] 24 | if x == 0: 25 | if y == 0: 26 | v_prev = 0. 27 | else: 28 | v_prev = max_neg_val 29 | else: 30 | v_prev = value[y - 1, x - 1] 31 | value[y, x] += max(v_prev, v_cur) 32 | 33 | for y in range(t_y - 1, -1, -1): 34 | path[y, index] = 1 35 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): 36 | index = index - 1 37 | -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules.utils import search_ext_file 3 | 4 | 5 | class ModelInfo: 6 | model_name: str 7 | model_folder: str 8 | model_hash: str 9 | checkpoint_path: str 10 | config_path: str 11 | 12 | def __init__(self, model_name, model_folder, model_hash, checkpoint_path, config_path): 13 | self.model_name = model_name 14 | self.model_folder = model_folder 15 | self.model_hash = model_hash 16 | self.checkpoint_path = checkpoint_path 17 | self.config_path = config_path 18 | self.custom_symbols = None 19 | 20 | 21 | def refresh_model_list(model_path): 22 | dirs = os.listdir(model_path) 23 | models = [] 24 | for d in dirs: 25 | p = os.path.join(model_path, d) 26 | if not os.path.isdir(p): 27 | continue 28 | pth_path = search_ext_file(p, ".pth") 29 | if not pth_path: 30 | print(f"Path {p} does not have a pth file, pass") 31 | continue 32 | config_path = search_ext_file(p, ".json") 33 | if not config_path: 34 | print(f"Path {p} does not have a config file, pass") 35 | continue 36 | models.append({ 37 | "dir": d, 38 | "pth": pth_path, 39 | "config": config_path 40 | }) 41 | return models 42 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import re 4 | import subprocess 5 | from typing import Optional 6 | 7 | # for export 8 | from vits.utils import get_hparams_from_file, HParams 9 | 10 | 11 | def search_ext_file(path: str, ext: str) -> Optional[str]: 12 | files = os.listdir(path) 13 | for f in files: 14 | if f.endswith(ext): 15 | return os.path.join(path, f) 16 | return None 17 | 18 | 19 | def model_hash(filename): 20 | try: 21 | with open(filename, "rb") as file: 22 | import hashlib 23 | m = hashlib.sha256() 24 | 25 | file.seek(0x100000) 26 | m.update(file.read(0x10000)) 27 | return m.hexdigest()[0:8] 28 | except FileNotFoundError: 29 | return 'FileNotFound' 30 | 31 | 32 | def open_folder(f): 33 | if not os.path.exists(f): 34 | print(f'Folder "{f}" does not exist.') 35 | return 36 | elif not os.path.isdir(f): 37 | return 38 | 39 | path = os.path.normpath(f) 40 | if platform.system() == "Windows": 41 | os.startfile(path) 42 | elif platform.system() == "Darwin": 43 | subprocess.Popen(["open", path]) 44 | elif "microsoft-standard-WSL2" in platform.uname().release: 45 | subprocess.Popen(["wsl-open", path]) 46 | else: 47 | subprocess.Popen(["xdg-open", path]) 48 | 49 | 50 | def windows_filename(s: str): 51 | return re.sub('[<>:"\/\\|?*\n\t\r]+', "", s) 52 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | # 転がる岩、君に朝が降る 2 | # Code with Love by @Akegarasu 3 | 4 | import os 5 | import sys 6 | 7 | # must import before other modules load model. 8 | import modules.safe 9 | import modules.path 10 | 11 | from modules.ui import create_ui 12 | import modules.vits_model as vits_model 13 | import modules.sovits_model as sovits_model 14 | from modules.options import cmd_opts 15 | 16 | 17 | def init(): 18 | print(f"Launching webui with arguments: {' '.join(sys.argv[1:])}") 19 | ensure_output_dirs() 20 | vits_model.refresh_list() 21 | sovits_model.refresh_list() 22 | if cmd_opts.ui_debug_mode: 23 | return 24 | # todo: autoload last model 25 | # load_last_model() 26 | 27 | 28 | def ensure_output_dirs(): 29 | folders = ["outputs/vits", "outputs/vits-batch", "outputs/sovits", "outputs/sovits", "outputs/sovits-batch", "temp"] 30 | 31 | def check_and_create(p): 32 | if not os.path.exists(p): 33 | os.makedirs(p) 34 | 35 | for i in folders: 36 | check_and_create(i) 37 | 38 | 39 | def run(): 40 | init() 41 | app = create_ui() 42 | if cmd_opts.server_name: 43 | server_name = cmd_opts.server_name 44 | else: 45 | server_name = "0.0.0.0" if cmd_opts.listen else None 46 | 47 | app.queue(default_enabled=False).launch( 48 | share=cmd_opts.share, 49 | server_name=server_name, 50 | server_port=cmd_opts.port, 51 | inbrowser=cmd_opts.autolaunch, 52 | show_api=False 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | run() 58 | -------------------------------------------------------------------------------- /vits/text/sanskrit.py: -------------------------------------------------------------------------------- 1 | import re 2 | from indic_transliteration import sanscript 3 | 4 | 5 | # List of (iast, ipa) pairs: 6 | _iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 7 | ('a', 'ə'), 8 | ('ā', 'aː'), 9 | ('ī', 'iː'), 10 | ('ū', 'uː'), 11 | ('ṛ', 'ɹ`'), 12 | ('ṝ', 'ɹ`ː'), 13 | ('ḷ', 'l`'), 14 | ('ḹ', 'l`ː'), 15 | ('e', 'eː'), 16 | ('o', 'oː'), 17 | ('k', 'k⁼'), 18 | ('k⁼h', 'kʰ'), 19 | ('g', 'g⁼'), 20 | ('g⁼h', 'gʰ'), 21 | ('ṅ', 'ŋ'), 22 | ('c', 'ʧ⁼'), 23 | ('ʧ⁼h', 'ʧʰ'), 24 | ('j', 'ʥ⁼'), 25 | ('ʥ⁼h', 'ʥʰ'), 26 | ('ñ', 'n^'), 27 | ('ṭ', 't`⁼'), 28 | ('t`⁼h', 't`ʰ'), 29 | ('ḍ', 'd`⁼'), 30 | ('d`⁼h', 'd`ʰ'), 31 | ('ṇ', 'n`'), 32 | ('t', 't⁼'), 33 | ('t⁼h', 'tʰ'), 34 | ('d', 'd⁼'), 35 | ('d⁼h', 'dʰ'), 36 | ('p', 'p⁼'), 37 | ('p⁼h', 'pʰ'), 38 | ('b', 'b⁼'), 39 | ('b⁼h', 'bʰ'), 40 | ('y', 'j'), 41 | ('ś', 'ʃ'), 42 | ('ṣ', 's`'), 43 | ('r', 'ɾ'), 44 | ('l̤', 'l`'), 45 | ('h', 'ɦ'), 46 | ("'", ''), 47 | ('~', '^'), 48 | ('ṃ', '^') 49 | ]] 50 | 51 | 52 | def devanagari_to_ipa(text): 53 | text = text.replace('ॐ', 'ओम्') 54 | text = re.sub(r'\s*।\s*$', '.', text) 55 | text = re.sub(r'\s*।\s*', ', ', text) 56 | text = re.sub(r'\s*॥', '.', text) 57 | text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST) 58 | for regex, replacement in _iast_to_ipa: 59 | text = re.sub(regex, replacement, text) 60 | text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0) 61 | [:-1]+'h'+x.group(1)+'*', text) 62 | return text 63 | -------------------------------------------------------------------------------- /vits/text/cantonese.py: -------------------------------------------------------------------------------- 1 | import re 2 | import cn2an 3 | import opencc 4 | 5 | 6 | converter = opencc.OpenCC('jyutjyu') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ei˥'), 11 | ('B', 'biː˥'), 12 | ('C', 'siː˥'), 13 | ('D', 'tiː˥'), 14 | ('E', 'iː˥'), 15 | ('F', 'e˥fuː˨˩'), 16 | ('G', 'tsiː˥'), 17 | ('H', 'ɪk̚˥tsʰyː˨˩'), 18 | ('I', 'ɐi˥'), 19 | ('J', 'tsei˥'), 20 | ('K', 'kʰei˥'), 21 | ('L', 'e˥llou˨˩'), 22 | ('M', 'ɛːm˥'), 23 | ('N', 'ɛːn˥'), 24 | ('O', 'ou˥'), 25 | ('P', 'pʰiː˥'), 26 | ('Q', 'kʰiːu˥'), 27 | ('R', 'aː˥lou˨˩'), 28 | ('S', 'ɛː˥siː˨˩'), 29 | ('T', 'tʰiː˥'), 30 | ('U', 'juː˥'), 31 | ('V', 'wiː˥'), 32 | ('W', 'tʊk̚˥piː˥juː˥'), 33 | ('X', 'ɪk̚˥siː˨˩'), 34 | ('Y', 'waːi˥'), 35 | ('Z', 'iː˨sɛːt̚˥') 36 | ]] 37 | 38 | 39 | def number_to_cantonese(text): 40 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text) 41 | 42 | 43 | def latin_to_ipa(text): 44 | for regex, replacement in _latin_to_ipa: 45 | text = re.sub(regex, replacement, text) 46 | return text 47 | 48 | 49 | def cantonese_to_ipa(text): 50 | text = number_to_cantonese(text.upper()) 51 | text = converter.convert(text).replace('-','').replace('$',' ') 52 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text) 53 | text = re.sub(r'[、;:]', ',', text) 54 | text = re.sub(r'\s*,\s*', ', ', text) 55 | text = re.sub(r'\s*。\s*', '. ', text) 56 | text = re.sub(r'\s*?\s*', '? ', text) 57 | text = re.sub(r'\s*!\s*', '! ', text) 58 | text = re.sub(r'\s*$', '', text) 59 | return text 60 | -------------------------------------------------------------------------------- /vits/text/shanghainese.py: -------------------------------------------------------------------------------- 1 | import re 2 | import cn2an 3 | import opencc 4 | 5 | 6 | converter = opencc.OpenCC('zaonhe') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ᴇ'), 11 | ('B', 'bi'), 12 | ('C', 'si'), 13 | ('D', 'di'), 14 | ('E', 'i'), 15 | ('F', 'ᴇf'), 16 | ('G', 'dʑi'), 17 | ('H', 'ᴇtɕʰ'), 18 | ('I', 'ᴀi'), 19 | ('J', 'dʑᴇ'), 20 | ('K', 'kʰᴇ'), 21 | ('L', 'ᴇl'), 22 | ('M', 'ᴇm'), 23 | ('N', 'ᴇn'), 24 | ('O', 'o'), 25 | ('P', 'pʰi'), 26 | ('Q', 'kʰiu'), 27 | ('R', 'ᴀl'), 28 | ('S', 'ᴇs'), 29 | ('T', 'tʰi'), 30 | ('U', 'ɦiu'), 31 | ('V', 'vi'), 32 | ('W', 'dᴀbɤliu'), 33 | ('X', 'ᴇks'), 34 | ('Y', 'uᴀi'), 35 | ('Z', 'zᴇ') 36 | ]] 37 | 38 | 39 | def _number_to_shanghainese(num): 40 | num = cn2an.an2cn(num).replace('一十','十').replace('二十', '廿').replace('二', '两') 41 | return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num) 42 | 43 | 44 | def number_to_shanghainese(text): 45 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text) 46 | 47 | 48 | def latin_to_ipa(text): 49 | for regex, replacement in _latin_to_ipa: 50 | text = re.sub(regex, replacement, text) 51 | return text 52 | 53 | 54 | def shanghainese_to_ipa(text): 55 | text = number_to_shanghainese(text.upper()) 56 | text = converter.convert(text).replace('-','').replace('$',' ') 57 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text) 58 | text = re.sub(r'[、;:]', ',', text) 59 | text = re.sub(r'\s*,\s*', ', ', text) 60 | text = re.sub(r'\s*。\s*', '. ', text) 61 | text = re.sub(r'\s*?\s*', '? ', text) 62 | text = re.sub(r'\s*!\s*', '! ', text) 63 | text = re.sub(r'\s*$', '', text) 64 | return text 65 | -------------------------------------------------------------------------------- /script.js: -------------------------------------------------------------------------------- 1 | let uiUpdateCallbacks = [] 2 | let uiTabChangeCallbacks = [] 3 | let uiCurrentTab = null; 4 | 5 | 6 | function gradioApp() { 7 | const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot 8 | return !!gradioShadowRoot ? gradioShadowRoot : document; 9 | } 10 | 11 | function get_uiCurrentTab() { 12 | return gradioApp().querySelector('.tabs button:not(.border-transparent)') 13 | } 14 | 15 | function get_uiCurrentTabContent() { 16 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])') 17 | } 18 | 19 | function onUiUpdate(callback) { 20 | uiUpdateCallbacks.push(callback) 21 | } 22 | 23 | function onUiTabChange(callback) { 24 | uiTabChangeCallbacks.push(callback) 25 | } 26 | 27 | function runCallback(x, m) { 28 | try { 29 | x(m) 30 | } catch (e) { 31 | (console.error || console.log).call(console, e.message, e); 32 | } 33 | } 34 | 35 | function executeCallbacks(queue, m) { 36 | queue.forEach(function (x) { 37 | runCallback(x, m) 38 | }) 39 | } 40 | 41 | document.addEventListener("DOMContentLoaded", function () { 42 | let mutationObserver = new MutationObserver(function (m) { 43 | executeCallbacks(uiUpdateCallbacks, m); 44 | const newTab = get_uiCurrentTab(); 45 | if (newTab && (newTab !== uiCurrentTab)) { 46 | uiCurrentTab = newTab; 47 | executeCallbacks(uiTabChangeCallbacks); 48 | } 49 | }); 50 | mutationObserver.observe(gradioApp(), {childList: true, subtree: true}) 51 | }); 52 | 53 | document.addEventListener('keydown', function (e) { 54 | let handled = false; 55 | if (e.key !== undefined) { 56 | if ((e.key === "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 57 | } else if (e.keyCode !== undefined) { 58 | if ((e.keyCode === 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; 59 | } 60 | if (handled) { 61 | let button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); 62 | if (button) { 63 | button.click(); 64 | } 65 | e.preventDefault(); 66 | } 67 | }) 68 | -------------------------------------------------------------------------------- /scripts/localization.js: -------------------------------------------------------------------------------- 1 | // localization = {} -- the dict with translations is created by the backend 2 | 3 | ignore_ids_for_localization = {} 4 | 5 | re_num = /^[\.\d]+$/ 6 | re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u 7 | 8 | original_lines = {} 9 | translated_lines = {} 10 | 11 | function textNodesUnder(el) { 12 | var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false); 13 | while (n = walk.nextNode()) a.push(n); 14 | return a; 15 | } 16 | 17 | function canBeTranslated(node, text) { 18 | if (!text) return false; 19 | if (!node.parentElement) return false; 20 | 21 | parentType = node.parentElement.nodeName 22 | if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false; 23 | 24 | if (parentType == 'OPTION' || parentType == 'SPAN') { 25 | pnode = node 26 | for (var level = 0; level < 4; level++) { 27 | pnode = pnode.parentElement 28 | if (!pnode) break; 29 | 30 | if (ignore_ids_for_localization[pnode.id] == parentType) return false; 31 | } 32 | } 33 | 34 | if (re_num.test(text)) return false; 35 | if (re_emoji.test(text)) return false; 36 | return true 37 | } 38 | 39 | function getTranslation(text) { 40 | if (!text) return undefined 41 | 42 | if (translated_lines[text] === undefined) { 43 | original_lines[text] = 1 44 | } 45 | 46 | tl = localization[text] 47 | if (tl !== undefined) { 48 | translated_lines[tl] = 1 49 | } 50 | 51 | return tl 52 | } 53 | 54 | function processTextNode(node) { 55 | text = node.textContent.trim() 56 | 57 | if (!canBeTranslated(node, text)) return 58 | 59 | tl = getTranslation(text) 60 | if (tl !== undefined) { 61 | node.textContent = tl 62 | } 63 | } 64 | 65 | function processNode(node) { 66 | if (node.nodeType == 3) { 67 | processTextNode(node) 68 | return 69 | } 70 | 71 | if (node.title) { 72 | tl = getTranslation(node.title) 73 | if (tl !== undefined) { 74 | node.title = tl 75 | } 76 | } 77 | 78 | if (node.placeholder) { 79 | tl = getTranslation(node.placeholder) 80 | if (tl !== undefined) { 81 | node.placeholder = tl 82 | } 83 | } 84 | 85 | textNodesUnder(node).forEach(function (node) { 86 | processTextNode(node) 87 | }) 88 | } 89 | 90 | onUiUpdate(function (m) { 91 | m.forEach(function (mutation) { 92 | mutation.addedNodes.forEach(function (node) { 93 | processNode(node) 94 | }) 95 | }); 96 | }) 97 | 98 | 99 | document.addEventListener("DOMContentLoaded", function () { 100 | processNode(gradioApp()) 101 | 102 | if (localization.rtl) { // if the language is from right to left, 103 | (new MutationObserver((mutations, observer) => { // wait for the style to load 104 | mutations.forEach(mutation => { 105 | mutation.addedNodes.forEach(node => { 106 | if (node.tagName === 'STYLE') { 107 | observer.disconnect(); 108 | 109 | for (const x of node.sheet.rules) { // find all rtl media rules 110 | if (Array.from(x.media || []).includes('rtl')) { 111 | x.media.appendMedium('all'); // enable them 112 | } 113 | } 114 | } 115 | }) 116 | }); 117 | })).observe(gradioApp(), {childList: true}); 118 | } 119 | }) 120 | -------------------------------------------------------------------------------- /vits/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from librosa.filters import mel as librosa_mel_fn 4 | 5 | MAX_WAV_VALUE = 32768.0 6 | 7 | 8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 9 | """ 10 | PARAMS 11 | ------ 12 | C: compression factor 13 | """ 14 | return torch.log(torch.clamp(x, min=clip_val) * C) 15 | 16 | 17 | def dynamic_range_decompression_torch(x, C=1): 18 | """ 19 | PARAMS 20 | ------ 21 | C: compression factor used to compress 22 | """ 23 | return torch.exp(x) / C 24 | 25 | 26 | def spectral_normalize_torch(magnitudes): 27 | output = dynamic_range_compression_torch(magnitudes) 28 | return output 29 | 30 | 31 | def spectral_de_normalize_torch(magnitudes): 32 | output = dynamic_range_decompression_torch(magnitudes) 33 | return output 34 | 35 | 36 | mel_basis = {} 37 | hann_window = {} 38 | 39 | 40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 41 | if torch.min(y) < -1.: 42 | print('min value is ', torch.min(y)) 43 | if torch.max(y) > 1.: 44 | print('max value is ', torch.max(y)) 45 | 46 | global hann_window 47 | dtype_device = str(y.dtype) + '_' + str(y.device) 48 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 49 | if wnsize_dtype_device not in hann_window: 50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 51 | 52 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 53 | mode='reflect') 54 | y = y.squeeze(1) 55 | 56 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 57 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 58 | 59 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 60 | return spec 61 | 62 | 63 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 64 | global mel_basis 65 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 66 | fmax_dtype_device = str(fmax) + '_' + dtype_device 67 | if fmax_dtype_device not in mel_basis: 68 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 69 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 70 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 71 | spec = spectral_normalize_torch(spec) 72 | return spec 73 | 74 | 75 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 76 | if torch.min(y) < -1.: 77 | print('min value is ', torch.min(y)) 78 | if torch.max(y) > 1.: 79 | print('max value is ', torch.max(y)) 80 | 81 | global mel_basis, hann_window 82 | dtype_device = str(y.dtype) + '_' + str(y.device) 83 | fmax_dtype_device = str(fmax) + '_' + dtype_device 84 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 85 | if fmax_dtype_device not in mel_basis: 86 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 87 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 88 | if wnsize_dtype_device not in hann_window: 89 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 90 | 91 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 92 | mode='reflect') 93 | y = y.squeeze(1) 94 | 95 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 96 | center=center, pad_mode='reflect', normalized=False, onesided=True) 97 | 98 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 99 | 100 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 101 | spec = spectral_normalize_torch(spec) 102 | 103 | return spec 104 | -------------------------------------------------------------------------------- /vits/text/japanese.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unidecode import unidecode 3 | import pyopenjtalk 4 | 5 | 6 | # Regular expression matching Japanese without punctuation marks: 7 | _japanese_characters = re.compile( 8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 9 | 10 | # Regular expression matching non-Japanese characters or punctuation marks: 11 | _japanese_marks = re.compile( 12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 13 | 14 | # List of (symbol, Japanese) pairs for marks: 15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 16 | ('%', 'パーセント') 17 | ]] 18 | 19 | # List of (romaji, ipa) pairs for marks: 20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 21 | ('ts', 'ʦ'), 22 | ('u', 'ɯ'), 23 | ('j', 'ʥ'), 24 | ('y', 'j'), 25 | ('ni', 'n^i'), 26 | ('nj', 'n^'), 27 | ('hi', 'çi'), 28 | ('hj', 'ç'), 29 | ('f', 'ɸ'), 30 | ('I', 'i*'), 31 | ('U', 'ɯ*'), 32 | ('r', 'ɾ') 33 | ]] 34 | 35 | # List of (romaji, ipa2) pairs for marks: 36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 37 | ('u', 'ɯ'), 38 | ('ʧ', 'tʃ'), 39 | ('j', 'dʑ'), 40 | ('y', 'j'), 41 | ('ni', 'n^i'), 42 | ('nj', 'n^'), 43 | ('hi', 'çi'), 44 | ('hj', 'ç'), 45 | ('f', 'ɸ'), 46 | ('I', 'i*'), 47 | ('U', 'ɯ*'), 48 | ('r', 'ɾ') 49 | ]] 50 | 51 | # List of (consonant, sokuon) pairs: 52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 53 | (r'Q([↑↓]*[kg])', r'k#\1'), 54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 55 | (r'Q([↑↓]*[sʃ])', r's\1'), 56 | (r'Q([↑↓]*[pb])', r'p#\1') 57 | ]] 58 | 59 | # List of (consonant, hatsuon) pairs: 60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 61 | (r'N([↑↓]*[pbm])', r'm\1'), 62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 63 | (r'N([↑↓]*[tdn])', r'n\1'), 64 | (r'N([↑↓]*[kg])', r'ŋ\1') 65 | ]] 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def japanese_to_romaji_with_accent(text): 75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 76 | text = symbols_to_japanese(text) 77 | sentences = re.split(_japanese_marks, text) 78 | marks = re.findall(_japanese_marks, text) 79 | text = '' 80 | for i, sentence in enumerate(sentences): 81 | if re.match(_japanese_characters, sentence): 82 | if text != '': 83 | text += ' ' 84 | labels = pyopenjtalk.extract_fullcontext(sentence) 85 | for n, label in enumerate(labels): 86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 87 | if phoneme not in ['sil', 'pau']: 88 | text += phoneme.replace('ch', 'ʧ').replace('sh', 89 | 'ʃ').replace('cl', 'Q') 90 | else: 91 | continue 92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 95 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: 97 | a2_next = -1 98 | else: 99 | a2_next = int( 100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 101 | # Accent phrase boundary 102 | if a3 == 1 and a2_next == 1: 103 | text += ' ' 104 | # Falling 105 | elif a1 == 0 and a2_next == a2 + 1: 106 | text += '↓' 107 | # Rising 108 | elif a2 == 1 and a2_next == 2: 109 | text += '↑' 110 | if i < len(marks): 111 | text += unidecode(marks[i]).replace(' ', '') 112 | return text 113 | 114 | 115 | def get_real_sokuon(text): 116 | for regex, replacement in _real_sokuon: 117 | text = re.sub(regex, replacement, text) 118 | return text 119 | 120 | 121 | def get_real_hatsuon(text): 122 | for regex, replacement in _real_hatsuon: 123 | text = re.sub(regex, replacement, text) 124 | return text 125 | 126 | 127 | def japanese_to_ipa(text): 128 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 129 | text = re.sub( 130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 131 | text = get_real_sokuon(text) 132 | text = get_real_hatsuon(text) 133 | for regex, replacement in _romaji_to_ipa: 134 | text = re.sub(regex, replacement, text) 135 | return text 136 | 137 | 138 | def japanese_to_ipa2(text): 139 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 140 | text = get_real_sokuon(text) 141 | text = get_real_hatsuon(text) 142 | for regex, replacement in _romaji_to_ipa2: 143 | text = re.sub(regex, replacement, text) 144 | return text 145 | 146 | 147 | def japanese_to_ipa3(text): 148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') 150 | text = re.sub( 151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) 153 | return text 154 | -------------------------------------------------------------------------------- /modules/vits_model.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import os.path 3 | from typing import List, Dict 4 | 5 | import torch 6 | 7 | import modules.devices as devices 8 | import vits.utils 9 | from modules.model import ModelInfo, refresh_model_list 10 | from modules.utils import search_ext_file, model_hash 11 | from vits.models import SynthesizerTrn 12 | from vits.text.symbols import symbols as builtin_symbols 13 | from vits.utils import HParams 14 | 15 | # todo: cmdline here 16 | MODEL_PATH = os.path.join(os.path.join(os.getcwd(), "models"), "vits") 17 | 18 | 19 | class VITSModel: 20 | model: SynthesizerTrn 21 | hps: HParams 22 | symbols: List[str] 23 | 24 | model_name: str 25 | model_folder: str 26 | checkpoint_path: str 27 | config_path: str 28 | 29 | speakers: List[str] 30 | 31 | def __init__(self, info: ModelInfo): 32 | self.model_name = info.model_name 33 | self.model_folder = info.model_folder 34 | self.checkpoint_path = info.checkpoint_path 35 | self.config_path = info.config_path 36 | self.custom_symbols = None 37 | # self.state = "" # maybe for multiprocessing 38 | 39 | def load_model(self): 40 | hps = vits.utils.get_hparams_from_file(self.config_path) 41 | self.load_custom_symbols(f"{self.model_folder}/symbols.py") 42 | if self.custom_symbols: 43 | _symbols = self.custom_symbols.symbols 44 | elif "symbols" in hps: 45 | _symbols = hps.symbols 46 | else: 47 | _symbols = builtin_symbols 48 | 49 | if hasattr(self.custom_symbols, "symbols_zh"): 50 | hps["symbols_zh"] = self.custom_symbols.symbols_zh 51 | 52 | model = SynthesizerTrn( 53 | len(_symbols), 54 | hps.data.filter_length // 2 + 1, 55 | hps.train.segment_size // hps.data.hop_length, 56 | n_speakers=hps.data.n_speakers, 57 | **hps.model) 58 | model, _, _, _ = load_checkpoint(checkpoint_path=self.checkpoint_path, 59 | model=model, optimizer=None) 60 | model.eval().to(devices.device) 61 | 62 | self.speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"] 63 | self.model = model 64 | self.hps = hps 65 | self.symbols = _symbols 66 | 67 | def load_custom_symbols(self, symbol_path): 68 | if os.path.exists(symbol_path): 69 | spec = importlib.util.spec_from_file_location('symbols', symbol_path) 70 | _sym = importlib.util.module_from_spec(spec) 71 | spec.loader.exec_module(_sym) 72 | if not hasattr(_sym, "symbols"): 73 | print(f"Loading symbol file {symbol_path} failed, so such attr") 74 | return 75 | self.custom_symbols = _sym 76 | 77 | 78 | vits_model_list: Dict[str, ModelInfo] = {} 79 | curr_vits_model: VITSModel = None 80 | 81 | 82 | def get_model() -> VITSModel: 83 | return curr_vits_model 84 | 85 | 86 | def get_model_name(): 87 | return curr_vits_model.model_name if curr_vits_model is not None else None 88 | 89 | 90 | def get_model_list(): 91 | return [k for k, _ in vits_model_list.items()] 92 | 93 | 94 | def get_speakers(): 95 | return curr_vits_model.speakers if curr_vits_model is not None else ["None"] 96 | 97 | 98 | def refresh_list(): 99 | vits_model_list.clear() 100 | model_list = refresh_model_list(model_path=MODEL_PATH) 101 | for m in model_list: 102 | d = m["dir"] 103 | p = os.path.join(MODEL_PATH, m["dir"]) 104 | pth_path = m["pth"] 105 | config_path = m["config"] 106 | vits_model_list[d] = ModelInfo( 107 | model_name=d, 108 | model_folder=p, 109 | model_hash=model_hash(pth_path), 110 | checkpoint_path=pth_path, 111 | config_path=config_path 112 | ) 113 | if len(vits_model_list.items()) == 0: 114 | print("No vits model found. Please put a model in models/vits") 115 | 116 | 117 | def init_load_model(): 118 | info = next(iter(vits_model_list.values())) 119 | load_model(info.model_name) 120 | 121 | 122 | def init_model(): 123 | global curr_vits_model 124 | info = next(iter(vits_model_list.values())) 125 | curr_vits_model = VITSModel(info) 126 | # load_model(info.model_name) 127 | 128 | 129 | def load_model(model_name: str): 130 | global curr_vits_model, vits_model_list 131 | if curr_vits_model and curr_vits_model.model_name == model_name: 132 | return 133 | info = vits_model_list[model_name] 134 | print(f"Loading weights [{info.model_hash}] from {info.checkpoint_path}...") 135 | m = VITSModel(info) 136 | m.load_model() 137 | curr_vits_model = m 138 | print("Model loaded.") 139 | 140 | 141 | def load_checkpoint(checkpoint_path, model, optimizer=None): 142 | assert os.path.isfile(checkpoint_path) 143 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 144 | iteration = checkpoint_dict['iteration'] 145 | learning_rate = checkpoint_dict['learning_rate'] 146 | if optimizer is not None: 147 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 148 | saved_state_dict = checkpoint_dict['model'] 149 | if hasattr(model, 'module'): 150 | state_dict = model.module.state_dict() 151 | else: 152 | state_dict = model.state_dict() 153 | new_state_dict = {} 154 | for k, v in state_dict.items(): 155 | try: 156 | new_state_dict[k] = saved_state_dict[k] 157 | except: 158 | new_state_dict[k] = v 159 | if hasattr(model, 'module'): 160 | model.module.load_state_dict(new_state_dict) 161 | else: 162 | model.load_state_dict(new_state_dict) 163 | return model, optimizer, learning_rate, iteration 164 | -------------------------------------------------------------------------------- /vits/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.jit 5 | 6 | 7 | def script_method(fn, _rcb=None): 8 | return fn 9 | 10 | 11 | def script(obj, optimize=True, _frames_up=0, _rcb=None): 12 | return obj 13 | 14 | 15 | torch.jit.script_method = script_method 16 | torch.jit.script = script 17 | 18 | 19 | def init_weights(m, mean=0.0, std=0.01): 20 | classname = m.__class__.__name__ 21 | if classname.find("Conv") != -1: 22 | m.weight.data.normal_(mean, std) 23 | 24 | 25 | def get_padding(kernel_size, dilation=1): 26 | return int((kernel_size * dilation - dilation) / 2) 27 | 28 | 29 | def convert_pad_shape(pad_shape): 30 | l = pad_shape[::-1] 31 | pad_shape = [item for sublist in l for item in sublist] 32 | return pad_shape 33 | 34 | 35 | def intersperse(lst, item): 36 | result = [item] * (len(lst) * 2 + 1) 37 | result[1::2] = lst 38 | return result 39 | 40 | 41 | def kl_divergence(m_p, logs_p, m_q, logs_q): 42 | """KL(P||Q)""" 43 | kl = (logs_q - logs_p) - 0.5 44 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q) 45 | return kl 46 | 47 | 48 | def rand_gumbel(shape): 49 | """Sample from the Gumbel distribution, protect from overflows.""" 50 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 51 | return -torch.log(-torch.log(uniform_samples)) 52 | 53 | 54 | def rand_gumbel_like(x): 55 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 56 | return g 57 | 58 | 59 | def slice_segments(x, ids_str, segment_size=4): 60 | ret = torch.zeros_like(x[:, :, :segment_size]) 61 | for i in range(x.size(0)): 62 | idx_str = ids_str[i] 63 | idx_end = idx_str + segment_size 64 | ret[i] = x[i, :, idx_str:idx_end] 65 | return ret 66 | 67 | 68 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 69 | b, d, t = x.size() 70 | if x_lengths is None: 71 | x_lengths = t 72 | ids_str_max = x_lengths - segment_size + 1 73 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 74 | ret = slice_segments(x, ids_str, segment_size) 75 | return ret, ids_str 76 | 77 | 78 | def get_timing_signal_1d( 79 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 80 | position = torch.arange(length, dtype=torch.float) 81 | num_timescales = channels // 2 82 | log_timescale_increment = ( 83 | math.log(float(max_timescale) / float(min_timescale)) / 84 | (num_timescales - 1)) 85 | inv_timescales = min_timescale * torch.exp( 86 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 87 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 88 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 89 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 90 | signal = signal.view(1, channels, length) 91 | return signal 92 | 93 | 94 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 95 | b, channels, length = x.size() 96 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 97 | return x + signal.to(dtype=x.dtype, device=x.device) 98 | 99 | 100 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 101 | b, channels, length = x.size() 102 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 103 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 104 | 105 | 106 | def subsequent_mask(length): 107 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 108 | return mask 109 | 110 | 111 | @torch.jit.script 112 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 113 | n_channels_int = n_channels[0] 114 | in_act = input_a + input_b 115 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 116 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 117 | acts = t_act * s_act 118 | return acts 119 | 120 | 121 | def convert_pad_shape(pad_shape): 122 | l = pad_shape[::-1] 123 | pad_shape = [item for sublist in l for item in sublist] 124 | return pad_shape 125 | 126 | 127 | def shift_1d(x): 128 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 129 | return x 130 | 131 | 132 | def sequence_mask(length, max_length=None): 133 | if max_length is None: 134 | max_length = length.max() 135 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 136 | return x.unsqueeze(0) < length.unsqueeze(1) 137 | 138 | 139 | def generate_path(duration, mask): 140 | """ 141 | duration: [b, 1, t_x] 142 | mask: [b, 1, t_y, t_x] 143 | """ 144 | device = duration.device 145 | 146 | b, _, t_y, t_x = mask.shape 147 | cum_duration = torch.cumsum(duration, -1) 148 | 149 | cum_duration_flat = cum_duration.view(b * t_x) 150 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 151 | path = path.view(b, t_x, t_y) 152 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 153 | path = path.unsqueeze(1).transpose(2, 3) * mask 154 | return path 155 | 156 | 157 | def clip_grad_value_(parameters, clip_value, norm_type=2): 158 | if isinstance(parameters, torch.Tensor): 159 | parameters = [parameters] 160 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 161 | norm_type = float(norm_type) 162 | if clip_value is not None: 163 | clip_value = float(clip_value) 164 | 165 | total_norm = 0 166 | for p in parameters: 167 | param_norm = p.grad.data.norm(norm_type) 168 | total_norm += param_norm.item() ** norm_type 169 | if clip_value is not None: 170 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 171 | total_norm = total_norm ** (1. / norm_type) 172 | return total_norm 173 | -------------------------------------------------------------------------------- /modules/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import gradio as gr 5 | 6 | from modules.localization import localization_files 7 | 8 | config_file = "config.json" 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name") 12 | parser.add_argument("--port", type=int, help="launch gradio with given server port, defaults to 8860 if available", default="8860") 13 | parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") 14 | parser.add_argument("--server-name", type=str, help="sets hostname of server", default=None) 15 | parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) 16 | parser.add_argument("--device-id", type=str, help="select the default CUDA device to use", default=None) 17 | parser.add_argument("--use-cpu", action='store_true', help="use cpu") 18 | parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable safe unpickle") 19 | parser.add_argument("--freeze-settings", action='store_true', help="freeze settings") 20 | parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") 21 | 22 | # todo: use logger 23 | parser.add_argument("--debug", action='store_true', help="Output debug log") 24 | 25 | cmd_opts = parser.parse_args() 26 | 27 | 28 | class OptionInfo: 29 | def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): 30 | self.default = default 31 | self.label = label 32 | self.component = component 33 | self.component_args = component_args 34 | self.onchange = onchange 35 | self.section = section 36 | self.refresh = refresh 37 | 38 | 39 | options_templates = {} 40 | 41 | # options_templates.update({ 42 | # "vits_model": OptionInfo(None, "VITS checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), 43 | # }) 44 | 45 | options_templates.update({ 46 | "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + localization_files}), 47 | }) 48 | 49 | 50 | class Options: 51 | data = None 52 | data_labels = options_templates 53 | type_map = {int: float} 54 | 55 | def __init__(self): 56 | self.data = {k: v.default for k, v in self.data_labels.items()} 57 | 58 | def __setattr__(self, key, value): 59 | if self.data is not None: 60 | if key in self.data or key in self.data_labels: 61 | assert not cmd_opts.freeze_settings, "changing settings is disabled" 62 | 63 | info = opts.data_labels.get(key, None) 64 | comp_args = info.component_args if info else None 65 | if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: 66 | raise RuntimeError(f"not possible to set {key} because it is restricted") 67 | 68 | self.data[key] = value 69 | return 70 | 71 | return super(Options, self).__setattr__(key, value) 72 | 73 | def __getattr__(self, item): 74 | if self.data is not None: 75 | if item in self.data: 76 | return self.data[item] 77 | 78 | if item in self.data_labels: 79 | return self.data_labels[item].default 80 | 81 | return super(Options, self).__getattribute__(item) 82 | 83 | def set(self, key, value): 84 | oldval = self.data.get(key, None) 85 | if oldval == value: 86 | return False 87 | 88 | try: 89 | setattr(self, key, value) 90 | except RuntimeError: 91 | return False 92 | 93 | if self.data_labels[key].onchange is not None: 94 | try: 95 | self.data_labels[key].onchange() 96 | except Exception as e: 97 | print(e) 98 | print(f"Error when handling onchange event: changing setting {key} to {value}") 99 | setattr(self, key, oldval) 100 | return False 101 | 102 | return True 103 | 104 | def save(self, filename=config_file): 105 | assert not cmd_opts.freeze_settings, "saving settings is disabled" 106 | 107 | with open(filename, "w", encoding="utf8") as file: 108 | json.dump(self.data, file, indent=4, ensure_ascii=False) 109 | 110 | def same_type(self, x, y): 111 | if x is None or y is None: 112 | return True 113 | 114 | type_x = self.type_map.get(type(x), type(x)) 115 | type_y = self.type_map.get(type(y), type(y)) 116 | 117 | return type_x == type_y 118 | 119 | def load(self, filename): 120 | with open(filename, "r", encoding="utf8") as file: 121 | self.data = json.load(file) 122 | 123 | bad_settings = 0 124 | for k, v in self.data.items(): 125 | info = self.data_labels.get(k, None) 126 | if info is not None and not self.same_type(info.default, v): 127 | print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})") 128 | bad_settings += 1 129 | 130 | if bad_settings > 0: 131 | print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.") 132 | 133 | def onchange(self, key, func, call=True): 134 | item = self.data_labels.get(key) 135 | item.onchange = func 136 | 137 | if call: 138 | func() 139 | 140 | def dump_json(self): 141 | d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} 142 | return json.dumps(d) 143 | 144 | def add_option(self, key, info): 145 | self.data_labels[key] = info 146 | 147 | 148 | opts = Options() 149 | if os.path.exists(config_file): 150 | opts.load(config_file) 151 | else: 152 | print("Config not found, generating default config...") 153 | opts.save() 154 | -------------------------------------------------------------------------------- /vits/text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | import inflect 21 | from unidecode import unidecode 22 | import eng_to_ipa as ipa 23 | _inflect = inflect.engine() 24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 29 | _number_re = re.compile(r'[0-9]+') 30 | 31 | # List of (regular expression, replacement) pairs for abbreviations: 32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 33 | ('mrs', 'misess'), 34 | ('mr', 'mister'), 35 | ('dr', 'doctor'), 36 | ('st', 'saint'), 37 | ('co', 'company'), 38 | ('jr', 'junior'), 39 | ('maj', 'major'), 40 | ('gen', 'general'), 41 | ('drs', 'doctors'), 42 | ('rev', 'reverend'), 43 | ('lt', 'lieutenant'), 44 | ('hon', 'honorable'), 45 | ('sgt', 'sergeant'), 46 | ('capt', 'captain'), 47 | ('esq', 'esquire'), 48 | ('ltd', 'limited'), 49 | ('col', 'colonel'), 50 | ('ft', 'fort'), 51 | ]] 52 | 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | -------------------------------------------------------------------------------- /vits/text/korean.py: -------------------------------------------------------------------------------- 1 | import re 2 | from jamo import h2j, j2hcj 3 | import ko_pron 4 | 5 | 6 | # This is a list of Korean classifiers preceded by pure Korean numerals. 7 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' 8 | 9 | # List of (hangul, hangul divided) pairs: 10 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ 11 | ('ㄳ', 'ㄱㅅ'), 12 | ('ㄵ', 'ㄴㅈ'), 13 | ('ㄶ', 'ㄴㅎ'), 14 | ('ㄺ', 'ㄹㄱ'), 15 | ('ㄻ', 'ㄹㅁ'), 16 | ('ㄼ', 'ㄹㅂ'), 17 | ('ㄽ', 'ㄹㅅ'), 18 | ('ㄾ', 'ㄹㅌ'), 19 | ('ㄿ', 'ㄹㅍ'), 20 | ('ㅀ', 'ㄹㅎ'), 21 | ('ㅄ', 'ㅂㅅ'), 22 | ('ㅘ', 'ㅗㅏ'), 23 | ('ㅙ', 'ㅗㅐ'), 24 | ('ㅚ', 'ㅗㅣ'), 25 | ('ㅝ', 'ㅜㅓ'), 26 | ('ㅞ', 'ㅜㅔ'), 27 | ('ㅟ', 'ㅜㅣ'), 28 | ('ㅢ', 'ㅡㅣ'), 29 | ('ㅑ', 'ㅣㅏ'), 30 | ('ㅒ', 'ㅣㅐ'), 31 | ('ㅕ', 'ㅣㅓ'), 32 | ('ㅖ', 'ㅣㅔ'), 33 | ('ㅛ', 'ㅣㅗ'), 34 | ('ㅠ', 'ㅣㅜ') 35 | ]] 36 | 37 | # List of (Latin alphabet, hangul) pairs: 38 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 39 | ('a', '에이'), 40 | ('b', '비'), 41 | ('c', '시'), 42 | ('d', '디'), 43 | ('e', '이'), 44 | ('f', '에프'), 45 | ('g', '지'), 46 | ('h', '에이치'), 47 | ('i', '아이'), 48 | ('j', '제이'), 49 | ('k', '케이'), 50 | ('l', '엘'), 51 | ('m', '엠'), 52 | ('n', '엔'), 53 | ('o', '오'), 54 | ('p', '피'), 55 | ('q', '큐'), 56 | ('r', '아르'), 57 | ('s', '에스'), 58 | ('t', '티'), 59 | ('u', '유'), 60 | ('v', '브이'), 61 | ('w', '더블유'), 62 | ('x', '엑스'), 63 | ('y', '와이'), 64 | ('z', '제트') 65 | ]] 66 | 67 | # List of (ipa, lazy ipa) pairs: 68 | _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 69 | ('t͡ɕ','ʧ'), 70 | ('d͡ʑ','ʥ'), 71 | ('ɲ','n^'), 72 | ('ɕ','ʃ'), 73 | ('ʷ','w'), 74 | ('ɭ','l`'), 75 | ('ʎ','ɾ'), 76 | ('ɣ','ŋ'), 77 | ('ɰ','ɯ'), 78 | ('ʝ','j'), 79 | ('ʌ','ə'), 80 | ('ɡ','g'), 81 | ('\u031a','#'), 82 | ('\u0348','='), 83 | ('\u031e',''), 84 | ('\u0320',''), 85 | ('\u0339','') 86 | ]] 87 | 88 | 89 | def latin_to_hangul(text): 90 | for regex, replacement in _latin_to_hangul: 91 | text = re.sub(regex, replacement, text) 92 | return text 93 | 94 | 95 | def divide_hangul(text): 96 | text = j2hcj(h2j(text)) 97 | for regex, replacement in _hangul_divided: 98 | text = re.sub(regex, replacement, text) 99 | return text 100 | 101 | 102 | def hangul_number(num, sino=True): 103 | '''Reference https://github.com/Kyubyong/g2pK''' 104 | num = re.sub(',', '', num) 105 | 106 | if num == '0': 107 | return '영' 108 | if not sino and num == '20': 109 | return '스무' 110 | 111 | digits = '123456789' 112 | names = '일이삼사오육칠팔구' 113 | digit2name = {d: n for d, n in zip(digits, names)} 114 | 115 | modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉' 116 | decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔' 117 | digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} 118 | digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} 119 | 120 | spelledout = [] 121 | for i, digit in enumerate(num): 122 | i = len(num) - i - 1 123 | if sino: 124 | if i == 0: 125 | name = digit2name.get(digit, '') 126 | elif i == 1: 127 | name = digit2name.get(digit, '') + '십' 128 | name = name.replace('일십', '십') 129 | else: 130 | if i == 0: 131 | name = digit2mod.get(digit, '') 132 | elif i == 1: 133 | name = digit2dec.get(digit, '') 134 | if digit == '0': 135 | if i % 4 == 0: 136 | last_three = spelledout[-min(3, len(spelledout)):] 137 | if ''.join(last_three) == '': 138 | spelledout.append('') 139 | continue 140 | else: 141 | spelledout.append('') 142 | continue 143 | if i == 2: 144 | name = digit2name.get(digit, '') + '백' 145 | name = name.replace('일백', '백') 146 | elif i == 3: 147 | name = digit2name.get(digit, '') + '천' 148 | name = name.replace('일천', '천') 149 | elif i == 4: 150 | name = digit2name.get(digit, '') + '만' 151 | name = name.replace('일만', '만') 152 | elif i == 5: 153 | name = digit2name.get(digit, '') + '십' 154 | name = name.replace('일십', '십') 155 | elif i == 6: 156 | name = digit2name.get(digit, '') + '백' 157 | name = name.replace('일백', '백') 158 | elif i == 7: 159 | name = digit2name.get(digit, '') + '천' 160 | name = name.replace('일천', '천') 161 | elif i == 8: 162 | name = digit2name.get(digit, '') + '억' 163 | elif i == 9: 164 | name = digit2name.get(digit, '') + '십' 165 | elif i == 10: 166 | name = digit2name.get(digit, '') + '백' 167 | elif i == 11: 168 | name = digit2name.get(digit, '') + '천' 169 | elif i == 12: 170 | name = digit2name.get(digit, '') + '조' 171 | elif i == 13: 172 | name = digit2name.get(digit, '') + '십' 173 | elif i == 14: 174 | name = digit2name.get(digit, '') + '백' 175 | elif i == 15: 176 | name = digit2name.get(digit, '') + '천' 177 | spelledout.append(name) 178 | return ''.join(elem for elem in spelledout) 179 | 180 | 181 | def number_to_hangul(text): 182 | '''Reference https://github.com/Kyubyong/g2pK''' 183 | tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text)) 184 | for token in tokens: 185 | num, classifier = token 186 | if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: 187 | spelledout = hangul_number(num, sino=False) 188 | else: 189 | spelledout = hangul_number(num, sino=True) 190 | text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}') 191 | # digit by digit for remaining digits 192 | digits = '0123456789' 193 | names = '영일이삼사오육칠팔구' 194 | for d, n in zip(digits, names): 195 | text = text.replace(d, n) 196 | return text 197 | 198 | 199 | def korean_to_lazy_ipa(text): 200 | text = latin_to_hangul(text) 201 | text = number_to_hangul(text) 202 | text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text) 203 | for regex, replacement in _ipa_to_lazy_ipa: 204 | text = re.sub(regex, replacement, text) 205 | return text 206 | 207 | 208 | def korean_to_ipa(text): 209 | text = korean_to_lazy_ipa(text) 210 | return text.replace('ʧ','tʃ').replace('ʥ','dʑ') 211 | -------------------------------------------------------------------------------- /vits/text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pyopenjtalk 3 | 4 | pyopenjtalk._lazy_init() 5 | 6 | 7 | def japanese_cleaners(text): 8 | from .japanese import japanese_to_romaji_with_accent 9 | text = japanese_to_romaji_with_accent(text) 10 | text = re.sub(r'([A-Za-z])$', r'\1.', text) 11 | return text 12 | 13 | 14 | def japanese_cleaners2(text): 15 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') 16 | 17 | 18 | def korean_cleaners(text): 19 | """Pipeline for Korean text""" 20 | from .korean import latin_to_hangul, number_to_hangul, divide_hangul 21 | text = latin_to_hangul(text) 22 | text = number_to_hangul(text) 23 | text = divide_hangul(text) 24 | text = re.sub(r'([\u3131-\u3163])$', r'\1.', text) 25 | return text 26 | 27 | 28 | def chinese_cleaners(text): 29 | """Pipeline for Chinese text""" 30 | from .mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo 31 | text = number_to_chinese(text) 32 | text = chinese_to_bopomofo(text) 33 | text = latin_to_bopomofo(text) 34 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text) 35 | return text 36 | 37 | 38 | def chinese_cleaners1(text): 39 | from pypinyin import Style, pinyin 40 | 41 | phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)] 42 | return ' '.join(phones) 43 | 44 | 45 | def chinese_cleaners2(text): 46 | from pypinyin import Style, pinyin 47 | from pypinyin.style._utils import get_finals, get_initials 48 | return " ".join([ 49 | p 50 | for phone in pinyin(text, style=Style.TONE3, v_to_u=True) 51 | for p in [ 52 | get_initials(phone[0], strict=True), 53 | get_finals(phone[0][:-1], strict=True) + phone[0][-1] 54 | if phone[0][-1].isdigit() 55 | else get_finals(phone[0], strict=True) 56 | if phone[0][-1].isalnum() 57 | else phone[0], 58 | ] 59 | if len(p) != 0 and not p.isdigit() 60 | ]) 61 | 62 | 63 | def zh_ja_mixture_cleaners(text): 64 | from .mandarin import chinese_to_romaji 65 | from .japanese import japanese_to_romaji_with_accent 66 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 67 | lambda x: chinese_to_romaji(x.group(1)) + ' ', text) 68 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent( 69 | x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…') + ' ', text) 70 | text = re.sub(r'\s+$', '', text) 71 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 72 | return text 73 | 74 | 75 | def sanskrit_cleaners(text): 76 | text = text.replace('॥', '।').replace('ॐ', 'ओम्') 77 | if text[-1] != '।': 78 | text += ' ।' 79 | return text 80 | 81 | 82 | def cjks_cleaners(text): 83 | from .mandarin import chinese_to_lazy_ipa 84 | from .japanese import japanese_to_ipa 85 | from .korean import korean_to_lazy_ipa 86 | from .sanskrit import devanagari_to_ipa 87 | from .english import english_to_lazy_ipa 88 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 89 | lambda x: chinese_to_lazy_ipa(x.group(1)) + ' ', text) 90 | text = re.sub(r'\[JA\](.*?)\[JA\]', 91 | lambda x: japanese_to_ipa(x.group(1)) + ' ', text) 92 | text = re.sub(r'\[KO\](.*?)\[KO\]', 93 | lambda x: korean_to_lazy_ipa(x.group(1)) + ' ', text) 94 | text = re.sub(r'\[SA\](.*?)\[SA\]', 95 | lambda x: devanagari_to_ipa(x.group(1)) + ' ', text) 96 | text = re.sub(r'\[EN\](.*?)\[EN\]', 97 | lambda x: english_to_lazy_ipa(x.group(1)) + ' ', text) 98 | text = re.sub(r'\s+$', '', text) 99 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 100 | return text 101 | 102 | 103 | def cjke_cleaners(text): 104 | from .mandarin import chinese_to_lazy_ipa 105 | from .japanese import japanese_to_ipa 106 | from .korean import korean_to_ipa 107 | from .english import english_to_ipa2 108 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace( 109 | 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn') + ' ', text) 110 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace( 111 | 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz') + ' ', text) 112 | text = re.sub(r'\[KO\](.*?)\[KO\]', 113 | lambda x: korean_to_ipa(x.group(1)) + ' ', text) 114 | text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace( 115 | 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u') + ' ', text) 116 | text = re.sub(r'\s+$', '', text) 117 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 118 | return text 119 | 120 | 121 | def cjke_cleaners2(text): 122 | from .mandarin import chinese_to_ipa 123 | from .japanese import japanese_to_ipa2 124 | from .korean import korean_to_ipa 125 | from .english import english_to_ipa2 126 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 127 | lambda x: chinese_to_ipa(x.group(1)) + ' ', text) 128 | text = re.sub(r'\[JA\](.*?)\[JA\]', 129 | lambda x: japanese_to_ipa2(x.group(1)) + ' ', text) 130 | text = re.sub(r'\[KO\](.*?)\[KO\]', 131 | lambda x: korean_to_ipa(x.group(1)) + ' ', text) 132 | text = re.sub(r'\[EN\](.*?)\[EN\]', 133 | lambda x: english_to_ipa2(x.group(1)) + ' ', text) 134 | text = re.sub(r'\s+$', '', text) 135 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 136 | return text 137 | 138 | 139 | def thai_cleaners(text): 140 | from .thai import num_to_thai, latin_to_thai 141 | text = num_to_thai(text) 142 | text = latin_to_thai(text) 143 | return text 144 | 145 | 146 | def shanghainese_cleaners(text): 147 | from .shanghainese import shanghainese_to_ipa 148 | text = shanghainese_to_ipa(text) 149 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 150 | return text 151 | 152 | 153 | def chinese_dialect_cleaners(text): 154 | from .mandarin import chinese_to_ipa2 155 | from .japanese import japanese_to_ipa3 156 | from .shanghainese import shanghainese_to_ipa 157 | from .cantonese import cantonese_to_ipa 158 | from .english import english_to_lazy_ipa2 159 | from .ngu_dialect import ngu_dialect_to_ipa 160 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 161 | lambda x: chinese_to_ipa2(x.group(1)) + ' ', text) 162 | text = re.sub(r'\[JA\](.*?)\[JA\]', 163 | lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ') + ' ', text) 164 | text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5', 165 | '˧˧˦').replace( 166 | '6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e') + ' ', text) 167 | text = re.sub(r'\[GD\](.*?)\[GD\]', 168 | lambda x: cantonese_to_ipa(x.group(1)) + ' ', text) 169 | text = re.sub(r'\[EN\](.*?)\[EN\]', 170 | lambda x: english_to_lazy_ipa2(x.group(1)) + ' ', text) 171 | text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group( 172 | 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ') + ' ', text) 173 | text = re.sub(r'\s+$', '', text) 174 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 175 | return text 176 | -------------------------------------------------------------------------------- /modules/safe.py: -------------------------------------------------------------------------------- 1 | # This file is modified from stable-diffusion-webui 2 | import pickle 3 | import collections 4 | import sys 5 | import traceback 6 | 7 | import torch 8 | import numpy 9 | import _codecs 10 | import zipfile 11 | import re 12 | 13 | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage 14 | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage 15 | 16 | 17 | def encode(*args): 18 | out = _codecs.encode(*args) 19 | return out 20 | 21 | 22 | class RestrictedUnpickler(pickle.Unpickler): 23 | extra_handler = None 24 | 25 | def persistent_load(self, saved_id): 26 | assert saved_id[0] == 'storage' 27 | return TypedStorage() 28 | 29 | def find_class(self, module, name): 30 | if self.extra_handler is not None: 31 | res = self.extra_handler(module, name) 32 | if res is not None: 33 | return res 34 | 35 | if module == 'collections' and name == 'OrderedDict': 36 | return getattr(collections, name) 37 | if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: 38 | return getattr(torch._utils, name) 39 | if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: 40 | return getattr(torch, name) 41 | if module == 'torch.nn.modules.container' and name in ['ParameterDict']: 42 | return getattr(torch.nn.modules.container, name) 43 | if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: 44 | return getattr(numpy.core.multiarray, name) 45 | if module == 'numpy' and name in ['dtype', 'ndarray']: 46 | return getattr(numpy, name) 47 | if module == '_codecs' and name == 'encode': 48 | return encode 49 | if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': 50 | import pytorch_lightning.callbacks 51 | return pytorch_lightning.callbacks.model_checkpoint 52 | if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': 53 | import pytorch_lightning.callbacks.model_checkpoint 54 | return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 55 | if module == "__builtin__" and name == 'set': 56 | return set 57 | 58 | # Forbid everything else. 59 | raise Exception(f"global '{module}/{name}' is forbidden") 60 | 61 | 62 | # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/' 63 | allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") 64 | data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") 65 | 66 | 67 | def check_zip_filenames(filename, names): 68 | for name in names: 69 | if allowed_zip_names_re.match(name): 70 | continue 71 | 72 | raise Exception(f"bad file inside {filename}: {name}") 73 | 74 | 75 | def check_pt(filename, extra_handler): 76 | try: 77 | 78 | # new pytorch format is a zip file 79 | with zipfile.ZipFile(filename) as z: 80 | check_zip_filenames(filename, z.namelist()) 81 | 82 | # find filename of data.pkl in zip file: '/data.pkl' 83 | data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] 84 | if len(data_pkl_filenames) == 0: 85 | raise Exception(f"data.pkl not found in {filename}") 86 | if len(data_pkl_filenames) > 1: 87 | raise Exception(f"Multiple data.pkl found in {filename}") 88 | with z.open(data_pkl_filenames[0]) as file: 89 | unpickler = RestrictedUnpickler(file) 90 | unpickler.extra_handler = extra_handler 91 | unpickler.load() 92 | 93 | except zipfile.BadZipfile: 94 | 95 | # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle 96 | with open(filename, "rb") as file: 97 | unpickler = RestrictedUnpickler(file) 98 | unpickler.extra_handler = extra_handler 99 | for i in range(5): 100 | unpickler.load() 101 | 102 | 103 | def load(filename, *args, **kwargs): 104 | return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) 105 | 106 | 107 | def load_with_extra(filename, extra_handler=None, *args, **kwargs): 108 | """ 109 | this function is intended to be used by extensions that want to load models with 110 | some extra classes in them that the usual unpickler would find suspicious. 111 | 112 | Use the extra_handler argument to specify a function that takes module and field name as text, 113 | and returns that field's value: 114 | 115 | ```python 116 | def extra(module, name): 117 | if module == 'collections' and name == 'OrderedDict': 118 | return collections.OrderedDict 119 | 120 | return None 121 | 122 | safe.load_with_extra('model.pt', extra_handler=extra) 123 | ``` 124 | 125 | The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is 126 | definitely unsafe. 127 | """ 128 | from modules.options import cmd_opts 129 | try: 130 | if not cmd_opts.disable_safe_unpickle: 131 | check_pt(filename, extra_handler) 132 | 133 | except pickle.UnpicklingError: 134 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 135 | print(traceback.format_exc(), file=sys.stderr) 136 | print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) 137 | print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) 138 | return None 139 | 140 | except Exception: 141 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr) 142 | print(traceback.format_exc(), file=sys.stderr) 143 | print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) 144 | print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) 145 | return None 146 | 147 | return unsafe_torch_load(filename, *args, **kwargs) 148 | 149 | 150 | class Extra: 151 | """ 152 | A class for temporarily setting the global handler for when you can't explicitly call load_with_extra 153 | (because it's not your code making the torch.load call). The intended use is like this: 154 | 155 | ``` 156 | import torch 157 | from modules import safe 158 | 159 | def handler(module, name): 160 | if module == 'torch' and name in ['float64', 'float16']: 161 | return getattr(torch, name) 162 | 163 | return None 164 | 165 | with safe.Extra(handler): 166 | x = torch.load('model.pt') 167 | ``` 168 | """ 169 | 170 | def __init__(self, handler): 171 | self.handler = handler 172 | 173 | def __enter__(self): 174 | global global_extra_handler 175 | 176 | assert global_extra_handler is None, 'already inside an Extra() block' 177 | global_extra_handler = self.handler 178 | 179 | def __exit__(self, exc_type, exc_val, exc_tb): 180 | global global_extra_handler 181 | 182 | global_extra_handler = None 183 | 184 | 185 | unsafe_torch_load = torch.load 186 | torch.load = load 187 | global_extra_handler = None 188 | -------------------------------------------------------------------------------- /vits/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.ERROR) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, 'module'): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 42 | checkpoint_path, iteration)) 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def plot_spectrogram_to_numpy(spectrogram): 47 | global MATPLOTLIB_FLAG 48 | if not MATPLOTLIB_FLAG: 49 | import matplotlib 50 | matplotlib.use("Agg") 51 | MATPLOTLIB_FLAG = True 52 | mpl_logger = logging.getLogger('matplotlib') 53 | mpl_logger.setLevel(logging.WARNING) 54 | import matplotlib.pylab as plt 55 | import numpy as np 56 | 57 | fig, ax = plt.subplots(figsize=(10, 2)) 58 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 59 | interpolation='none') 60 | plt.colorbar(im, ax=ax) 61 | plt.xlabel("Frames") 62 | plt.ylabel("Channels") 63 | plt.tight_layout() 64 | 65 | fig.canvas.draw() 66 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 67 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 68 | plt.close() 69 | return data 70 | 71 | 72 | def plot_alignment_to_numpy(alignment, info=None): 73 | global MATPLOTLIB_FLAG 74 | if not MATPLOTLIB_FLAG: 75 | import matplotlib 76 | matplotlib.use("Agg") 77 | MATPLOTLIB_FLAG = True 78 | mpl_logger = logging.getLogger('matplotlib') 79 | mpl_logger.setLevel(logging.WARNING) 80 | import matplotlib.pylab as plt 81 | import numpy as np 82 | 83 | fig, ax = plt.subplots(figsize=(6, 4)) 84 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 85 | interpolation='none') 86 | fig.colorbar(im, ax=ax) 87 | xlabel = 'Decoder timestep' 88 | if info is not None: 89 | xlabel += '\n\n' + info 90 | plt.xlabel(xlabel) 91 | plt.ylabel('Encoder timestep') 92 | plt.tight_layout() 93 | 94 | fig.canvas.draw() 95 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 96 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 97 | plt.close() 98 | return data 99 | 100 | 101 | def load_wav_to_torch(full_path): 102 | sampling_rate, data = read(full_path) 103 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 104 | 105 | 106 | def load_filepaths_and_text(filename, split="|"): 107 | with open(filename, encoding='utf-8') as f: 108 | filepaths_and_text = [line.strip().split(split) for line in f] 109 | return filepaths_and_text 110 | 111 | 112 | def get_hparams(init=True): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 115 | help='JSON file for configuration') 116 | parser.add_argument('-m', '--model', type=str, required=True, 117 | help='Model name') 118 | 119 | args = parser.parse_args() 120 | model_dir = os.path.join("./logs", args.model) 121 | 122 | if not os.path.exists(model_dir): 123 | os.makedirs(model_dir) 124 | 125 | config_path = args.config 126 | config_save_path = os.path.join(model_dir, "config.json") 127 | if init: 128 | with open(config_path, "r") as f: 129 | data = f.read() 130 | with open(config_save_path, "w") as f: 131 | f.write(data) 132 | else: 133 | with open(config_save_path, "r") as f: 134 | data = f.read() 135 | config = json.loads(data) 136 | 137 | hparams = HParams(**config) 138 | hparams.model_dir = model_dir 139 | return hparams 140 | 141 | 142 | def get_hparams_from_dir(model_dir): 143 | config_save_path = os.path.join(model_dir, "config.json") 144 | with open(config_save_path, "r") as f: 145 | data = f.read() 146 | config = json.loads(data) 147 | 148 | hparams = HParams(**config) 149 | hparams.model_dir = model_dir 150 | return hparams 151 | 152 | 153 | def get_hparams_from_file(config_path): 154 | with open(config_path, "r", encoding="utf-8") as f: 155 | data = f.read() 156 | config = json.loads(data) 157 | 158 | hparams = HParams(**config) 159 | return hparams 160 | 161 | 162 | def check_git_hash(model_dir): 163 | source_dir = os.path.dirname(os.path.realpath(__file__)) 164 | if not os.path.exists(os.path.join(source_dir, ".git")): 165 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 166 | source_dir 167 | )) 168 | return 169 | 170 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 171 | 172 | path = os.path.join(model_dir, "githash") 173 | if os.path.exists(path): 174 | saved_hash = open(path).read() 175 | if saved_hash != cur_hash: 176 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 177 | saved_hash[:8], cur_hash[:8])) 178 | else: 179 | open(path, "w").write(cur_hash) 180 | 181 | 182 | def get_logger(model_dir, filename="train.log"): 183 | global logger 184 | logger = logging.getLogger(os.path.basename(model_dir)) 185 | logger.setLevel(logging.DEBUG) 186 | 187 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 188 | if not os.path.exists(model_dir): 189 | os.makedirs(model_dir) 190 | h = logging.FileHandler(os.path.join(model_dir, filename)) 191 | h.setLevel(logging.DEBUG) 192 | h.setFormatter(formatter) 193 | logger.addHandler(h) 194 | return logger 195 | 196 | 197 | class HParams(): 198 | def __init__(self, **kwargs): 199 | for k, v in kwargs.items(): 200 | if type(v) == dict: 201 | v = HParams(**v) 202 | self[k] = v 203 | 204 | def keys(self): 205 | return self.__dict__.keys() 206 | 207 | def items(self): 208 | return self.__dict__.items() 209 | 210 | def values(self): 211 | return self.__dict__.values() 212 | 213 | def __len__(self): 214 | return len(self.__dict__) 215 | 216 | def __getitem__(self, key): 217 | return getattr(self, key) 218 | 219 | def __setitem__(self, key, value): 220 | return setattr(self, key, value) 221 | 222 | def __contains__(self, key): 223 | return key in self.__dict__ 224 | 225 | def __repr__(self): 226 | return self.__dict__.__repr__() 227 | -------------------------------------------------------------------------------- /modules/process.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os.path 3 | import re 4 | import shutil 5 | import time 6 | from pathlib import Path 7 | from typing import Tuple, List 8 | 9 | import librosa 10 | import numpy as np 11 | import scipy.io.wavfile as wavfile 12 | import soundfile 13 | import tqdm 14 | from torch import no_grad, LongTensor 15 | 16 | import modules.sovits_model as sovits_model 17 | import modules.vits_model as vits_model 18 | from modules.devices import device, torch_gc 19 | from modules.options import cmd_opts 20 | from modules.sovits_model import Svc as SovitsSvc 21 | from modules.utils import windows_filename 22 | from modules.vits_model import VITSModel 23 | from repositories.sovits.inference import slicer 24 | from vits import commons 25 | from vits.text import text_to_sequence 26 | 27 | 28 | class Text2SpeechTask: 29 | origin: str 30 | speaker: str 31 | method: str 32 | pre_processed: List[Tuple[int, str]] 33 | 34 | def __init__(self, origin: str, speaker: str, method: str): 35 | self.origin = origin 36 | self.speaker = speaker 37 | self.method = method 38 | self.pre_processed = [] 39 | 40 | def preprocess(self): 41 | model = vits_model.curr_vits_model 42 | if self.method == "Simple": 43 | speaker_id = model.speakers.index(self.speaker) 44 | self.pre_processed.append((speaker_id, self.origin)) 45 | elif self.method == "Multi Speakers": 46 | match = re.findall(r"\[(.*)] (.*)", self.origin) 47 | for m in match: 48 | if m[0] not in model.speakers: 49 | err = f"Error: Unknown speaker {m[0]}, check your input." 50 | print(err) 51 | return err 52 | speaker_id = model.speakers.index(m[0]) 53 | self.pre_processed.append((speaker_id, m[1])) 54 | elif self.method == "Batch Process": 55 | spl = self.origin.split("\n") 56 | speaker_id = model.speakers.index(self.speaker) 57 | for line in spl: 58 | self.pre_processed.append((speaker_id, line)) 59 | 60 | 61 | class SovitsTask: 62 | # 暂时貌似没有需要 preprocess 的,就先放这里了 63 | pass 64 | 65 | 66 | def text2speech(text: str, speaker: str, speed, method="Simple"): 67 | if text == "": 68 | return "Fail: You need to input text.", None 69 | model = vits_model.get_model() 70 | if not model: 71 | return "Fail: No vits model loaded. Please select a model to load first.", None 72 | task = Text2SpeechTask(origin=text, speaker=speaker, method=method) 73 | err = task.preprocess() 74 | if err: 75 | return err, None 76 | ti = int(time.time()) 77 | save_path = "" 78 | output_info = "Success saved to " 79 | outputs = [] 80 | for t in task.pre_processed: 81 | sample_rate, data = process_vits(model=model, 82 | text=t[1], speaker_id=t[0], speed=speed) 83 | outputs.append(data) 84 | save_path = f"outputs/vits/{str(ti)}-{windows_filename(t[1])}.wav" 85 | wavfile.write(save_path, sample_rate, data) 86 | output_info += f"\n{save_path}" 87 | ti += 1 88 | 89 | torch_gc() 90 | 91 | if len(outputs) > 1: 92 | batch_file_path = f"outputs/vits-batch/{str(int(time.time()))}.wav" 93 | wavfile.write(batch_file_path, vits_model.curr_vits_model.hps.data.sampling_rate, np.concatenate(outputs)) 94 | return f"{output_info}\n{batch_file_path}", batch_file_path 95 | return output_info, save_path 96 | 97 | 98 | def sovits_process(audio_path, speaker: str, vc_transform: int, slice_db: int): 99 | if not audio_path: 100 | return "Fail: You need to input an audio.", None 101 | model = sovits_model.get_model() 102 | if not model: 103 | return "Fail: No so-vits model loaded. Please select a model to load first.", None 104 | ti = int(time.time()) 105 | save_path = "" 106 | output_info = "Success saved to " 107 | outputs = [] 108 | try: 109 | for af in audio_path: 110 | data, sampling_rate = process_so_vits(svc_model=sovits_model.get_model(), 111 | sid=speaker, 112 | input_audio=af, 113 | vc_transform=vc_transform, 114 | slice_db=slice_db) 115 | outputs.extend(data) 116 | save_path = f"outputs/sovits/{str(ti)}.wav" 117 | soundfile.write(save_path, data, sampling_rate, format="wav") 118 | ti += 1 119 | output_info += f"\n{save_path}" 120 | 121 | if len(outputs) > 1: 122 | batch_file_path = f"outputs/sovits-batch/{str(int(time.time()))}.wav" 123 | soundfile.write(batch_file_path, outputs, sovits_model.get_model().target_sample) 124 | return f"{output_info}\n{batch_file_path}", batch_file_path 125 | 126 | finally: 127 | torch_gc() 128 | 129 | return output_info, save_path 130 | 131 | 132 | def text_processing(text, model: VITSModel): 133 | hps = model.hps 134 | _use_symbols = model.symbols 135 | # 留了点屎山 以后再处理吧 136 | if hasattr(hps, "symbols_zh"): 137 | _use_symbols = hps.symbols_zh 138 | text_norm = text_to_sequence(text, _use_symbols, hps.data.text_cleaners) 139 | if hps.data.add_blank: 140 | text_norm = commons.intersperse(text_norm, 0) 141 | text_norm = LongTensor(text_norm) 142 | return text_norm 143 | 144 | 145 | def process_vits(model: VITSModel, text: str, 146 | speaker_id, speed, 147 | noise_scale=0.667, 148 | noise_scale_w=0.8) -> Tuple[int, np.array]: 149 | stn_tst = text_processing(text, model) 150 | with no_grad(): 151 | x_tst = stn_tst.unsqueeze(0).to(device) 152 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device) 153 | sid = LongTensor([speaker_id]).to(device) 154 | audio = model.model.infer(x_tst, x_tst_lengths, sid=sid, 155 | noise_scale=noise_scale, noise_scale_w=noise_scale_w, 156 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() 157 | del stn_tst, x_tst, x_tst_lengths, sid 158 | return model.hps.data.sampling_rate, audio 159 | 160 | 161 | def process_so_vits(svc_model: SovitsSvc, sid, input_audio, vc_transform, slice_db): 162 | audio_path = input_audio.name 163 | wav_path = os.path.join("temp", str(int(time.time())) + ".wav") 164 | if Path(audio_path).suffix != '.wav': 165 | raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None) 166 | soundfile.write(wav_path, raw_audio, raw_sample_rate) 167 | else: 168 | shutil.copy(audio_path, wav_path) 169 | chunks = slicer.cut(wav_path, db_thresh=slice_db) 170 | audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks) 171 | 172 | audio = [] 173 | for (slice_tag, data) in tqdm.tqdm(audio_data): 174 | if cmd_opts.debug: 175 | print(f'segment start, {round(len(data) / audio_sr, 3)}s') 176 | length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample)) 177 | raw_path = io.BytesIO() 178 | soundfile.write(raw_path, data, audio_sr, format="wav") 179 | raw_path.seek(0) 180 | if slice_tag: 181 | if cmd_opts.debug: 182 | print('jump empty segment') 183 | _audio = np.zeros(length) 184 | else: 185 | out_audio, out_sr = svc_model.infer(sid, vc_transform, raw_path) 186 | _audio = out_audio.cpu().numpy() 187 | audio.extend(list(_audio)) 188 | os.remove(wav_path) 189 | return audio, svc_model.target_sample 190 | -------------------------------------------------------------------------------- /vits/text/mandarin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from pypinyin import lazy_pinyin, BOPOMOFO 5 | import jieba 6 | import cn2an 7 | import logging 8 | 9 | logging.getLogger('jieba').setLevel(logging.WARNING) 10 | jieba.initialize() 11 | 12 | 13 | # List of (Latin alphabet, bopomofo) pairs: 14 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 15 | ('a', 'ㄟˉ'), 16 | ('b', 'ㄅㄧˋ'), 17 | ('c', 'ㄙㄧˉ'), 18 | ('d', 'ㄉㄧˋ'), 19 | ('e', 'ㄧˋ'), 20 | ('f', 'ㄝˊㄈㄨˋ'), 21 | ('g', 'ㄐㄧˋ'), 22 | ('h', 'ㄝˇㄑㄩˋ'), 23 | ('i', 'ㄞˋ'), 24 | ('j', 'ㄐㄟˋ'), 25 | ('k', 'ㄎㄟˋ'), 26 | ('l', 'ㄝˊㄛˋ'), 27 | ('m', 'ㄝˊㄇㄨˋ'), 28 | ('n', 'ㄣˉ'), 29 | ('o', 'ㄡˉ'), 30 | ('p', 'ㄆㄧˉ'), 31 | ('q', 'ㄎㄧㄡˉ'), 32 | ('r', 'ㄚˋ'), 33 | ('s', 'ㄝˊㄙˋ'), 34 | ('t', 'ㄊㄧˋ'), 35 | ('u', 'ㄧㄡˉ'), 36 | ('v', 'ㄨㄧˉ'), 37 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 38 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 39 | ('y', 'ㄨㄞˋ'), 40 | ('z', 'ㄗㄟˋ') 41 | ]] 42 | 43 | # List of (bopomofo, romaji) pairs: 44 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 45 | ('ㄅㄛ', 'p⁼wo'), 46 | ('ㄆㄛ', 'pʰwo'), 47 | ('ㄇㄛ', 'mwo'), 48 | ('ㄈㄛ', 'fwo'), 49 | ('ㄅ', 'p⁼'), 50 | ('ㄆ', 'pʰ'), 51 | ('ㄇ', 'm'), 52 | ('ㄈ', 'f'), 53 | ('ㄉ', 't⁼'), 54 | ('ㄊ', 'tʰ'), 55 | ('ㄋ', 'n'), 56 | ('ㄌ', 'l'), 57 | ('ㄍ', 'k⁼'), 58 | ('ㄎ', 'kʰ'), 59 | ('ㄏ', 'h'), 60 | ('ㄐ', 'ʧ⁼'), 61 | ('ㄑ', 'ʧʰ'), 62 | ('ㄒ', 'ʃ'), 63 | ('ㄓ', 'ʦ`⁼'), 64 | ('ㄔ', 'ʦ`ʰ'), 65 | ('ㄕ', 's`'), 66 | ('ㄖ', 'ɹ`'), 67 | ('ㄗ', 'ʦ⁼'), 68 | ('ㄘ', 'ʦʰ'), 69 | ('ㄙ', 's'), 70 | ('ㄚ', 'a'), 71 | ('ㄛ', 'o'), 72 | ('ㄜ', 'ə'), 73 | ('ㄝ', 'e'), 74 | ('ㄞ', 'ai'), 75 | ('ㄟ', 'ei'), 76 | ('ㄠ', 'au'), 77 | ('ㄡ', 'ou'), 78 | ('ㄧㄢ', 'yeNN'), 79 | ('ㄢ', 'aNN'), 80 | ('ㄧㄣ', 'iNN'), 81 | ('ㄣ', 'əNN'), 82 | ('ㄤ', 'aNg'), 83 | ('ㄧㄥ', 'iNg'), 84 | ('ㄨㄥ', 'uNg'), 85 | ('ㄩㄥ', 'yuNg'), 86 | ('ㄥ', 'əNg'), 87 | ('ㄦ', 'əɻ'), 88 | ('ㄧ', 'i'), 89 | ('ㄨ', 'u'), 90 | ('ㄩ', 'ɥ'), 91 | ('ˉ', '→'), 92 | ('ˊ', '↑'), 93 | ('ˇ', '↓↑'), 94 | ('ˋ', '↓'), 95 | ('˙', ''), 96 | (',', ','), 97 | ('。', '.'), 98 | ('!', '!'), 99 | ('?', '?'), 100 | ('—', '-') 101 | ]] 102 | 103 | # List of (romaji, ipa) pairs: 104 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 105 | ('ʃy', 'ʃ'), 106 | ('ʧʰy', 'ʧʰ'), 107 | ('ʧ⁼y', 'ʧ⁼'), 108 | ('NN', 'n'), 109 | ('Ng', 'ŋ'), 110 | ('y', 'j'), 111 | ('h', 'x') 112 | ]] 113 | 114 | # List of (bopomofo, ipa) pairs: 115 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 116 | ('ㄅㄛ', 'p⁼wo'), 117 | ('ㄆㄛ', 'pʰwo'), 118 | ('ㄇㄛ', 'mwo'), 119 | ('ㄈㄛ', 'fwo'), 120 | ('ㄅ', 'p⁼'), 121 | ('ㄆ', 'pʰ'), 122 | ('ㄇ', 'm'), 123 | ('ㄈ', 'f'), 124 | ('ㄉ', 't⁼'), 125 | ('ㄊ', 'tʰ'), 126 | ('ㄋ', 'n'), 127 | ('ㄌ', 'l'), 128 | ('ㄍ', 'k⁼'), 129 | ('ㄎ', 'kʰ'), 130 | ('ㄏ', 'x'), 131 | ('ㄐ', 'tʃ⁼'), 132 | ('ㄑ', 'tʃʰ'), 133 | ('ㄒ', 'ʃ'), 134 | ('ㄓ', 'ts`⁼'), 135 | ('ㄔ', 'ts`ʰ'), 136 | ('ㄕ', 's`'), 137 | ('ㄖ', 'ɹ`'), 138 | ('ㄗ', 'ts⁼'), 139 | ('ㄘ', 'tsʰ'), 140 | ('ㄙ', 's'), 141 | ('ㄚ', 'a'), 142 | ('ㄛ', 'o'), 143 | ('ㄜ', 'ə'), 144 | ('ㄝ', 'ɛ'), 145 | ('ㄞ', 'aɪ'), 146 | ('ㄟ', 'eɪ'), 147 | ('ㄠ', 'ɑʊ'), 148 | ('ㄡ', 'oʊ'), 149 | ('ㄧㄢ', 'jɛn'), 150 | ('ㄩㄢ', 'ɥæn'), 151 | ('ㄢ', 'an'), 152 | ('ㄧㄣ', 'in'), 153 | ('ㄩㄣ', 'ɥn'), 154 | ('ㄣ', 'ən'), 155 | ('ㄤ', 'ɑŋ'), 156 | ('ㄧㄥ', 'iŋ'), 157 | ('ㄨㄥ', 'ʊŋ'), 158 | ('ㄩㄥ', 'jʊŋ'), 159 | ('ㄥ', 'əŋ'), 160 | ('ㄦ', 'əɻ'), 161 | ('ㄧ', 'i'), 162 | ('ㄨ', 'u'), 163 | ('ㄩ', 'ɥ'), 164 | ('ˉ', '→'), 165 | ('ˊ', '↑'), 166 | ('ˇ', '↓↑'), 167 | ('ˋ', '↓'), 168 | ('˙', ''), 169 | (',', ','), 170 | ('。', '.'), 171 | ('!', '!'), 172 | ('?', '?'), 173 | ('—', '-') 174 | ]] 175 | 176 | # List of (bopomofo, ipa2) pairs: 177 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 178 | ('ㄅㄛ', 'pwo'), 179 | ('ㄆㄛ', 'pʰwo'), 180 | ('ㄇㄛ', 'mwo'), 181 | ('ㄈㄛ', 'fwo'), 182 | ('ㄅ', 'p'), 183 | ('ㄆ', 'pʰ'), 184 | ('ㄇ', 'm'), 185 | ('ㄈ', 'f'), 186 | ('ㄉ', 't'), 187 | ('ㄊ', 'tʰ'), 188 | ('ㄋ', 'n'), 189 | ('ㄌ', 'l'), 190 | ('ㄍ', 'k'), 191 | ('ㄎ', 'kʰ'), 192 | ('ㄏ', 'h'), 193 | ('ㄐ', 'tɕ'), 194 | ('ㄑ', 'tɕʰ'), 195 | ('ㄒ', 'ɕ'), 196 | ('ㄓ', 'tʂ'), 197 | ('ㄔ', 'tʂʰ'), 198 | ('ㄕ', 'ʂ'), 199 | ('ㄖ', 'ɻ'), 200 | ('ㄗ', 'ts'), 201 | ('ㄘ', 'tsʰ'), 202 | ('ㄙ', 's'), 203 | ('ㄚ', 'a'), 204 | ('ㄛ', 'o'), 205 | ('ㄜ', 'ɤ'), 206 | ('ㄝ', 'ɛ'), 207 | ('ㄞ', 'aɪ'), 208 | ('ㄟ', 'eɪ'), 209 | ('ㄠ', 'ɑʊ'), 210 | ('ㄡ', 'oʊ'), 211 | ('ㄧㄢ', 'jɛn'), 212 | ('ㄩㄢ', 'yæn'), 213 | ('ㄢ', 'an'), 214 | ('ㄧㄣ', 'in'), 215 | ('ㄩㄣ', 'yn'), 216 | ('ㄣ', 'ən'), 217 | ('ㄤ', 'ɑŋ'), 218 | ('ㄧㄥ', 'iŋ'), 219 | ('ㄨㄥ', 'ʊŋ'), 220 | ('ㄩㄥ', 'jʊŋ'), 221 | ('ㄥ', 'ɤŋ'), 222 | ('ㄦ', 'əɻ'), 223 | ('ㄧ', 'i'), 224 | ('ㄨ', 'u'), 225 | ('ㄩ', 'y'), 226 | ('ˉ', '˥'), 227 | ('ˊ', '˧˥'), 228 | ('ˇ', '˨˩˦'), 229 | ('ˋ', '˥˩'), 230 | ('˙', ''), 231 | (',', ','), 232 | ('。', '.'), 233 | ('!', '!'), 234 | ('?', '?'), 235 | ('—', '-') 236 | ]] 237 | 238 | 239 | def number_to_chinese(text): 240 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 241 | for number in numbers: 242 | text = text.replace(number, cn2an.an2cn(number), 1) 243 | return text 244 | 245 | 246 | def chinese_to_bopomofo(text): 247 | text = text.replace('、', ',').replace(';', ',').replace(':', ',') 248 | words = jieba.lcut(text, cut_all=False) 249 | text = '' 250 | for word in words: 251 | bopomofos = lazy_pinyin(word, BOPOMOFO) 252 | if not re.search('[\u4e00-\u9fff]', word): 253 | text += word 254 | continue 255 | for i in range(len(bopomofos)): 256 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) 257 | if text != '': 258 | text += ' ' 259 | text += ''.join(bopomofos) 260 | return text 261 | 262 | 263 | def latin_to_bopomofo(text): 264 | for regex, replacement in _latin_to_bopomofo: 265 | text = re.sub(regex, replacement, text) 266 | return text 267 | 268 | 269 | def bopomofo_to_romaji(text): 270 | for regex, replacement in _bopomofo_to_romaji: 271 | text = re.sub(regex, replacement, text) 272 | return text 273 | 274 | 275 | def bopomofo_to_ipa(text): 276 | for regex, replacement in _bopomofo_to_ipa: 277 | text = re.sub(regex, replacement, text) 278 | return text 279 | 280 | 281 | def bopomofo_to_ipa2(text): 282 | for regex, replacement in _bopomofo_to_ipa2: 283 | text = re.sub(regex, replacement, text) 284 | return text 285 | 286 | 287 | def chinese_to_romaji(text): 288 | text = number_to_chinese(text) 289 | text = chinese_to_bopomofo(text) 290 | text = latin_to_bopomofo(text) 291 | text = bopomofo_to_romaji(text) 292 | text = re.sub('i([aoe])', r'y\1', text) 293 | text = re.sub('u([aoəe])', r'w\1', text) 294 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 295 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 296 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 297 | return text 298 | 299 | 300 | def chinese_to_lazy_ipa(text): 301 | text = chinese_to_romaji(text) 302 | for regex, replacement in _romaji_to_ipa: 303 | text = re.sub(regex, replacement, text) 304 | return text 305 | 306 | 307 | def chinese_to_ipa(text): 308 | text = number_to_chinese(text) 309 | text = chinese_to_bopomofo(text) 310 | text = latin_to_bopomofo(text) 311 | text = bopomofo_to_ipa(text) 312 | text = re.sub('i([aoe])', r'j\1', text) 313 | text = re.sub('u([aoəe])', r'w\1', text) 314 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 315 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 316 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 317 | return text 318 | 319 | 320 | def chinese_to_ipa2(text): 321 | text = number_to_chinese(text) 322 | text = chinese_to_bopomofo(text) 323 | text = latin_to_bopomofo(text) 324 | text = bopomofo_to_ipa2(text) 325 | text = re.sub(r'i([aoe])', r'j\1', text) 326 | text = re.sub(r'u([aoəe])', r'w\1', text) 327 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 328 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 329 | return text 330 | -------------------------------------------------------------------------------- /vits/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | DEFAULT_MIN_BIN_WIDTH = 1e-3 7 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 8 | DEFAULT_MIN_DERIVATIVE = 1e-3 9 | 10 | 11 | def piecewise_rational_quadratic_transform(inputs, 12 | unnormalized_widths, 13 | unnormalized_heights, 14 | unnormalized_derivatives, 15 | inverse=False, 16 | tails=None, 17 | tail_bound=1., 18 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 19 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 20 | min_derivative=DEFAULT_MIN_DERIVATIVE): 21 | if tails is None: 22 | spline_fn = rational_quadratic_spline 23 | spline_kwargs = {} 24 | else: 25 | spline_fn = unconstrained_rational_quadratic_spline 26 | spline_kwargs = { 27 | 'tails': tails, 28 | 'tail_bound': tail_bound 29 | } 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum( 48 | inputs[..., None] >= bin_locations, 49 | dim=-1 50 | ) - 1 51 | 52 | 53 | def unconstrained_rational_quadratic_spline(inputs, 54 | unnormalized_widths, 55 | unnormalized_heights, 56 | unnormalized_derivatives, 57 | inverse=False, 58 | tails='linear', 59 | tail_bound=1., 60 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 61 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 62 | min_derivative=DEFAULT_MIN_DERIVATIVE): 63 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 64 | outside_interval_mask = ~inside_interval_mask 65 | 66 | outputs = torch.zeros_like(inputs) 67 | logabsdet = torch.zeros_like(inputs) 68 | 69 | if tails == 'linear': 70 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 71 | constant = np.log(np.exp(1 - min_derivative) - 1) 72 | unnormalized_derivatives[..., 0] = constant 73 | unnormalized_derivatives[..., -1] = constant 74 | 75 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 76 | logabsdet[outside_interval_mask] = 0 77 | else: 78 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 79 | 80 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 81 | inputs=inputs[inside_interval_mask], 82 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 83 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 84 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 85 | inverse=inverse, 86 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 87 | min_bin_width=min_bin_width, 88 | min_bin_height=min_bin_height, 89 | min_derivative=min_derivative 90 | ) 91 | 92 | return outputs, logabsdet 93 | 94 | 95 | def rational_quadratic_spline(inputs, 96 | unnormalized_widths, 97 | unnormalized_heights, 98 | unnormalized_derivatives, 99 | inverse=False, 100 | left=0., right=1., bottom=0., top=1., 101 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 102 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 103 | min_derivative=DEFAULT_MIN_DERIVATIVE): 104 | if torch.min(inputs) < left or torch.max(inputs) > right: 105 | raise ValueError('Input to a transform is not within its domain') 106 | 107 | num_bins = unnormalized_widths.shape[-1] 108 | 109 | if min_bin_width * num_bins > 1.0: 110 | raise ValueError('Minimal bin width too large for the number of bins') 111 | if min_bin_height * num_bins > 1.0: 112 | raise ValueError('Minimal bin height too large for the number of bins') 113 | 114 | widths = F.softmax(unnormalized_widths, dim=-1) 115 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 116 | cumwidths = torch.cumsum(widths, dim=-1) 117 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 118 | cumwidths = (right - left) * cumwidths + left 119 | cumwidths[..., 0] = left 120 | cumwidths[..., -1] = right 121 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 122 | 123 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 124 | 125 | heights = F.softmax(unnormalized_heights, dim=-1) 126 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 127 | cumheights = torch.cumsum(heights, dim=-1) 128 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 129 | cumheights = (top - bottom) * cumheights + bottom 130 | cumheights[..., 0] = bottom 131 | cumheights[..., -1] = top 132 | heights = cumheights[..., 1:] - cumheights[..., :-1] 133 | 134 | if inverse: 135 | bin_idx = searchsorted(cumheights, inputs)[..., None] 136 | else: 137 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 138 | 139 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 140 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 141 | 142 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 143 | delta = heights / widths 144 | input_delta = delta.gather(-1, bin_idx)[..., 0] 145 | 146 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 147 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 148 | 149 | input_heights = heights.gather(-1, bin_idx)[..., 0] 150 | 151 | if inverse: 152 | a = (((inputs - input_cumheights) * (input_derivatives 153 | + input_derivatives_plus_one 154 | - 2 * input_delta) 155 | + input_heights * (input_delta - input_derivatives))) 156 | b = (input_heights * input_derivatives 157 | - (inputs - input_cumheights) * (input_derivatives 158 | + input_derivatives_plus_one 159 | - 2 * input_delta)) 160 | c = - input_delta * (inputs - input_cumheights) 161 | 162 | discriminant = b.pow(2) - 4 * a * c 163 | assert (discriminant >= 0).all() 164 | 165 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 166 | outputs = root * input_bin_widths + input_cumwidths 167 | 168 | theta_one_minus_theta = root * (1 - root) 169 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 170 | * theta_one_minus_theta) 171 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 172 | + 2 * input_delta * theta_one_minus_theta 173 | + input_derivatives * (1 - root).pow(2)) 174 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 175 | 176 | return outputs, -logabsdet 177 | else: 178 | theta = (inputs - input_cumwidths) / input_bin_widths 179 | theta_one_minus_theta = theta * (1 - theta) 180 | 181 | numerator = input_heights * (input_delta * theta.pow(2) 182 | + input_derivatives * theta_one_minus_theta) 183 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 184 | * theta_one_minus_theta) 185 | outputs = input_cumheights + numerator / denominator 186 | 187 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 188 | + 2 * input_delta * theta_one_minus_theta 189 | + input_derivatives * (1 - theta).pow(2)) 190 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 191 | 192 | return outputs, logabsdet 193 | -------------------------------------------------------------------------------- /modules/sovits_model.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import time 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | import librosa 8 | import numpy as np 9 | import parselmouth 10 | import soundfile 11 | import torch 12 | import torchaudio 13 | 14 | import modules.utils as utils 15 | from modules.devices import device 16 | from modules.model import ModelInfo 17 | from modules.model import refresh_model_list 18 | from modules.options import cmd_opts 19 | from modules.utils import model_hash 20 | from repositories.sovits.hubert import hubert_model 21 | from repositories.sovits.models import SynthesizerTrn 22 | 23 | MODEL_PATH = os.path.join(os.path.join(os.getcwd(), "models"), "sovits") 24 | 25 | 26 | class Svc(object): 27 | def __init__(self, net_g_path, config_path, hubert_path="models/hubert/hubert-soft-0d54a1f4.pt"): 28 | self.net_g_path = net_g_path 29 | self.hubert_path = hubert_path 30 | self.dev = device 31 | self.net_g_ms = None 32 | self.hps_ms = utils.get_hparams_from_file(config_path) 33 | self.target_sample = self.hps_ms.data.sampling_rate 34 | self.hop_size = self.hps_ms.data.hop_length 35 | self.speakers = {} 36 | for spk, sid in self.hps_ms.spk.items(): 37 | self.speakers[sid] = spk 38 | self.spk2id = self.hps_ms.spk 39 | 40 | def load_model(self): 41 | self.hubert_soft = hubert_model.hubert_soft(self.hubert_path) 42 | if self.dev != torch.device("cpu"): 43 | self.hubert_soft = self.hubert_soft.cuda() 44 | self.net_g_ms = SynthesizerTrn( 45 | self.hps_ms.data.filter_length // 2 + 1, 46 | self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, 47 | **self.hps_ms.model) 48 | _ = load_checkpoint(self.net_g_path, self.net_g_ms, None) 49 | if "half" in self.net_g_path and self.dev != torch.device("cpu"): 50 | _ = self.net_g_ms.half().eval().to(self.dev) 51 | else: 52 | _ = self.net_g_ms.eval().to(self.dev) 53 | 54 | def get_units(self, source, sr): 55 | source = source.unsqueeze(0).to(self.dev) 56 | with torch.inference_mode(): 57 | start = time.time() 58 | units = self.hubert_soft.units(source) 59 | use_time = time.time() - start 60 | if cmd_opts.debug: 61 | print("hubert use time: {}".format(use_time)) 62 | return units 63 | 64 | def get_unit_pitch(self, in_path, tran): 65 | source, sr = torchaudio.load(in_path) 66 | source = torchaudio.functional.resample(source, sr, 16000) 67 | if len(source.shape) == 2 and source.shape[1] >= 2: 68 | source = torch.mean(source, dim=0).unsqueeze(0) 69 | soft = self.get_units(source, sr).squeeze(0).cpu().numpy() 70 | f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0] * 2, tran) 71 | return soft, f0 72 | 73 | def infer(self, speaker_id, tran, raw_path): 74 | if type(speaker_id) == str: 75 | speaker_id = self.spk2id[speaker_id] 76 | sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0) 77 | soft, pitch = self.get_unit_pitch(raw_path, tran) 78 | f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.dev) 79 | if "half" in self.net_g_path and torch.cuda.is_available(): 80 | stn_tst = torch.HalfTensor(soft) 81 | else: 82 | stn_tst = torch.FloatTensor(soft) 83 | with torch.no_grad(): 84 | x_tst = stn_tst.unsqueeze(0).to(self.dev) 85 | start = time.time() 86 | x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2) 87 | audio = self.net_g_ms.infer(x_tst, f0=f0, g=sid)[0, 0].data.float() 88 | use_time = time.time() - start 89 | if cmd_opts.debug: 90 | print("vits use time: {}".format(use_time)) 91 | return audio, audio.shape[-1] 92 | 93 | 94 | class SovitsModel(Svc): 95 | def __init__(self, info: ModelInfo): 96 | super(SovitsModel, self).__init__(info.checkpoint_path, info.config_path) 97 | self.model_name = info.model_name 98 | 99 | 100 | sovits_model_list: Dict[str, ModelInfo] = {} 101 | curr_model: SovitsModel = None 102 | 103 | 104 | def get_model() -> Svc: 105 | return curr_model 106 | 107 | 108 | def get_model_name(): 109 | return curr_model.model_name if curr_model is not None else None 110 | 111 | 112 | def get_model_list(): 113 | return [k for k, _ in sovits_model_list.items()] 114 | 115 | 116 | def get_speakers(): 117 | if curr_model is None: 118 | return ["None"] 119 | return [spk for sid, spk in curr_model.speakers.items() if spk != "None"] 120 | 121 | 122 | def refresh_list(): 123 | sovits_model_list.clear() 124 | model_list = refresh_model_list(model_path=MODEL_PATH) 125 | for m in model_list: 126 | d = m["dir"] 127 | p = os.path.join(MODEL_PATH, m["dir"]) 128 | pth_path = m["pth"] 129 | config_path = m["config"] 130 | sovits_model_list[d] = ModelInfo( 131 | model_name=d, 132 | model_folder=p, 133 | model_hash=model_hash(pth_path), 134 | checkpoint_path=pth_path, 135 | config_path=config_path 136 | ) 137 | if len(sovits_model_list.items()) == 0: 138 | print("No so-vits model found. Please put a model in models/sovits") 139 | 140 | 141 | # def init_model(): 142 | # global curr_sovits_model 143 | # info = next(iter(sovits_model_list.values())) 144 | # # load_model(info.model_name) 145 | # curr_sovits_model = SovitsModel(info) 146 | 147 | 148 | def load_model(model_name: str): 149 | global curr_model, sovits_model_list 150 | if curr_model and curr_model.model_name == model_name: 151 | return 152 | info = sovits_model_list[model_name] 153 | print(f"Loading so-vits weights [{info.model_hash}] from {info.checkpoint_path}...") 154 | m = SovitsModel(info) 155 | m.load_model() 156 | curr_model = m 157 | print("so-vits model loaded.") 158 | 159 | 160 | def load_checkpoint(checkpoint_path, model, optimizer=None): 161 | assert os.path.isfile(checkpoint_path) 162 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 163 | iteration = checkpoint_dict['iteration'] 164 | learning_rate = checkpoint_dict['learning_rate'] 165 | if iteration is None: 166 | iteration = 1 167 | if learning_rate is None: 168 | learning_rate = 0.0002 169 | if optimizer is not None and checkpoint_dict['optimizer'] is not None: 170 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 171 | saved_state_dict = checkpoint_dict['model'] 172 | if hasattr(model, 'module'): 173 | state_dict = model.module.state_dict() 174 | else: 175 | state_dict = model.state_dict() 176 | new_state_dict = {} 177 | for k, v in state_dict.items(): 178 | try: 179 | new_state_dict[k] = saved_state_dict[k] 180 | except: 181 | # logger.info("%s is not in the checkpoint" % k) 182 | new_state_dict[k] = v 183 | if hasattr(model, 'module'): 184 | model.module.load_state_dict(new_state_dict) 185 | else: 186 | model.load_state_dict(new_state_dict) 187 | print(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})") 188 | return model, optimizer, learning_rate, iteration 189 | 190 | 191 | def format_wav(audio_path): 192 | if Path(audio_path).suffix == '.wav': 193 | return 194 | raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None) 195 | soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate) 196 | 197 | 198 | def get_end_file(dir_path, end): 199 | file_lists = [] 200 | for root, dirs, files in os.walk(dir_path): 201 | files = [f for f in files if f[0] != '.'] 202 | dirs[:] = [d for d in dirs if d[0] != '.'] 203 | for f_file in files: 204 | if f_file.endswith(end): 205 | file_lists.append(os.path.join(root, f_file).replace("\\", "/")) 206 | return file_lists 207 | 208 | 209 | def get_md5(content): 210 | return hashlib.new("md5", content).hexdigest() 211 | 212 | 213 | def resize2d_f0(x, target_len): 214 | source = np.array(x) 215 | source[source < 0.001] = np.nan 216 | target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)), 217 | source) 218 | res = np.nan_to_num(target) 219 | return res 220 | 221 | 222 | def get_f0(x, p_len, f0_up_key=0): 223 | time_step = 160 / 16000 * 1000 224 | f0_min = 50 225 | f0_max = 1100 226 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 227 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 228 | 229 | f0 = parselmouth.Sound(x, 16000).to_pitch_ac( 230 | time_step=time_step / 1000, voicing_threshold=0.6, 231 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 232 | if len(f0) > p_len: 233 | f0 = f0[:p_len] 234 | pad_size = (p_len - len(f0) + 1) // 2 235 | if (pad_size > 0 or p_len - len(f0) - pad_size > 0): 236 | f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode='constant') 237 | 238 | f0 *= pow(2, f0_up_key / 12) 239 | f0_mel = 1127 * np.log(1 + f0 / 700) 240 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1 241 | f0_mel[f0_mel <= 1] = 1 242 | f0_mel[f0_mel > 255] = 255 243 | f0_coarse = np.rint(f0_mel).astype(np.int) 244 | return f0_coarse, f0 245 | 246 | 247 | def clean_pitch(input_pitch): 248 | num_nan = np.sum(input_pitch == 1) 249 | if num_nan / len(input_pitch) > 0.9: 250 | input_pitch[input_pitch != 1] = 1 251 | return input_pitch 252 | 253 | 254 | def plt_pitch(input_pitch): 255 | input_pitch = input_pitch.astype(float) 256 | input_pitch[input_pitch == 1] = np.nan 257 | return input_pitch 258 | 259 | 260 | def f0_to_pitch(ff): 261 | f0_pitch = 69 + 12 * np.log2(ff / 440) 262 | return f0_pitch 263 | 264 | 265 | def fill_a_to_b(a, b): 266 | if len(a) < len(b): 267 | for _ in range(0, len(b) - len(a)): 268 | a.append(a[0]) 269 | -------------------------------------------------------------------------------- /modules/ui.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import gradio as gr 4 | 5 | import modules.sovits_model as sovits_model 6 | import modules.vits_model as vits_model 7 | from modules.localization import gen_localization_js 8 | from modules.options import opts 9 | from modules.process import text2speech, sovits_process 10 | from modules.utils import open_folder 11 | from modules.vits_model import get_model_list, refresh_list 12 | 13 | refresh_symbol = "\U0001f504" # 🔄 14 | folder_symbol = '\U0001f4c2' # 📂 15 | 16 | _gradio_template_response_orig = gr.routes.templates.TemplateResponse 17 | script_path = "scripts" 18 | 19 | 20 | class ToolButton(gr.Button, gr.components.FormComponent): 21 | def __init__(self, **kwargs): 22 | super().__init__(variant="tool", **kwargs) 23 | 24 | def get_block_name(self): 25 | return "button" 26 | 27 | 28 | def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): 29 | def refresh(): 30 | refresh_method() 31 | args = refreshed_args() if callable(refreshed_args) else refreshed_args 32 | 33 | for k, v in args.items(): 34 | setattr(refresh_component, k, v) 35 | 36 | return gr.update(**(args or {})) 37 | 38 | refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) 39 | refresh_button.click( 40 | fn=refresh, 41 | inputs=[], 42 | outputs=[refresh_component] 43 | ) 44 | return refresh_button 45 | 46 | 47 | def create_setting_component(key): 48 | def fun(): 49 | return opts.data[key] if key in opts.data else opts.data_labels[key].default 50 | 51 | info = opts.data_labels[key] 52 | t = type(info.default) 53 | 54 | args = info.component_args() if callable(info.component_args) else info.component_args 55 | 56 | if info.component is not None: 57 | comp = info.component 58 | elif t == str: 59 | comp = gr.Textbox 60 | elif t == int: 61 | comp = gr.Number 62 | elif t == bool: 63 | comp = gr.Checkbox 64 | else: 65 | raise Exception(f'bad options item type: {str(t)} for key {key}') 66 | 67 | elem_id = "setting_" + key 68 | 69 | if info.refresh is not None: 70 | with gr.Row(): 71 | res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) 72 | create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) 73 | else: 74 | res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) 75 | 76 | return res 77 | 78 | 79 | def change_vits_model(model_name): 80 | vits_model.load_model(model_name) 81 | speakers = vits_model.get_speakers() 82 | return gr.update(choices=speakers, value=speakers[0]) 83 | 84 | 85 | def change_sovits_model(model_name): 86 | sovits_model.load_model(model_name) 87 | speakers = sovits_model.get_speakers() 88 | return gr.update(choices=speakers, value=speakers[0]) 89 | 90 | 91 | def create_ui(): 92 | css = "style.css" 93 | component_dict = {} 94 | reload_javascript() 95 | 96 | vits_model_list = vits_model.get_model_list() 97 | vits_speakers = vits_model.get_speakers() 98 | 99 | sovits_model_list = sovits_model.get_model_list() 100 | sovits_speakers = sovits_model.get_speakers() 101 | 102 | with gr.Blocks(analytics_enabled=False) as txt2img_interface: 103 | with gr.Row(elem_id="toprow"): 104 | with gr.Column(scale=6): 105 | with gr.Row(): 106 | with gr.Column(scale=80): 107 | with gr.Row(): 108 | input_text = gr.Textbox(label="Text", show_label=False, lines=4, 109 | placeholder="Text (press Ctrl+Enter or Alt+Enter to generate)") 110 | with gr.Column(scale=1): 111 | with gr.Row(): 112 | vits_submit_btn = gr.Button("Generate", elem_id=f"vits_generate", variant="primary") 113 | 114 | with gr.Row().style(equal_height=False): 115 | with gr.Column(variant="panel", elem_id="vits_settings"): 116 | with gr.Row(): 117 | vits_model_picker = gr.Dropdown(label="VITS Checkpoint", choices=vits_model_list, 118 | value=vits_model.get_model_name()) 119 | create_refresh_button(vits_model_picker, refresh_method=vits_model.refresh_list(), 120 | refreshed_args=lambda: {"choices": vits_model.get_model_list()}, 121 | elem_id="vits_model_refresh") 122 | with gr.Row(): 123 | process_method = gr.Radio(label="Process Method", 124 | choices=["Simple", "Batch Process", "Multi Speakers"], 125 | value="Simple") 126 | 127 | with gr.Row(): 128 | speaker_index = gr.Dropdown(label="Speakers", 129 | choices=vits_speakers, value=vits_speakers[0]) 130 | 131 | speed = gr.Slider(value=1, minimum=0.5, maximum=2, step=0.1, 132 | elem_id=f"vits_speed", 133 | label="Speed") 134 | 135 | with gr.Column(variant="panel", elem_id="vits_output"): 136 | tts_output1 = gr.Textbox(label="Output Message") 137 | tts_output2 = gr.Audio(label="Output Audio", elem_id=f"vits_audio") 138 | 139 | with gr.Column(): 140 | with gr.Row(elem_id=f"functional_buttons"): 141 | open_folder_button = gr.Button(f"{folder_symbol} Open Folder", elem_id=f'open_vits_folder') 142 | save_button = gr.Button('Save', elem_id=f'save') 143 | 144 | open_folder_button.click(fn=lambda: open_folder("outputs/vits")) 145 | 146 | vits_model_picker.change( 147 | fn=change_vits_model, 148 | inputs=[vits_model_picker], 149 | outputs=[speaker_index] 150 | ) 151 | 152 | vits_submit_btn.click( 153 | fn=text2speech, 154 | inputs=[ 155 | input_text, 156 | speaker_index, 157 | speed, 158 | process_method 159 | ], 160 | outputs=[ 161 | tts_output1, 162 | tts_output2 163 | ] 164 | ) 165 | 166 | with gr.Blocks(analytics_enabled=False) as sovits_interface: 167 | with gr.Row(): 168 | with gr.Column(scale=6, elem_id="sovits_audio_panel"): 169 | sovits_audio_input = gr.File(label="Upload Audio File", elem_id=f"sovits_input_audio", file_count="multiple") 170 | with gr.Column(scale=1): 171 | with gr.Row(): 172 | sovits_submit_btn = gr.Button("Generate", elem_id=f"sovits_generate", variant="primary") 173 | 174 | with gr.Row().style(equal_height=False): 175 | with gr.Column(variant="panel", elem_id="sovits_settings"): 176 | with gr.Row(): 177 | sovits_model_picker = gr.Dropdown(label="SO-VITS Checkpoint", choices=sovits_model_list, 178 | value=sovits_model.get_model_name()) 179 | create_refresh_button(sovits_model_picker, refresh_method=sovits_model.refresh_list(), 180 | refreshed_args=lambda: {"choices": sovits_model.get_model_list()}, 181 | elem_id="sovits_model_refresh") 182 | 183 | with gr.Row(): 184 | sovits_speaker_index = gr.Dropdown(label="Speakers", 185 | choices=sovits_speakers, value=sovits_speakers[0]) 186 | 187 | with gr.Row(): 188 | vc_transform = gr.Slider(value=0, minimum=-20, maximum=20, step=1, 189 | elem_id=f"vc_transform", 190 | label="VC Transform") 191 | slice_db = gr.Slider(value=-40, minimum=-100, maximum=0, step=5, 192 | elem_id=f"slice_db", 193 | label="Slice db") 194 | with gr.Column(variant="panel", elem_id="sovits_output"): 195 | sovits_output1 = gr.Textbox(label="Output Message") 196 | sovits_output2 = gr.Audio(label="Output Audio", elem_id=f"sovits_output_audio") 197 | 198 | sovits_submit_btn.click( 199 | fn=sovits_process, 200 | inputs=[sovits_audio_input, sovits_speaker_index, vc_transform, slice_db], 201 | outputs=[sovits_output1, sovits_output2] 202 | ) 203 | 204 | sovits_model_picker.change( 205 | fn=change_sovits_model, 206 | inputs=[sovits_model_picker], 207 | outputs=[sovits_speaker_index] 208 | ) 209 | 210 | with gr.Blocks(analytics_enabled=False) as settings_interface: 211 | settings_component = [] 212 | 213 | def run_settings(*args): 214 | changed = [] 215 | 216 | for key, value, comp in zip(opts.data_labels.keys(), args, settings_component): 217 | assert opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" 218 | 219 | for key, value, comp in zip(opts.data_labels.keys(), args, settings_component): 220 | if opts.set(key, value): 221 | changed.append(key) 222 | 223 | try: 224 | opts.save() 225 | except RuntimeError: 226 | return f'{len(changed)} settings changed without save: {", ".join(changed)}.' 227 | return f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' 228 | 229 | with gr.Row(): 230 | with gr.Column(scale=6): 231 | settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") 232 | with gr.Column(): 233 | restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") 234 | 235 | settings_result = gr.HTML(elem_id="settings_result") 236 | 237 | for i, (k, item) in enumerate(opts.data_labels.items()): 238 | component = create_setting_component(k) 239 | component_dict[k] = component 240 | settings_component.append(component) 241 | 242 | settings_submit.click( 243 | fn=run_settings, 244 | inputs=settings_component, 245 | outputs=[settings_result], 246 | ) 247 | 248 | interfaces = [ 249 | (txt2img_interface, "VITS", "vits"), 250 | (sovits_interface, "SO-VITS", "sovits"), 251 | (settings_interface, "Settings", "settings") 252 | ] 253 | 254 | with gr.Blocks(css=css, analytics_enabled=False, title="VITS") as demo: 255 | with gr.Tabs(elem_id="tabs") as tabs: 256 | for interface, label, ifid in interfaces: 257 | with gr.TabItem(label, id=ifid, elem_id="tab_" + ifid): 258 | interface.render() 259 | 260 | component_keys = [k for k in opts.data_labels.keys() if k in component_dict] 261 | 262 | # def get_settings_values(): 263 | # return [getattr(opts, key) for key in component_keys] 264 | # 265 | # demo.load( 266 | # fn=get_settings_values, 267 | # inputs=[], 268 | # outputs=[component_dict[k] for k in component_keys], 269 | # ) 270 | 271 | return demo 272 | 273 | 274 | def reload_javascript(): 275 | scripts_list = [os.path.join(script_path, i) for i in os.listdir(script_path) if i.endswith(".js")] 276 | with open("script.js", "r", encoding="utf8") as jsfile: 277 | javascript = f'' 278 | 279 | for path in scripts_list: 280 | with open(path, "r", encoding="utf8") as jsfile: 281 | javascript += f"\n" 282 | 283 | javascript += gen_localization_js(opts.localization) 284 | 285 | # todo: theme 286 | # if cmd_opts.theme is not None: 287 | # javascript += f"\n\n" 288 | 289 | def template_response(*args, **kwargs): 290 | res = _gradio_template_response_orig(*args, **kwargs) 291 | res.body = res.body.replace( 292 | b'', f'{javascript}'.encode("utf8")) 293 | res.init_headers() 294 | return res 295 | 296 | gr.routes.templates.TemplateResponse = template_response 297 | -------------------------------------------------------------------------------- /style.css: -------------------------------------------------------------------------------- 1 | .container { 2 | max-width: 100%; 3 | } 4 | 5 | [id$=_generate] { 6 | min-height: 6.5em; 7 | margin-top: 0.5em; 8 | margin-right: 0.25em; 9 | } 10 | 11 | /*#sovits_audio_panel {*/ 12 | /* min-height: 5.5em;*/ 13 | /*}*/ 14 | 15 | #sh { 16 | min-width: 2em; 17 | min-height: 2em; 18 | max-width: 2em; 19 | max-height: 2em; 20 | flex-grow: 0; 21 | padding-left: 0.25em; 22 | padding-right: 0.25em; 23 | margin: 0.1em 0; 24 | opacity: 0%; 25 | cursor: default; 26 | } 27 | 28 | .output-html p { 29 | margin: 0 0.5em; 30 | } 31 | 32 | .row > *, 33 | .row > .gr-form > * { 34 | min-width: min(120px, 100%); 35 | flex: 1 1 0%; 36 | } 37 | 38 | .performance { 39 | font-size: 0.85em; 40 | color: #444; 41 | } 42 | 43 | .performance p { 44 | display: inline-block; 45 | } 46 | 47 | .performance .time { 48 | margin-right: 0; 49 | } 50 | 51 | .performance .vram { 52 | } 53 | 54 | .justify-center.overflow-x-scroll { 55 | justify-content: left; 56 | } 57 | 58 | .justify-center.overflow-x-scroll button:first-of-type { 59 | margin-left: auto; 60 | } 61 | 62 | .justify-center.overflow-x-scroll button:last-of-type { 63 | margin-right: auto; 64 | } 65 | 66 | [id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder { 67 | min-width: 2.3em; 68 | height: 2.5em; 69 | flex-grow: 0; 70 | padding-left: 0.25em; 71 | padding-right: 0.25em; 72 | } 73 | 74 | #hidden_element { 75 | display: none; 76 | } 77 | 78 | [id$=_seed_row], [id$=_subseed_row] { 79 | gap: 0.5rem; 80 | padding: 0.6em; 81 | } 82 | 83 | [id$=_subseed_show_box] { 84 | min-width: auto; 85 | flex-grow: 0; 86 | } 87 | 88 | [id$=_subseed_show_box] > div { 89 | border: 0; 90 | height: 100%; 91 | } 92 | 93 | [id$=_subseed_show] { 94 | min-width: auto; 95 | flex-grow: 0; 96 | padding: 0; 97 | } 98 | 99 | [id$=_subseed_show] label { 100 | height: 100%; 101 | } 102 | 103 | #roll_col { 104 | min-width: unset !important; 105 | flex-grow: 0 !important; 106 | padding: 0.4em 0; 107 | } 108 | 109 | #roll_col > button { 110 | min-width: 2em; 111 | min-height: 2em; 112 | max-width: 2em; 113 | max-height: 2em; 114 | flex-grow: 0; 115 | padding-left: 0.25em; 116 | padding-right: 0.25em; 117 | margin: 0.1em 0; 118 | } 119 | 120 | .gr-form { 121 | background: transparent; 122 | } 123 | 124 | .my-4 { 125 | margin-top: 0; 126 | margin-bottom: 0; 127 | } 128 | 129 | #toprow div { 130 | border: none; 131 | gap: 0; 132 | background: transparent; 133 | } 134 | 135 | #resize_mode { 136 | flex: 1.5; 137 | } 138 | 139 | button { 140 | align-self: stretch !important; 141 | } 142 | 143 | .overflow-hidden, .gr-panel { 144 | overflow: visible !important; 145 | } 146 | 147 | #x_type, #y_type { 148 | max-width: 10em; 149 | } 150 | 151 | 152 | fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span { 153 | position: absolute; 154 | top: -0.7em; 155 | line-height: 1.2em; 156 | padding: 0; 157 | margin: 0 0.5em; 158 | 159 | background-color: white; 160 | box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white; 161 | 162 | z-index: 300; 163 | } 164 | 165 | .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span { 166 | background-color: rgb(31, 41, 55); 167 | box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55); 168 | } 169 | 170 | #txt2img_column_batch, #img2img_column_batch { 171 | min-width: min(13.5em, 100%) !important; 172 | } 173 | 174 | #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span { 175 | position: relative; 176 | border: none; 177 | margin-right: 8em; 178 | } 179 | 180 | #settings .gr-panel div.flex-col div.justify-between div { 181 | position: relative; 182 | z-index: 200; 183 | } 184 | 185 | #settings { 186 | display: block; 187 | } 188 | 189 | #settings > div { 190 | border: none; 191 | margin-left: 10em; 192 | } 193 | 194 | #settings > div.flex-wrap { 195 | float: left; 196 | display: block; 197 | margin-left: 0; 198 | width: 10em; 199 | } 200 | 201 | #settings > div.flex-wrap button { 202 | display: block; 203 | border: none; 204 | text-align: left; 205 | } 206 | 207 | #settings_result { 208 | height: 1.4em; 209 | margin: 0 1.2em; 210 | } 211 | 212 | input[type="range"] { 213 | margin: 0.5em 0 -0.3em 0; 214 | } 215 | 216 | #mask_bug_info { 217 | text-align: center; 218 | display: block; 219 | margin-top: -0.75em; 220 | margin-bottom: -0.75em; 221 | } 222 | 223 | 224 | .transition.opacity-20 { 225 | opacity: 1 !important; 226 | } 227 | 228 | 229 | .min-h-\[4rem\] { 230 | min-height: unset !important; 231 | } 232 | 233 | .progressDiv { 234 | width: 100%; 235 | height: 20px; 236 | background: #b4c0cc; 237 | border-radius: 8px; 238 | } 239 | 240 | .dark .progressDiv { 241 | background: #424c5b; 242 | } 243 | 244 | .progressDiv .progress { 245 | width: 0%; 246 | height: 20px; 247 | background: #0060df; 248 | color: white; 249 | font-weight: bold; 250 | line-height: 20px; 251 | padding: 0 8px 0 0; 252 | text-align: right; 253 | border-radius: 8px; 254 | } 255 | 256 | #lightboxModal { 257 | display: none; 258 | position: fixed; 259 | z-index: 1001; 260 | padding-top: 100px; 261 | left: 0; 262 | top: 0; 263 | width: 100%; 264 | height: 100%; 265 | overflow: auto; 266 | background-color: rgba(20, 20, 20, 0.95); 267 | user-select: none; 268 | -webkit-user-select: none; 269 | } 270 | 271 | .modalControls { 272 | display: grid; 273 | grid-template-columns: 32px 32px 32px 1fr 32px; 274 | grid-template-areas: "zoom tile save space close"; 275 | position: absolute; 276 | top: 0; 277 | left: 0; 278 | right: 0; 279 | padding: 16px; 280 | gap: 16px; 281 | background-color: rgba(0, 0, 0, 0.2); 282 | } 283 | 284 | .modalClose { 285 | grid-area: close; 286 | } 287 | 288 | .modalZoom { 289 | grid-area: zoom; 290 | } 291 | 292 | .modalSave { 293 | grid-area: save; 294 | } 295 | 296 | .modalTileImage { 297 | grid-area: tile; 298 | } 299 | 300 | .modalClose, 301 | .modalZoom, 302 | .modalTileImage { 303 | color: white; 304 | font-size: 35px; 305 | font-weight: bold; 306 | cursor: pointer; 307 | } 308 | 309 | .modalSave { 310 | color: white; 311 | font-size: 28px; 312 | margin-top: 8px; 313 | font-weight: bold; 314 | cursor: pointer; 315 | } 316 | 317 | .modalClose:hover, 318 | .modalClose:focus, 319 | .modalSave:hover, 320 | .modalSave:focus, 321 | .modalZoom:hover, 322 | .modalZoom:focus { 323 | color: #999; 324 | text-decoration: none; 325 | cursor: pointer; 326 | } 327 | 328 | #modalImage { 329 | display: block; 330 | margin-left: auto; 331 | margin-right: auto; 332 | margin-top: auto; 333 | width: auto; 334 | } 335 | 336 | .modalImageFullscreen { 337 | object-fit: contain; 338 | height: 90%; 339 | } 340 | 341 | .modalPrev, 342 | .modalNext { 343 | cursor: pointer; 344 | position: absolute; 345 | top: 50%; 346 | width: auto; 347 | padding: 16px; 348 | margin-top: -50px; 349 | color: white; 350 | font-weight: bold; 351 | font-size: 20px; 352 | transition: 0.6s ease; 353 | border-radius: 0 3px 3px 0; 354 | user-select: none; 355 | -webkit-user-select: none; 356 | } 357 | 358 | .modalNext { 359 | right: 0; 360 | border-radius: 3px 0 0 3px; 361 | } 362 | 363 | .modalPrev:hover, 364 | .modalNext:hover { 365 | background-color: rgba(0, 0, 0, 0.8); 366 | } 367 | 368 | #imageARPreview { 369 | position: absolute; 370 | top: 0px; 371 | left: 0px; 372 | border: 2px solid red; 373 | background: rgba(255, 0, 0, 0.3); 374 | z-index: 900; 375 | pointer-events: none; 376 | display: none 377 | } 378 | 379 | .red { 380 | color: red; 381 | } 382 | 383 | .gallery-item { 384 | --tw-bg-opacity: 0 !important; 385 | } 386 | 387 | #context-menu { 388 | z-index: 9999; 389 | position: absolute; 390 | display: block; 391 | padding: 0px 0; 392 | border: 2px solid #a55000; 393 | border-radius: 8px; 394 | box-shadow: 1px 1px 2px #CE6400; 395 | width: 200px; 396 | } 397 | 398 | .context-menu-items { 399 | list-style: none; 400 | margin: 0; 401 | padding: 0; 402 | } 403 | 404 | .context-menu-items a { 405 | display: block; 406 | padding: 5px; 407 | cursor: pointer; 408 | } 409 | 410 | .context-menu-items a:hover { 411 | background: #a55000; 412 | } 413 | 414 | #quicksettings { 415 | gap: 0.4em; 416 | } 417 | 418 | #quicksettings > div { 419 | border: none; 420 | background: none; 421 | flex: unset; 422 | gap: 0.5em; 423 | } 424 | 425 | #quicksettings > div > div { 426 | max-width: 32em; 427 | min-width: 24em; 428 | padding: 0; 429 | } 430 | 431 | canvas[key="mask"] { 432 | z-index: 12 !important; 433 | filter: invert(); 434 | mix-blend-mode: multiply; 435 | pointer-events: none; 436 | } 437 | 438 | 439 | /* gradio 3.4.1 stuff for editable scrollbar values */ 440 | .gr-box > div > div > input.gr-text-input { 441 | position: absolute; 442 | right: 0.5em; 443 | top: -0.6em; 444 | z-index: 400; 445 | width: 8em; 446 | } 447 | 448 | #quicksettings .gr-box > div > div > input.gr-text-input { 449 | top: -1.12em; 450 | } 451 | 452 | .row.gr-compact { 453 | overflow: visible; 454 | } 455 | 456 | .gr-form { 457 | background-color: white; 458 | } 459 | 460 | .dark .gr-form { 461 | background-color: rgb(31 41 55 / var(--tw-bg-opacity)); 462 | } 463 | 464 | .gr-button-tool { 465 | max-width: 2.5em; 466 | min-width: 2.5em !important; 467 | height: 2.4em; 468 | margin: 0.55em 0; 469 | } 470 | 471 | #quicksettings .gr-button-tool { 472 | margin: 0; 473 | } 474 | 475 | 476 | #img2img_settings > div.gr-form, #txt2img_settings > div.gr-form { 477 | padding-top: 0.9em; 478 | } 479 | 480 | #img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form { 481 | border: none; 482 | padding-bottom: 0.5em; 483 | } 484 | 485 | footer { 486 | display: none !important; 487 | } 488 | 489 | #footer { 490 | text-align: center; 491 | } 492 | 493 | #footer div { 494 | display: inline-block; 495 | } 496 | 497 | #footer .versions { 498 | font-size: 85%; 499 | opacity: 0.85; 500 | } 501 | 502 | /* The following handles localization for right-to-left (RTL) languages like Arabic. 503 | The rtl media type will only be activated by the logic in javascript/localization.js. 504 | If you change anything above, you need to make sure it is RTL compliant by just running 505 | your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/. 506 | Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/ 507 | @media rtl { 508 | /* this part was added manually */ 509 | :host { 510 | direction: rtl; 511 | } 512 | 513 | select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress { 514 | direction: ltr; 515 | } 516 | 517 | #script_list > label > select, 518 | #x_type > label > select, 519 | #y_type > label > select { 520 | direction: rtl; 521 | } 522 | 523 | .gr-radio, .gr-checkbox { 524 | margin-left: 0.25em; 525 | } 526 | 527 | /* automatically generated with few manual modifications */ 528 | .performance .time { 529 | margin-right: unset; 530 | margin-left: 0; 531 | } 532 | 533 | .justify-center.overflow-x-scroll { 534 | justify-content: right; 535 | } 536 | 537 | .justify-center.overflow-x-scroll button:first-of-type { 538 | margin-left: unset; 539 | margin-right: auto; 540 | } 541 | 542 | .justify-center.overflow-x-scroll button:last-of-type { 543 | margin-right: unset; 544 | margin-left: auto; 545 | } 546 | 547 | #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span { 548 | margin-right: unset; 549 | margin-left: 8em; 550 | } 551 | 552 | #txt2img_progressbar, #img2img_progressbar, #ti_progressbar { 553 | right: unset; 554 | left: 0; 555 | } 556 | 557 | .progressDiv .progress { 558 | padding: 0 0 0 8px; 559 | text-align: left; 560 | } 561 | 562 | #lightboxModal { 563 | left: unset; 564 | right: 0; 565 | } 566 | 567 | .modalPrev, .modalNext { 568 | border-radius: 3px 0 0 3px; 569 | } 570 | 571 | .modalNext { 572 | right: unset; 573 | left: 0; 574 | border-radius: 0 3px 3px 0; 575 | } 576 | 577 | #imageARPreview { 578 | left: unset; 579 | right: 0px; 580 | } 581 | 582 | #txt2img_skip, #img2img_skip { 583 | right: unset; 584 | left: 0px; 585 | } 586 | 587 | #context-menu { 588 | box-shadow: -1px 1px 2px #CE6400; 589 | } 590 | 591 | .gr-box > div > div > input.gr-text-input { 592 | right: unset; 593 | left: 0.5em; 594 | } 595 | } -------------------------------------------------------------------------------- /vits/attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from . import commons 7 | from .modules import LayerNorm 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, 12 | **kwargs): 13 | super().__init__() 14 | self.hidden_channels = hidden_channels 15 | self.filter_channels = filter_channels 16 | self.n_heads = n_heads 17 | self.n_layers = n_layers 18 | self.kernel_size = kernel_size 19 | self.p_dropout = p_dropout 20 | self.window_size = window_size 21 | 22 | self.drop = nn.Dropout(p_dropout) 23 | self.attn_layers = nn.ModuleList() 24 | self.norm_layers_1 = nn.ModuleList() 25 | self.ffn_layers = nn.ModuleList() 26 | self.norm_layers_2 = nn.ModuleList() 27 | for i in range(self.n_layers): 28 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 29 | window_size=window_size)) 30 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 31 | self.ffn_layers.append( 32 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 33 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 34 | 35 | def forward(self, x, x_mask): 36 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 37 | x = x * x_mask 38 | for i in range(self.n_layers): 39 | y = self.attn_layers[i](x, x, attn_mask) 40 | y = self.drop(y) 41 | x = self.norm_layers_1[i](x + y) 42 | 43 | y = self.ffn_layers[i](x, x_mask) 44 | y = self.drop(y) 45 | x = self.norm_layers_2[i](x + y) 46 | x = x * x_mask 47 | return x 48 | 49 | 50 | class Decoder(nn.Module): 51 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., 52 | proximal_bias=False, proximal_init=True, **kwargs): 53 | super().__init__() 54 | self.hidden_channels = hidden_channels 55 | self.filter_channels = filter_channels 56 | self.n_heads = n_heads 57 | self.n_layers = n_layers 58 | self.kernel_size = kernel_size 59 | self.p_dropout = p_dropout 60 | self.proximal_bias = proximal_bias 61 | self.proximal_init = proximal_init 62 | 63 | self.drop = nn.Dropout(p_dropout) 64 | self.self_attn_layers = nn.ModuleList() 65 | self.norm_layers_0 = nn.ModuleList() 66 | self.encdec_attn_layers = nn.ModuleList() 67 | self.norm_layers_1 = nn.ModuleList() 68 | self.ffn_layers = nn.ModuleList() 69 | self.norm_layers_2 = nn.ModuleList() 70 | for i in range(self.n_layers): 71 | self.self_attn_layers.append( 72 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, 73 | proximal_bias=proximal_bias, proximal_init=proximal_init)) 74 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 75 | self.encdec_attn_layers.append( 76 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 77 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 78 | self.ffn_layers.append( 79 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 80 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 81 | 82 | def forward(self, x, x_mask, h, h_mask): 83 | """ 84 | x: decoder input 85 | h: encoder output 86 | """ 87 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 88 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 89 | x = x * x_mask 90 | for i in range(self.n_layers): 91 | y = self.self_attn_layers[i](x, x, self_attn_mask) 92 | y = self.drop(y) 93 | x = self.norm_layers_0[i](x + y) 94 | 95 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 96 | y = self.drop(y) 97 | x = self.norm_layers_1[i](x + y) 98 | 99 | y = self.ffn_layers[i](x, x_mask) 100 | y = self.drop(y) 101 | x = self.norm_layers_2[i](x + y) 102 | x = x * x_mask 103 | return x 104 | 105 | 106 | class MultiHeadAttention(nn.Module): 107 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, 108 | block_length=None, proximal_bias=False, proximal_init=False): 109 | super().__init__() 110 | assert channels % n_heads == 0 111 | 112 | self.channels = channels 113 | self.out_channels = out_channels 114 | self.n_heads = n_heads 115 | self.p_dropout = p_dropout 116 | self.window_size = window_size 117 | self.heads_share = heads_share 118 | self.block_length = block_length 119 | self.proximal_bias = proximal_bias 120 | self.proximal_init = proximal_init 121 | self.attn = None 122 | 123 | self.k_channels = channels // n_heads 124 | self.conv_q = nn.Conv1d(channels, channels, 1) 125 | self.conv_k = nn.Conv1d(channels, channels, 1) 126 | self.conv_v = nn.Conv1d(channels, channels, 1) 127 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 128 | self.drop = nn.Dropout(p_dropout) 129 | 130 | if window_size is not None: 131 | n_heads_rel = 1 if heads_share else n_heads 132 | rel_stddev = self.k_channels ** -0.5 133 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 134 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 135 | 136 | nn.init.xavier_uniform_(self.conv_q.weight) 137 | nn.init.xavier_uniform_(self.conv_k.weight) 138 | nn.init.xavier_uniform_(self.conv_v.weight) 139 | if proximal_init: 140 | with torch.no_grad(): 141 | self.conv_k.weight.copy_(self.conv_q.weight) 142 | self.conv_k.bias.copy_(self.conv_q.bias) 143 | 144 | def forward(self, x, c, attn_mask=None): 145 | q = self.conv_q(x) 146 | k = self.conv_k(c) 147 | v = self.conv_v(c) 148 | 149 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 150 | 151 | x = self.conv_o(x) 152 | return x 153 | 154 | def attention(self, query, key, value, mask=None): 155 | # reshape [b, d, t] -> [b, n_h, t, d_k] 156 | b, d, t_s, t_t = (*key.size(), query.size(2)) 157 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 158 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 159 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 160 | 161 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 162 | if self.window_size is not None: 163 | assert t_s == t_t, "Relative attention is only available for self-attention." 164 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 165 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings) 166 | scores_local = self._relative_position_to_absolute_position(rel_logits) 167 | scores = scores + scores_local 168 | if self.proximal_bias: 169 | assert t_s == t_t, "Proximal bias is only available for self-attention." 170 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 171 | if mask is not None: 172 | scores = scores.masked_fill(mask == 0, -1e4) 173 | if self.block_length is not None: 174 | assert t_s == t_t, "Local attention is only available for self-attention." 175 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 176 | scores = scores.masked_fill(block_mask == 0, -1e4) 177 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 178 | p_attn = self.drop(p_attn) 179 | output = torch.matmul(p_attn, value) 180 | if self.window_size is not None: 181 | relative_weights = self._absolute_position_to_relative_position(p_attn) 182 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 183 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 184 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 185 | return output, p_attn 186 | 187 | def _matmul_with_relative_values(self, x, y): 188 | """ 189 | x: [b, h, l, m] 190 | y: [h or 1, m, d] 191 | ret: [b, h, l, d] 192 | """ 193 | ret = torch.matmul(x, y.unsqueeze(0)) 194 | return ret 195 | 196 | def _matmul_with_relative_keys(self, x, y): 197 | """ 198 | x: [b, h, l, d] 199 | y: [h or 1, m, d] 200 | ret: [b, h, l, m] 201 | """ 202 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 203 | return ret 204 | 205 | def _get_relative_embeddings(self, relative_embeddings, length): 206 | max_relative_position = 2 * self.window_size + 1 207 | # Pad first before slice to avoid using cond ops. 208 | pad_length = max(length - (self.window_size + 1), 0) 209 | slice_start_position = max((self.window_size + 1) - length, 0) 210 | slice_end_position = slice_start_position + 2 * length - 1 211 | if pad_length > 0: 212 | padded_relative_embeddings = F.pad( 213 | relative_embeddings, 214 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 215 | else: 216 | padded_relative_embeddings = relative_embeddings 217 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] 218 | return used_relative_embeddings 219 | 220 | def _relative_position_to_absolute_position(self, x): 221 | """ 222 | x: [b, h, l, 2*l-1] 223 | ret: [b, h, l, l] 224 | """ 225 | batch, heads, length, _ = x.size() 226 | # Concat columns of pad to shift from relative to absolute indexing. 227 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 228 | 229 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 230 | x_flat = x.view([batch, heads, length * 2 * length]) 231 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])) 232 | 233 | # Reshape and slice out the padded elements. 234 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:] 235 | return x_final 236 | 237 | def _absolute_position_to_relative_position(self, x): 238 | """ 239 | x: [b, h, l, l] 240 | ret: [b, h, l, 2*l-1] 241 | """ 242 | batch, heads, length, _ = x.size() 243 | # padd along column 244 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])) 245 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) 246 | # add 0's in the beginning that will skew the elements after reshape 247 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 248 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 249 | return x_final 250 | 251 | def _attention_bias_proximal(self, length): 252 | """Bias for self-attention to encourage attention to close positions. 253 | Args: 254 | length: an integer scalar. 255 | Returns: 256 | a Tensor with shape [1, 1, length, length] 257 | """ 258 | r = torch.arange(length, dtype=torch.float32) 259 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 260 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 261 | 262 | 263 | class FFN(nn.Module): 264 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, 265 | causal=False): 266 | super().__init__() 267 | self.in_channels = in_channels 268 | self.out_channels = out_channels 269 | self.filter_channels = filter_channels 270 | self.kernel_size = kernel_size 271 | self.p_dropout = p_dropout 272 | self.activation = activation 273 | self.causal = causal 274 | 275 | if causal: 276 | self.padding = self._causal_padding 277 | else: 278 | self.padding = self._same_padding 279 | 280 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 281 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 282 | self.drop = nn.Dropout(p_dropout) 283 | 284 | def forward(self, x, x_mask): 285 | x = self.conv_1(self.padding(x * x_mask)) 286 | if self.activation == "gelu": 287 | x = x * torch.sigmoid(1.702 * x) 288 | else: 289 | x = torch.relu(x) 290 | x = self.drop(x) 291 | x = self.conv_2(self.padding(x * x_mask)) 292 | return x * x_mask 293 | 294 | def _causal_padding(self, x): 295 | if self.kernel_size == 1: 296 | return x 297 | pad_l = self.kernel_size - 1 298 | pad_r = 0 299 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 300 | x = F.pad(x, commons.convert_pad_shape(padding)) 301 | return x 302 | 303 | def _same_padding(self, x): 304 | if self.kernel_size == 1: 305 | return x 306 | pad_l = (self.kernel_size - 1) // 2 307 | pad_r = self.kernel_size // 2 308 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 309 | x = F.pad(x, commons.convert_pad_shape(padding)) 310 | return x 311 | -------------------------------------------------------------------------------- /vits/modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | from . import commons 13 | from .commons import init_weights, get_padding 14 | from .transforms import piecewise_rational_quadratic_transform 15 | 16 | LRELU_SLOPE = 0.1 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | def __init__(self, channels, eps=1e-5): 21 | super().__init__() 22 | self.channels = channels 23 | self.eps = eps 24 | 25 | self.gamma = nn.Parameter(torch.ones(channels)) 26 | self.beta = nn.Parameter(torch.zeros(channels)) 27 | 28 | def forward(self, x): 29 | x = x.transpose(1, -1) 30 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 31 | return x.transpose(1, -1) 32 | 33 | 34 | class ConvReluNorm(nn.Module): 35 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 36 | super().__init__() 37 | self.in_channels = in_channels 38 | self.hidden_channels = hidden_channels 39 | self.out_channels = out_channels 40 | self.kernel_size = kernel_size 41 | self.n_layers = n_layers 42 | self.p_dropout = p_dropout 43 | assert n_layers > 1, "Number of layers should be larger than 0." 44 | 45 | self.conv_layers = nn.ModuleList() 46 | self.norm_layers = nn.ModuleList() 47 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 48 | self.norm_layers.append(LayerNorm(hidden_channels)) 49 | self.relu_drop = nn.Sequential( 50 | nn.ReLU(), 51 | nn.Dropout(p_dropout)) 52 | for _ in range(n_layers - 1): 53 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 54 | self.norm_layers.append(LayerNorm(hidden_channels)) 55 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 56 | self.proj.weight.data.zero_() 57 | self.proj.bias.data.zero_() 58 | 59 | def forward(self, x, x_mask): 60 | x_org = x 61 | for i in range(self.n_layers): 62 | x = self.conv_layers[i](x * x_mask) 63 | x = self.norm_layers[i](x) 64 | x = self.relu_drop(x) 65 | x = x_org + self.proj(x) 66 | return x * x_mask 67 | 68 | 69 | class DDSConv(nn.Module): 70 | """ 71 | Dilated and Depth-Separable Convolution 72 | """ 73 | 74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 75 | super().__init__() 76 | self.channels = channels 77 | self.kernel_size = kernel_size 78 | self.n_layers = n_layers 79 | self.p_dropout = p_dropout 80 | 81 | self.drop = nn.Dropout(p_dropout) 82 | self.convs_sep = nn.ModuleList() 83 | self.convs_1x1 = nn.ModuleList() 84 | self.norms_1 = nn.ModuleList() 85 | self.norms_2 = nn.ModuleList() 86 | for i in range(n_layers): 87 | dilation = kernel_size ** i 88 | padding = (kernel_size * dilation - dilation) // 2 89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 90 | groups=channels, dilation=dilation, padding=padding 91 | )) 92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 93 | self.norms_1.append(LayerNorm(channels)) 94 | self.norms_2.append(LayerNorm(channels)) 95 | 96 | def forward(self, x, x_mask, g=None): 97 | if g is not None: 98 | x = x + g 99 | for i in range(self.n_layers): 100 | y = self.convs_sep[i](x * x_mask) 101 | y = self.norms_1[i](y) 102 | y = F.gelu(y) 103 | y = self.convs_1x1[i](y) 104 | y = self.norms_2[i](y) 105 | y = F.gelu(y) 106 | y = self.drop(y) 107 | x = x + y 108 | return x * x_mask 109 | 110 | 111 | class WN(torch.nn.Module): 112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 113 | super(WN, self).__init__() 114 | assert (kernel_size % 2 == 1) 115 | self.hidden_channels = hidden_channels 116 | self.kernel_size = kernel_size, 117 | self.dilation_rate = dilation_rate 118 | self.n_layers = n_layers 119 | self.gin_channels = gin_channels 120 | self.p_dropout = p_dropout 121 | 122 | self.in_layers = torch.nn.ModuleList() 123 | self.res_skip_layers = torch.nn.ModuleList() 124 | self.drop = nn.Dropout(p_dropout) 125 | 126 | if gin_channels != 0: 127 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 129 | 130 | for i in range(n_layers): 131 | dilation = dilation_rate ** i 132 | padding = int((kernel_size * dilation - dilation) / 2) 133 | in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, 134 | dilation=dilation, padding=padding) 135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 136 | self.in_layers.append(in_layer) 137 | 138 | # last one is not necessary 139 | if i < n_layers - 1: 140 | res_skip_channels = 2 * hidden_channels 141 | else: 142 | res_skip_channels = hidden_channels 143 | 144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 146 | self.res_skip_layers.append(res_skip_layer) 147 | 148 | def forward(self, x, x_mask, g=None, **kwargs): 149 | output = torch.zeros_like(x) 150 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 151 | 152 | if g is not None: 153 | g = self.cond_layer(g) 154 | 155 | for i in range(self.n_layers): 156 | x_in = self.in_layers[i](x) 157 | if g is not None: 158 | cond_offset = i * 2 * self.hidden_channels 159 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] 160 | else: 161 | g_l = torch.zeros_like(x_in) 162 | 163 | acts = commons.fused_add_tanh_sigmoid_multiply( 164 | x_in, 165 | g_l, 166 | n_channels_tensor) 167 | acts = self.drop(acts) 168 | 169 | res_skip_acts = self.res_skip_layers[i](acts) 170 | if i < self.n_layers - 1: 171 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 172 | x = (x + res_acts) * x_mask 173 | output = output + res_skip_acts[:, self.hidden_channels:, :] 174 | else: 175 | output = output + res_skip_acts 176 | return output * x_mask 177 | 178 | def remove_weight_norm(self): 179 | if self.gin_channels != 0: 180 | torch.nn.utils.remove_weight_norm(self.cond_layer) 181 | for l in self.in_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | for l in self.res_skip_layers: 184 | torch.nn.utils.remove_weight_norm(l) 185 | 186 | 187 | class ResBlock1(torch.nn.Module): 188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 189 | super(ResBlock1, self).__init__() 190 | self.convs1 = nn.ModuleList([ 191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 192 | padding=get_padding(kernel_size, dilation[0]))), 193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 194 | padding=get_padding(kernel_size, dilation[1]))), 195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 196 | padding=get_padding(kernel_size, dilation[2]))) 197 | ]) 198 | self.convs1.apply(init_weights) 199 | 200 | self.convs2 = nn.ModuleList([ 201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 202 | padding=get_padding(kernel_size, 1))), 203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 204 | padding=get_padding(kernel_size, 1))), 205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 206 | padding=get_padding(kernel_size, 1))) 207 | ]) 208 | self.convs2.apply(init_weights) 209 | 210 | def forward(self, x, x_mask=None): 211 | for c1, c2 in zip(self.convs1, self.convs2): 212 | xt = F.leaky_relu(x, LRELU_SLOPE) 213 | if x_mask is not None: 214 | xt = xt * x_mask 215 | xt = c1(xt) 216 | xt = F.leaky_relu(xt, LRELU_SLOPE) 217 | if x_mask is not None: 218 | xt = xt * x_mask 219 | xt = c2(xt) 220 | x = xt + x 221 | if x_mask is not None: 222 | x = x * x_mask 223 | return x 224 | 225 | def remove_weight_norm(self): 226 | for l in self.convs1: 227 | remove_weight_norm(l) 228 | for l in self.convs2: 229 | remove_weight_norm(l) 230 | 231 | 232 | class ResBlock2(torch.nn.Module): 233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 234 | super(ResBlock2, self).__init__() 235 | self.convs = nn.ModuleList([ 236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 237 | padding=get_padding(kernel_size, dilation[0]))), 238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 239 | padding=get_padding(kernel_size, dilation[1]))) 240 | ]) 241 | self.convs.apply(init_weights) 242 | 243 | def forward(self, x, x_mask=None): 244 | for c in self.convs: 245 | xt = F.leaky_relu(x, LRELU_SLOPE) 246 | if x_mask is not None: 247 | xt = xt * x_mask 248 | xt = c(xt) 249 | x = xt + x 250 | if x_mask is not None: 251 | x = x * x_mask 252 | return x 253 | 254 | def remove_weight_norm(self): 255 | for l in self.convs: 256 | remove_weight_norm(l) 257 | 258 | 259 | class Log(nn.Module): 260 | def forward(self, x, x_mask, reverse=False, **kwargs): 261 | if not reverse: 262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 263 | logdet = torch.sum(-y, [1, 2]) 264 | return y, logdet 265 | else: 266 | x = torch.exp(x) * x_mask 267 | return x 268 | 269 | 270 | class Flip(nn.Module): 271 | def forward(self, x, *args, reverse=False, **kwargs): 272 | x = torch.flip(x, [1]) 273 | if not reverse: 274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 275 | return x, logdet 276 | else: 277 | return x 278 | 279 | 280 | class ElementwiseAffine(nn.Module): 281 | def __init__(self, channels): 282 | super().__init__() 283 | self.channels = channels 284 | self.m = nn.Parameter(torch.zeros(channels, 1)) 285 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 286 | 287 | def forward(self, x, x_mask, reverse=False, **kwargs): 288 | if not reverse: 289 | y = self.m + torch.exp(self.logs) * x 290 | y = y * x_mask 291 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 292 | return y, logdet 293 | else: 294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 295 | return x 296 | 297 | 298 | class ResidualCouplingLayer(nn.Module): 299 | def __init__(self, 300 | channels, 301 | hidden_channels, 302 | kernel_size, 303 | dilation_rate, 304 | n_layers, 305 | p_dropout=0, 306 | gin_channels=0, 307 | mean_only=False): 308 | assert channels % 2 == 0, "channels should be divisible by 2" 309 | super().__init__() 310 | self.channels = channels 311 | self.hidden_channels = hidden_channels 312 | self.kernel_size = kernel_size 313 | self.dilation_rate = dilation_rate 314 | self.n_layers = n_layers 315 | self.half_channels = channels // 2 316 | self.mean_only = mean_only 317 | 318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, 320 | gin_channels=gin_channels) 321 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 322 | self.post.weight.data.zero_() 323 | self.post.bias.data.zero_() 324 | 325 | def forward(self, x, x_mask, g=None, reverse=False): 326 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 327 | h = self.pre(x0) * x_mask 328 | h = self.enc(h, x_mask, g=g) 329 | stats = self.post(h) * x_mask 330 | if not self.mean_only: 331 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 332 | else: 333 | m = stats 334 | logs = torch.zeros_like(m) 335 | 336 | if not reverse: 337 | x1 = m + x1 * torch.exp(logs) * x_mask 338 | x = torch.cat([x0, x1], 1) 339 | logdet = torch.sum(logs, [1, 2]) 340 | return x, logdet 341 | else: 342 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 343 | x = torch.cat([x0, x1], 1) 344 | return x 345 | 346 | 347 | class ConvFlow(nn.Module): 348 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 349 | super().__init__() 350 | self.in_channels = in_channels 351 | self.filter_channels = filter_channels 352 | self.kernel_size = kernel_size 353 | self.n_layers = n_layers 354 | self.num_bins = num_bins 355 | self.tail_bound = tail_bound 356 | self.half_channels = in_channels // 2 357 | 358 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 359 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 360 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 361 | self.proj.weight.data.zero_() 362 | self.proj.bias.data.zero_() 363 | 364 | def forward(self, x, x_mask, g=None, reverse=False): 365 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 366 | h = self.pre(x0) 367 | h = self.convs(h, x_mask, g=g) 368 | h = self.proj(h) * x_mask 369 | 370 | b, c, t = x0.shape 371 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 372 | 373 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 374 | unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels) 375 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 376 | 377 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 378 | unnormalized_widths, 379 | unnormalized_heights, 380 | unnormalized_derivatives, 381 | inverse=reverse, 382 | tails='linear', 383 | tail_bound=self.tail_bound 384 | ) 385 | 386 | x = torch.cat([x0, x1], 1) * x_mask 387 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 388 | if not reverse: 389 | return x, logdet 390 | else: 391 | return x 392 | -------------------------------------------------------------------------------- /vits/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from . import commons 7 | from . import modules 8 | from . import attentions 9 | from . import monotonic_align 10 | 11 | from torch.nn import Conv1d, ConvTranspose1d, Conv2d 12 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 13 | from .commons import init_weights, get_padding 14 | 15 | 16 | class StochasticDurationPredictor(nn.Module): 17 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 18 | super().__init__() 19 | filter_channels = in_channels # it needs to be removed from future version. 20 | self.in_channels = in_channels 21 | self.filter_channels = filter_channels 22 | self.kernel_size = kernel_size 23 | self.p_dropout = p_dropout 24 | self.n_flows = n_flows 25 | self.gin_channels = gin_channels 26 | 27 | self.log_flow = modules.Log() 28 | self.flows = nn.ModuleList() 29 | self.flows.append(modules.ElementwiseAffine(2)) 30 | for i in range(n_flows): 31 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 32 | self.flows.append(modules.Flip()) 33 | 34 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 35 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 36 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 37 | self.post_flows = nn.ModuleList() 38 | self.post_flows.append(modules.ElementwiseAffine(2)) 39 | for i in range(4): 40 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 41 | self.post_flows.append(modules.Flip()) 42 | 43 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 44 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 45 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 46 | if gin_channels != 0: 47 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 48 | 49 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 50 | x = torch.detach(x) 51 | x = self.pre(x) 52 | if g is not None: 53 | g = torch.detach(g) 54 | x = x + self.cond(g) 55 | x = self.convs(x, x_mask) 56 | x = self.proj(x) * x_mask 57 | 58 | if not reverse: 59 | flows = self.flows 60 | assert w is not None 61 | 62 | logdet_tot_q = 0 63 | h_w = self.post_pre(w) 64 | h_w = self.post_convs(h_w, x_mask) 65 | h_w = self.post_proj(h_w) * x_mask 66 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 67 | z_q = e_q 68 | for flow in self.post_flows: 69 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 70 | logdet_tot_q += logdet_q 71 | z_u, z1 = torch.split(z_q, [1, 1], 1) 72 | u = torch.sigmoid(z_u) * x_mask 73 | z0 = (w - u) * x_mask 74 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) 75 | logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q 76 | 77 | logdet_tot = 0 78 | z0, logdet = self.log_flow(z0, x_mask) 79 | logdet_tot += logdet 80 | z = torch.cat([z0, z1], 1) 81 | for flow in flows: 82 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 83 | logdet_tot = logdet_tot + logdet 84 | nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot 85 | return nll + logq # [b] 86 | else: 87 | flows = list(reversed(self.flows)) 88 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 89 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 90 | for flow in flows: 91 | z = flow(z, x_mask, g=x, reverse=reverse) 92 | z0, z1 = torch.split(z, [1, 1], 1) 93 | logw = z0 94 | return logw 95 | 96 | 97 | class DurationPredictor(nn.Module): 98 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 99 | super().__init__() 100 | 101 | self.in_channels = in_channels 102 | self.filter_channels = filter_channels 103 | self.kernel_size = kernel_size 104 | self.p_dropout = p_dropout 105 | self.gin_channels = gin_channels 106 | 107 | self.drop = nn.Dropout(p_dropout) 108 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) 109 | self.norm_1 = modules.LayerNorm(filter_channels) 110 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) 111 | self.norm_2 = modules.LayerNorm(filter_channels) 112 | self.proj = nn.Conv1d(filter_channels, 1, 1) 113 | 114 | if gin_channels != 0: 115 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 116 | 117 | def forward(self, x, x_mask, g=None): 118 | x = torch.detach(x) 119 | if g is not None: 120 | g = torch.detach(g) 121 | x = x + self.cond(g) 122 | x = self.conv_1(x * x_mask) 123 | x = torch.relu(x) 124 | x = self.norm_1(x) 125 | x = self.drop(x) 126 | x = self.conv_2(x * x_mask) 127 | x = torch.relu(x) 128 | x = self.norm_2(x) 129 | x = self.drop(x) 130 | x = self.proj(x * x_mask) 131 | return x * x_mask 132 | 133 | 134 | class TextEncoder(nn.Module): 135 | def __init__(self, 136 | n_vocab, 137 | out_channels, 138 | hidden_channels, 139 | filter_channels, 140 | n_heads, 141 | n_layers, 142 | kernel_size, 143 | p_dropout, 144 | emotion_embedding): 145 | super().__init__() 146 | self.n_vocab = n_vocab 147 | self.out_channels = out_channels 148 | self.hidden_channels = hidden_channels 149 | self.filter_channels = filter_channels 150 | self.n_heads = n_heads 151 | self.n_layers = n_layers 152 | self.kernel_size = kernel_size 153 | self.p_dropout = p_dropout 154 | self.emotion_embedding = emotion_embedding 155 | 156 | if self.n_vocab != 0: 157 | self.emb = nn.Embedding(n_vocab, hidden_channels) 158 | if emotion_embedding: 159 | self.emo_proj = nn.Linear(1024, hidden_channels) 160 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) 161 | 162 | self.encoder = attentions.Encoder( 163 | hidden_channels, 164 | filter_channels, 165 | n_heads, 166 | n_layers, 167 | kernel_size, 168 | p_dropout) 169 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 170 | 171 | def forward(self, x, x_lengths, emotion_embedding=None): 172 | if self.n_vocab != 0: 173 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 174 | if emotion_embedding is not None: 175 | x = x + self.emo_proj(emotion_embedding.unsqueeze(1)) 176 | x = torch.transpose(x, 1, -1) # [b, h, t] 177 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 178 | 179 | x = self.encoder(x * x_mask, x_mask) 180 | stats = self.proj(x) * x_mask 181 | 182 | m, logs = torch.split(stats, self.out_channels, dim=1) 183 | return x, m, logs, x_mask 184 | 185 | 186 | class ResidualCouplingBlock(nn.Module): 187 | def __init__(self, 188 | channels, 189 | hidden_channels, 190 | kernel_size, 191 | dilation_rate, 192 | n_layers, 193 | n_flows=4, 194 | gin_channels=0): 195 | super().__init__() 196 | self.channels = channels 197 | self.hidden_channels = hidden_channels 198 | self.kernel_size = kernel_size 199 | self.dilation_rate = dilation_rate 200 | self.n_layers = n_layers 201 | self.n_flows = n_flows 202 | self.gin_channels = gin_channels 203 | 204 | self.flows = nn.ModuleList() 205 | for i in range(n_flows): 206 | self.flows.append( 207 | modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, 208 | gin_channels=gin_channels, mean_only=True)) 209 | self.flows.append(modules.Flip()) 210 | 211 | def forward(self, x, x_mask, g=None, reverse=False): 212 | if not reverse: 213 | for flow in self.flows: 214 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 215 | else: 216 | for flow in reversed(self.flows): 217 | x = flow(x, x_mask, g=g, reverse=reverse) 218 | return x 219 | 220 | 221 | class PosteriorEncoder(nn.Module): 222 | def __init__(self, 223 | in_channels, 224 | out_channels, 225 | hidden_channels, 226 | kernel_size, 227 | dilation_rate, 228 | n_layers, 229 | gin_channels=0): 230 | super().__init__() 231 | self.in_channels = in_channels 232 | self.out_channels = out_channels 233 | self.hidden_channels = hidden_channels 234 | self.kernel_size = kernel_size 235 | self.dilation_rate = dilation_rate 236 | self.n_layers = n_layers 237 | self.gin_channels = gin_channels 238 | 239 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 240 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 241 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 242 | 243 | def forward(self, x, x_lengths, g=None): 244 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 245 | x = self.pre(x) * x_mask 246 | x = self.enc(x, x_mask, g=g) 247 | stats = self.proj(x) * x_mask 248 | m, logs = torch.split(stats, self.out_channels, dim=1) 249 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 250 | return z, m, logs, x_mask 251 | 252 | 253 | class Generator(torch.nn.Module): 254 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, 255 | upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 256 | super(Generator, self).__init__() 257 | self.num_kernels = len(resblock_kernel_sizes) 258 | self.num_upsamples = len(upsample_rates) 259 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 260 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 261 | 262 | self.ups = nn.ModuleList() 263 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 264 | self.ups.append(weight_norm( 265 | ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), 266 | k, u, padding=(k - u) // 2))) 267 | 268 | self.resblocks = nn.ModuleList() 269 | for i in range(len(self.ups)): 270 | ch = upsample_initial_channel // (2 ** (i + 1)) 271 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 272 | self.resblocks.append(resblock(ch, k, d)) 273 | 274 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 275 | self.ups.apply(init_weights) 276 | 277 | if gin_channels != 0: 278 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 279 | 280 | def forward(self, x, g=None): 281 | x = self.conv_pre(x) 282 | if g is not None: 283 | x = x + self.cond(g) 284 | 285 | for i in range(self.num_upsamples): 286 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 287 | x = self.ups[i](x) 288 | xs = None 289 | for j in range(self.num_kernels): 290 | if xs is None: 291 | xs = self.resblocks[i * self.num_kernels + j](x) 292 | else: 293 | xs += self.resblocks[i * self.num_kernels + j](x) 294 | x = xs / self.num_kernels 295 | x = F.leaky_relu(x) 296 | x = self.conv_post(x) 297 | x = torch.tanh(x) 298 | 299 | return x 300 | 301 | def remove_weight_norm(self): 302 | print('Removing weight norm...') 303 | for l in self.ups: 304 | remove_weight_norm(l) 305 | for l in self.resblocks: 306 | l.remove_weight_norm() 307 | 308 | 309 | class DiscriminatorP(torch.nn.Module): 310 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 311 | super(DiscriminatorP, self).__init__() 312 | self.period = period 313 | self.use_spectral_norm = use_spectral_norm 314 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 315 | self.convs = nn.ModuleList([ 316 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 317 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 318 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 319 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 320 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 321 | ]) 322 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 323 | 324 | def forward(self, x): 325 | fmap = [] 326 | 327 | # 1d to 2d 328 | b, c, t = x.shape 329 | if t % self.period != 0: # pad first 330 | n_pad = self.period - (t % self.period) 331 | x = F.pad(x, (0, n_pad), "reflect") 332 | t = t + n_pad 333 | x = x.view(b, c, t // self.period, self.period) 334 | 335 | for l in self.convs: 336 | x = l(x) 337 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 338 | fmap.append(x) 339 | x = self.conv_post(x) 340 | fmap.append(x) 341 | x = torch.flatten(x, 1, -1) 342 | 343 | return x, fmap 344 | 345 | 346 | class DiscriminatorS(torch.nn.Module): 347 | def __init__(self, use_spectral_norm=False): 348 | super(DiscriminatorS, self).__init__() 349 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 350 | self.convs = nn.ModuleList([ 351 | norm_f(Conv1d(1, 16, 15, 1, padding=7)), 352 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), 353 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), 354 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 355 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 356 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 357 | ]) 358 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 359 | 360 | def forward(self, x): 361 | fmap = [] 362 | 363 | for l in self.convs: 364 | x = l(x) 365 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 366 | fmap.append(x) 367 | x = self.conv_post(x) 368 | fmap.append(x) 369 | x = torch.flatten(x, 1, -1) 370 | 371 | return x, fmap 372 | 373 | 374 | class MultiPeriodDiscriminator(torch.nn.Module): 375 | def __init__(self, use_spectral_norm=False): 376 | super(MultiPeriodDiscriminator, self).__init__() 377 | periods = [2, 3, 5, 7, 11] 378 | 379 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 380 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 381 | self.discriminators = nn.ModuleList(discs) 382 | 383 | def forward(self, y, y_hat): 384 | y_d_rs = [] 385 | y_d_gs = [] 386 | fmap_rs = [] 387 | fmap_gs = [] 388 | for i, d in enumerate(self.discriminators): 389 | y_d_r, fmap_r = d(y) 390 | y_d_g, fmap_g = d(y_hat) 391 | y_d_rs.append(y_d_r) 392 | y_d_gs.append(y_d_g) 393 | fmap_rs.append(fmap_r) 394 | fmap_gs.append(fmap_g) 395 | 396 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 397 | 398 | 399 | class SynthesizerTrn(nn.Module): 400 | """ 401 | Synthesizer for Training 402 | """ 403 | 404 | def __init__(self, 405 | n_vocab, 406 | spec_channels, 407 | segment_size, 408 | inter_channels, 409 | hidden_channels, 410 | filter_channels, 411 | n_heads, 412 | n_layers, 413 | kernel_size, 414 | p_dropout, 415 | resblock, 416 | resblock_kernel_sizes, 417 | resblock_dilation_sizes, 418 | upsample_rates, 419 | upsample_initial_channel, 420 | upsample_kernel_sizes, 421 | n_speakers=0, 422 | gin_channels=0, 423 | use_sdp=True, 424 | emotion_embedding=False, 425 | **kwargs): 426 | 427 | super().__init__() 428 | self.n_vocab = n_vocab 429 | self.spec_channels = spec_channels 430 | self.inter_channels = inter_channels 431 | self.hidden_channels = hidden_channels 432 | self.filter_channels = filter_channels 433 | self.n_heads = n_heads 434 | self.n_layers = n_layers 435 | self.kernel_size = kernel_size 436 | self.p_dropout = p_dropout 437 | self.resblock = resblock 438 | self.resblock_kernel_sizes = resblock_kernel_sizes 439 | self.resblock_dilation_sizes = resblock_dilation_sizes 440 | self.upsample_rates = upsample_rates 441 | self.upsample_initial_channel = upsample_initial_channel 442 | self.upsample_kernel_sizes = upsample_kernel_sizes 443 | self.segment_size = segment_size 444 | self.n_speakers = n_speakers 445 | self.gin_channels = gin_channels 446 | 447 | self.use_sdp = use_sdp 448 | 449 | self.enc_p = TextEncoder(n_vocab, 450 | inter_channels, 451 | hidden_channels, 452 | filter_channels, 453 | n_heads, 454 | n_layers, 455 | kernel_size, 456 | p_dropout, 457 | emotion_embedding) 458 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, 459 | upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) 460 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, 461 | gin_channels=gin_channels) 462 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 463 | 464 | if use_sdp: 465 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 466 | else: 467 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 468 | 469 | if n_speakers > 1: 470 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 471 | 472 | def forward(self, x, x_lengths, y, y_lengths, sid=None, emotion_embedding=None): 473 | 474 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding) 475 | if self.n_speakers > 1: 476 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 477 | else: 478 | g = None 479 | 480 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) 481 | z_p = self.flow(z, y_mask, g=g) 482 | 483 | with torch.no_grad(): 484 | # negative cross-entropy 485 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] 486 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] 487 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), 488 | s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 489 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 490 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] 491 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 492 | 493 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 494 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() 495 | 496 | w = attn.sum(2) 497 | if self.use_sdp: 498 | l_length = self.dp(x, x_mask, w, g=g) 499 | l_length = l_length / torch.sum(x_mask) 500 | else: 501 | logw_ = torch.log(w + 1e-6) * x_mask 502 | logw = self.dp(x, x_mask, g=g) 503 | l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # for averaging 504 | 505 | # expand prior 506 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) 507 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) 508 | 509 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) 510 | o = self.dec(z_slice, g=g) 511 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) 512 | 513 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, 514 | emotion_embedding=None): 515 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding) 516 | if self.n_speakers > 1: 517 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 518 | else: 519 | g = None 520 | 521 | if self.use_sdp: 522 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) 523 | else: 524 | logw = self.dp(x, x_mask, g=g) 525 | w = torch.exp(logw) * x_mask * length_scale 526 | w_ceil = torch.ceil(w) 527 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 528 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 529 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 530 | attn = commons.generate_path(w_ceil, attn_mask) 531 | 532 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 533 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 534 | 2) # [b, t', t], [b, t, d] -> [b, d, t'] 535 | 536 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 537 | z = self.flow(z_p, y_mask, g=g, reverse=True) 538 | o = self.dec((z * y_mask)[:, :, :max_len], g=g) 539 | return o, attn, y_mask, (z, z_p, m_p, logs_p) 540 | 541 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): 542 | assert self.n_speakers > 1, "n_speakers have to be larger than 1." 543 | g_src = self.emb_g(sid_src).unsqueeze(-1) 544 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) 545 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) 546 | z_p = self.flow(z, y_mask, g=g_src) 547 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 548 | o_hat = self.dec(z_hat * y_mask, g=g_tgt) 549 | return o_hat, y_mask, (z, z_p, z_hat) 550 | --------------------------------------------------------------------------------