├── models
├── hubert
│ └── Put hubert checkpoint here.txt
└── sovits
│ └── Put so-vits checkpoint in a folder here.txt
├── crowdin.yml
├── .gitignore
├── README.md
├── localizations
├── zh_CN.json
└── en.json
├── requirements.txt
├── modules
├── path.py
├── localization.py
├── devices.py
├── model.py
├── utils.py
├── vits_model.py
├── options.py
├── safe.py
├── process.py
├── sovits_model.py
└── ui.py
├── vits
├── monotonic_align
│ ├── __init__.py
│ └── core.py
├── text
│ ├── symbols.py
│ ├── __init__.py
│ ├── LICENSE
│ ├── thai.py
│ ├── ngu_dialect.py
│ ├── sanskrit.py
│ ├── cantonese.py
│ ├── shanghainese.py
│ ├── japanese.py
│ ├── english.py
│ ├── korean.py
│ ├── cleaners.py
│ └── mandarin.py
├── mel_processing.py
├── commons.py
├── utils.py
├── transforms.py
├── attentions.py
├── modules.py
└── models.py
├── webui.py
├── script.js
├── scripts
└── localization.js
└── style.css
/models/hubert/Put hubert checkpoint here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/sovits/Put so-vits checkpoint in a folder here.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/crowdin.yml:
--------------------------------------------------------------------------------
1 | files:
2 | - source: /localizations/en.json
3 | translation: /localizations/%locale_with_underscore%.json
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .vscode
3 |
4 | __pycache__
5 | venv
6 |
7 | models/vits/**
8 | models/sovits/**
9 | models/hubert/**
10 |
11 | outputs/*
12 | temp/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VITS web UI
2 |
3 | A browser interface based on Gradio library for VITS/SO-VITS.
4 |
5 | ## Installation and Running
6 |
7 | Just run it. Trust yourself. You can make it!
--------------------------------------------------------------------------------
/localizations/zh_CN.json:
--------------------------------------------------------------------------------
1 | {
2 | "Text (press Ctrl+Enter or Alt+Enter to generate)": "文字 (按下 Ctrl+Enter 或 Alt+Enter 开始生成)",
3 | "Generate": "生成",
4 | "Speakers": "说话人",
5 | "VITS Checkpoint": "VITS 模型",
6 | "Process Method": "处理方式",
7 | "Speed": "语速",
8 | "Output Message": "输出信息",
9 | "Output Audio": "输出音频",
10 | "Open Folder": "打开文件夹",
11 | "Save": "保存"
12 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numba
2 | librosa
3 | matplotlib
4 | numpy
5 | phonemizer
6 | scipy
7 | tensorboard
8 | torch
9 | torchvision
10 | torchaudio
11 | unidecode
12 | pyopenjtalk>=0.3.0
13 | jamo
14 | pypinyin
15 | ko_pron
16 | jieba
17 | cn2an
18 | protobuf
19 | inflect
20 | eng_to_ipa
21 | indic_transliteration
22 | num_thai
23 | opencc
24 | gradio==3.15
25 | praat-parselmouth
26 | tqdm
--------------------------------------------------------------------------------
/modules/path.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | webui_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
5 | sys.path.insert(0, webui_path)
6 |
7 | paths = [
8 | {
9 | "t": 0,
10 | "p": os.path.join(webui_path, "repositories/sovits")
11 | }
12 | ]
13 |
14 |
15 | def insert_repositories_path():
16 | for p in paths:
17 | if p["t"] == 0:
18 | sys.path.insert(0, p["p"])
19 | else:
20 | sys.path.append(p["p"])
21 |
22 |
23 | # insert_repositories_path()
24 |
--------------------------------------------------------------------------------
/modules/localization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | localization_dir = "localizations"
4 | localization_files = os.listdir(localization_dir)
5 |
6 |
7 | def gen_localization_js(name):
8 | if name not in localization_files:
9 | print(f"Load localization file {name} failed. Try set another localization file in settings panel.")
10 | return ""
11 | with open(os.path.join("localizations", name), "r", encoding="utf8") as lf:
12 | localization_file = lf.read()
13 | js = f"\n"
14 | return js
15 |
--------------------------------------------------------------------------------
/localizations/en.json:
--------------------------------------------------------------------------------
1 | {
2 | "Drop Audio Here": "Drop Audio Here",
3 | "Click to Upload": "Click to Upload",
4 |
5 | "Text (press Ctrl+Enter or Alt+Enter to generate)": "Text (press Ctrl+Enter or Alt+Enter to generate)",
6 | "Generate": "Generate",
7 | "Speakers": "Speakers",
8 | "VITS Checkpoint": "VITS Checkpoint",
9 | "Process Method": "Process Method",
10 | "Speed": "Speed",
11 | "Output Message": "Output Message",
12 | "Output Audio": "Output Audio",
13 | "Open Folder": "Open Folder",
14 | "Save": "Save",
15 | "Simple": "Simple",
16 | "Batch Process": "Batch Process",
17 | "Multi Speakers": "Multi Speakers",
18 | "Apply settings": "Apply settings",
19 | "Reload UI": "Reload UI"
20 | }
--------------------------------------------------------------------------------
/vits/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | from numpy import zeros, int32, float32
2 | from torch import from_numpy
3 |
4 | from .core import maximum_path_jit
5 |
6 |
7 | def maximum_path(neg_cent, mask):
8 | """ numba optimized version.
9 | neg_cent: [b, t_t, t_s]
10 | mask: [b, t_t, t_s]
11 | """
12 | device = neg_cent.device
13 | dtype = neg_cent.dtype
14 | neg_cent = neg_cent.data.cpu().numpy().astype(float32)
15 | path = zeros(neg_cent.shape, dtype=int32)
16 |
17 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
18 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
19 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
20 | return from_numpy(path).to(device=device, dtype=dtype)
21 |
22 |
--------------------------------------------------------------------------------
/vits/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Defines the set of symbols used in text input to the model.
5 | '''
6 | _pad = '_'
7 | _punctuation = ';:,.!?¡¿—…"«»“” '
8 |
9 | _punctuation_zh = ';:,。!?-“”《》、()BP…—~.\·『』・ '
10 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
11 |
12 | _numbers = '1234567890'
13 | _others = ''
14 |
15 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
16 |
17 | # Export all symbols:
18 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
19 |
20 | symbols_zh = [_pad] + list(_punctuation_zh) + list(_letters) + list(_numbers)
21 |
22 | # Special symbol ids
23 | SPACE_ID = symbols.index(" ")
24 |
--------------------------------------------------------------------------------
/modules/devices.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules import options
3 |
4 | cpu = torch.device("cpu")
5 | cuda_available = torch.cuda.is_available()
6 |
7 |
8 | def get_cuda_device():
9 | if options.cmd_opts.device_id is not None:
10 | return f"cuda:{options.cmd_opts.device_id}"
11 |
12 | return "cuda"
13 |
14 |
15 | def get_optimal_device():
16 | if cuda_available:
17 | return torch.device(get_cuda_device())
18 | return cpu
19 |
20 |
21 | def torch_gc():
22 | if cuda_available:
23 | with torch.cuda.device(get_cuda_device()):
24 | torch.cuda.empty_cache()
25 | torch.cuda.ipc_collect()
26 |
27 |
28 | device = cpu if options.cmd_opts.use_cpu else get_optimal_device()
29 |
30 | if not cuda_available:
31 | print("CUDA is not available, using cpu mode...")
32 |
--------------------------------------------------------------------------------
/vits/text/__init__.py:
--------------------------------------------------------------------------------
1 | from . import cleaners
2 |
3 |
4 | def text_to_sequence(text, symbols, cleaner_names: str, cleaner=None):
5 | """
6 | modified t2s: custom symbols and cleaner
7 | """
8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
9 |
10 | sequence = []
11 | if cleaner:
12 | clean_text = cleaner(text)
13 | else:
14 | clean_text = _clean_text(text, cleaner_names)
15 | for symbol in clean_text:
16 | if symbol not in _symbol_to_id.keys():
17 | continue
18 | symbol_id = _symbol_to_id[symbol]
19 | sequence += [symbol_id]
20 | return sequence
21 |
22 |
23 | def _clean_text(text, cleaner_names):
24 | for name in cleaner_names:
25 | cleaner = getattr(cleaners, name)
26 | if not cleaner:
27 | raise Exception('Unknown cleaner: %s' % name)
28 | text = cleaner(text)
29 | return text
30 |
--------------------------------------------------------------------------------
/vits/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/vits/text/thai.py:
--------------------------------------------------------------------------------
1 | import re
2 | from num_thai.thainumbers import NumThai
3 |
4 |
5 | num = NumThai()
6 |
7 | # List of (Latin alphabet, Thai) pairs:
8 | _latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
9 | ('a', 'เอ'),
10 | ('b','บี'),
11 | ('c','ซี'),
12 | ('d','ดี'),
13 | ('e','อี'),
14 | ('f','เอฟ'),
15 | ('g','จี'),
16 | ('h','เอช'),
17 | ('i','ไอ'),
18 | ('j','เจ'),
19 | ('k','เค'),
20 | ('l','แอล'),
21 | ('m','เอ็ม'),
22 | ('n','เอ็น'),
23 | ('o','โอ'),
24 | ('p','พี'),
25 | ('q','คิว'),
26 | ('r','แอร์'),
27 | ('s','เอส'),
28 | ('t','ที'),
29 | ('u','ยู'),
30 | ('v','วี'),
31 | ('w','ดับเบิลยู'),
32 | ('x','เอ็กซ์'),
33 | ('y','วาย'),
34 | ('z','ซี')
35 | ]]
36 |
37 |
38 | def num_to_thai(text):
39 | return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text)
40 |
41 | def latin_to_thai(text):
42 | for regex, replacement in _latin_to_thai:
43 | text = re.sub(regex, replacement, text)
44 | return text
45 |
--------------------------------------------------------------------------------
/vits/text/ngu_dialect.py:
--------------------------------------------------------------------------------
1 | import re
2 | import opencc
3 |
4 |
5 | dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou',
6 | 'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing',
7 | 'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang',
8 | 'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan',
9 | 'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen',
10 | 'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'}
11 |
12 | converters = {}
13 |
14 | for dialect in dialects.values():
15 | try:
16 | converters[dialect] = opencc.OpenCC(dialect)
17 | except:
18 | pass
19 |
20 |
21 | def ngu_dialect_to_ipa(text, dialect):
22 | dialect = dialects[dialect]
23 | text = converters[dialect].convert(text).replace('-','').replace('$',' ')
24 | text = re.sub(r'[、;:]', ',', text)
25 | text = re.sub(r'\s*,\s*', ', ', text)
26 | text = re.sub(r'\s*。\s*', '. ', text)
27 | text = re.sub(r'\s*?\s*', '? ', text)
28 | text = re.sub(r'\s*!\s*', '! ', text)
29 | text = re.sub(r'\s*$', '', text)
30 | return text
31 |
--------------------------------------------------------------------------------
/vits/monotonic_align/core.py:
--------------------------------------------------------------------------------
1 | import numba
2 |
3 |
4 | @numba.jit(numba.void(numba.int32[:, :, ::1], numba.float32[:, :, ::1], numba.int32[::1], numba.int32[::1]),
5 | nopython=True, nogil=True)
6 | def maximum_path_jit(paths, values, t_ys, t_xs):
7 | b = paths.shape[0]
8 | max_neg_val = -1e9
9 | for i in range(int(b)):
10 | path = paths[i]
11 | value = values[i]
12 | t_y = t_ys[i]
13 | t_x = t_xs[i]
14 |
15 | v_prev = v_cur = 0.0
16 | index = t_x - 1
17 |
18 | for y in range(t_y):
19 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
20 | if x == y:
21 | v_cur = max_neg_val
22 | else:
23 | v_cur = value[y - 1, x]
24 | if x == 0:
25 | if y == 0:
26 | v_prev = 0.
27 | else:
28 | v_prev = max_neg_val
29 | else:
30 | v_prev = value[y - 1, x - 1]
31 | value[y, x] += max(v_prev, v_cur)
32 |
33 | for y in range(t_y - 1, -1, -1):
34 | path[y, index] = 1
35 | if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
36 | index = index - 1
37 |
--------------------------------------------------------------------------------
/modules/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules.utils import search_ext_file
3 |
4 |
5 | class ModelInfo:
6 | model_name: str
7 | model_folder: str
8 | model_hash: str
9 | checkpoint_path: str
10 | config_path: str
11 |
12 | def __init__(self, model_name, model_folder, model_hash, checkpoint_path, config_path):
13 | self.model_name = model_name
14 | self.model_folder = model_folder
15 | self.model_hash = model_hash
16 | self.checkpoint_path = checkpoint_path
17 | self.config_path = config_path
18 | self.custom_symbols = None
19 |
20 |
21 | def refresh_model_list(model_path):
22 | dirs = os.listdir(model_path)
23 | models = []
24 | for d in dirs:
25 | p = os.path.join(model_path, d)
26 | if not os.path.isdir(p):
27 | continue
28 | pth_path = search_ext_file(p, ".pth")
29 | if not pth_path:
30 | print(f"Path {p} does not have a pth file, pass")
31 | continue
32 | config_path = search_ext_file(p, ".json")
33 | if not config_path:
34 | print(f"Path {p} does not have a config file, pass")
35 | continue
36 | models.append({
37 | "dir": d,
38 | "pth": pth_path,
39 | "config": config_path
40 | })
41 | return models
42 |
--------------------------------------------------------------------------------
/modules/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import re
4 | import subprocess
5 | from typing import Optional
6 |
7 | # for export
8 | from vits.utils import get_hparams_from_file, HParams
9 |
10 |
11 | def search_ext_file(path: str, ext: str) -> Optional[str]:
12 | files = os.listdir(path)
13 | for f in files:
14 | if f.endswith(ext):
15 | return os.path.join(path, f)
16 | return None
17 |
18 |
19 | def model_hash(filename):
20 | try:
21 | with open(filename, "rb") as file:
22 | import hashlib
23 | m = hashlib.sha256()
24 |
25 | file.seek(0x100000)
26 | m.update(file.read(0x10000))
27 | return m.hexdigest()[0:8]
28 | except FileNotFoundError:
29 | return 'FileNotFound'
30 |
31 |
32 | def open_folder(f):
33 | if not os.path.exists(f):
34 | print(f'Folder "{f}" does not exist.')
35 | return
36 | elif not os.path.isdir(f):
37 | return
38 |
39 | path = os.path.normpath(f)
40 | if platform.system() == "Windows":
41 | os.startfile(path)
42 | elif platform.system() == "Darwin":
43 | subprocess.Popen(["open", path])
44 | elif "microsoft-standard-WSL2" in platform.uname().release:
45 | subprocess.Popen(["wsl-open", path])
46 | else:
47 | subprocess.Popen(["xdg-open", path])
48 |
49 |
50 | def windows_filename(s: str):
51 | return re.sub('[<>:"\/\\|?*\n\t\r]+', "", s)
52 |
--------------------------------------------------------------------------------
/webui.py:
--------------------------------------------------------------------------------
1 | # 転がる岩、君に朝が降る
2 | # Code with Love by @Akegarasu
3 |
4 | import os
5 | import sys
6 |
7 | # must import before other modules load model.
8 | import modules.safe
9 | import modules.path
10 |
11 | from modules.ui import create_ui
12 | import modules.vits_model as vits_model
13 | import modules.sovits_model as sovits_model
14 | from modules.options import cmd_opts
15 |
16 |
17 | def init():
18 | print(f"Launching webui with arguments: {' '.join(sys.argv[1:])}")
19 | ensure_output_dirs()
20 | vits_model.refresh_list()
21 | sovits_model.refresh_list()
22 | if cmd_opts.ui_debug_mode:
23 | return
24 | # todo: autoload last model
25 | # load_last_model()
26 |
27 |
28 | def ensure_output_dirs():
29 | folders = ["outputs/vits", "outputs/vits-batch", "outputs/sovits", "outputs/sovits", "outputs/sovits-batch", "temp"]
30 |
31 | def check_and_create(p):
32 | if not os.path.exists(p):
33 | os.makedirs(p)
34 |
35 | for i in folders:
36 | check_and_create(i)
37 |
38 |
39 | def run():
40 | init()
41 | app = create_ui()
42 | if cmd_opts.server_name:
43 | server_name = cmd_opts.server_name
44 | else:
45 | server_name = "0.0.0.0" if cmd_opts.listen else None
46 |
47 | app.queue(default_enabled=False).launch(
48 | share=cmd_opts.share,
49 | server_name=server_name,
50 | server_port=cmd_opts.port,
51 | inbrowser=cmd_opts.autolaunch,
52 | show_api=False
53 | )
54 |
55 |
56 | if __name__ == "__main__":
57 | run()
58 |
--------------------------------------------------------------------------------
/vits/text/sanskrit.py:
--------------------------------------------------------------------------------
1 | import re
2 | from indic_transliteration import sanscript
3 |
4 |
5 | # List of (iast, ipa) pairs:
6 | _iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
7 | ('a', 'ə'),
8 | ('ā', 'aː'),
9 | ('ī', 'iː'),
10 | ('ū', 'uː'),
11 | ('ṛ', 'ɹ`'),
12 | ('ṝ', 'ɹ`ː'),
13 | ('ḷ', 'l`'),
14 | ('ḹ', 'l`ː'),
15 | ('e', 'eː'),
16 | ('o', 'oː'),
17 | ('k', 'k⁼'),
18 | ('k⁼h', 'kʰ'),
19 | ('g', 'g⁼'),
20 | ('g⁼h', 'gʰ'),
21 | ('ṅ', 'ŋ'),
22 | ('c', 'ʧ⁼'),
23 | ('ʧ⁼h', 'ʧʰ'),
24 | ('j', 'ʥ⁼'),
25 | ('ʥ⁼h', 'ʥʰ'),
26 | ('ñ', 'n^'),
27 | ('ṭ', 't`⁼'),
28 | ('t`⁼h', 't`ʰ'),
29 | ('ḍ', 'd`⁼'),
30 | ('d`⁼h', 'd`ʰ'),
31 | ('ṇ', 'n`'),
32 | ('t', 't⁼'),
33 | ('t⁼h', 'tʰ'),
34 | ('d', 'd⁼'),
35 | ('d⁼h', 'dʰ'),
36 | ('p', 'p⁼'),
37 | ('p⁼h', 'pʰ'),
38 | ('b', 'b⁼'),
39 | ('b⁼h', 'bʰ'),
40 | ('y', 'j'),
41 | ('ś', 'ʃ'),
42 | ('ṣ', 's`'),
43 | ('r', 'ɾ'),
44 | ('l̤', 'l`'),
45 | ('h', 'ɦ'),
46 | ("'", ''),
47 | ('~', '^'),
48 | ('ṃ', '^')
49 | ]]
50 |
51 |
52 | def devanagari_to_ipa(text):
53 | text = text.replace('ॐ', 'ओम्')
54 | text = re.sub(r'\s*।\s*$', '.', text)
55 | text = re.sub(r'\s*।\s*', ', ', text)
56 | text = re.sub(r'\s*॥', '.', text)
57 | text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST)
58 | for regex, replacement in _iast_to_ipa:
59 | text = re.sub(regex, replacement, text)
60 | text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0)
61 | [:-1]+'h'+x.group(1)+'*', text)
62 | return text
63 |
--------------------------------------------------------------------------------
/vits/text/cantonese.py:
--------------------------------------------------------------------------------
1 | import re
2 | import cn2an
3 | import opencc
4 |
5 |
6 | converter = opencc.OpenCC('jyutjyu')
7 |
8 | # List of (Latin alphabet, ipa) pairs:
9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
10 | ('A', 'ei˥'),
11 | ('B', 'biː˥'),
12 | ('C', 'siː˥'),
13 | ('D', 'tiː˥'),
14 | ('E', 'iː˥'),
15 | ('F', 'e˥fuː˨˩'),
16 | ('G', 'tsiː˥'),
17 | ('H', 'ɪk̚˥tsʰyː˨˩'),
18 | ('I', 'ɐi˥'),
19 | ('J', 'tsei˥'),
20 | ('K', 'kʰei˥'),
21 | ('L', 'e˥llou˨˩'),
22 | ('M', 'ɛːm˥'),
23 | ('N', 'ɛːn˥'),
24 | ('O', 'ou˥'),
25 | ('P', 'pʰiː˥'),
26 | ('Q', 'kʰiːu˥'),
27 | ('R', 'aː˥lou˨˩'),
28 | ('S', 'ɛː˥siː˨˩'),
29 | ('T', 'tʰiː˥'),
30 | ('U', 'juː˥'),
31 | ('V', 'wiː˥'),
32 | ('W', 'tʊk̚˥piː˥juː˥'),
33 | ('X', 'ɪk̚˥siː˨˩'),
34 | ('Y', 'waːi˥'),
35 | ('Z', 'iː˨sɛːt̚˥')
36 | ]]
37 |
38 |
39 | def number_to_cantonese(text):
40 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text)
41 |
42 |
43 | def latin_to_ipa(text):
44 | for regex, replacement in _latin_to_ipa:
45 | text = re.sub(regex, replacement, text)
46 | return text
47 |
48 |
49 | def cantonese_to_ipa(text):
50 | text = number_to_cantonese(text.upper())
51 | text = converter.convert(text).replace('-','').replace('$',' ')
52 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
53 | text = re.sub(r'[、;:]', ',', text)
54 | text = re.sub(r'\s*,\s*', ', ', text)
55 | text = re.sub(r'\s*。\s*', '. ', text)
56 | text = re.sub(r'\s*?\s*', '? ', text)
57 | text = re.sub(r'\s*!\s*', '! ', text)
58 | text = re.sub(r'\s*$', '', text)
59 | return text
60 |
--------------------------------------------------------------------------------
/vits/text/shanghainese.py:
--------------------------------------------------------------------------------
1 | import re
2 | import cn2an
3 | import opencc
4 |
5 |
6 | converter = opencc.OpenCC('zaonhe')
7 |
8 | # List of (Latin alphabet, ipa) pairs:
9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
10 | ('A', 'ᴇ'),
11 | ('B', 'bi'),
12 | ('C', 'si'),
13 | ('D', 'di'),
14 | ('E', 'i'),
15 | ('F', 'ᴇf'),
16 | ('G', 'dʑi'),
17 | ('H', 'ᴇtɕʰ'),
18 | ('I', 'ᴀi'),
19 | ('J', 'dʑᴇ'),
20 | ('K', 'kʰᴇ'),
21 | ('L', 'ᴇl'),
22 | ('M', 'ᴇm'),
23 | ('N', 'ᴇn'),
24 | ('O', 'o'),
25 | ('P', 'pʰi'),
26 | ('Q', 'kʰiu'),
27 | ('R', 'ᴀl'),
28 | ('S', 'ᴇs'),
29 | ('T', 'tʰi'),
30 | ('U', 'ɦiu'),
31 | ('V', 'vi'),
32 | ('W', 'dᴀbɤliu'),
33 | ('X', 'ᴇks'),
34 | ('Y', 'uᴀi'),
35 | ('Z', 'zᴇ')
36 | ]]
37 |
38 |
39 | def _number_to_shanghainese(num):
40 | num = cn2an.an2cn(num).replace('一十','十').replace('二十', '廿').replace('二', '两')
41 | return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num)
42 |
43 |
44 | def number_to_shanghainese(text):
45 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text)
46 |
47 |
48 | def latin_to_ipa(text):
49 | for regex, replacement in _latin_to_ipa:
50 | text = re.sub(regex, replacement, text)
51 | return text
52 |
53 |
54 | def shanghainese_to_ipa(text):
55 | text = number_to_shanghainese(text.upper())
56 | text = converter.convert(text).replace('-','').replace('$',' ')
57 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group())+' ', text)
58 | text = re.sub(r'[、;:]', ',', text)
59 | text = re.sub(r'\s*,\s*', ', ', text)
60 | text = re.sub(r'\s*。\s*', '. ', text)
61 | text = re.sub(r'\s*?\s*', '? ', text)
62 | text = re.sub(r'\s*!\s*', '! ', text)
63 | text = re.sub(r'\s*$', '', text)
64 | return text
65 |
--------------------------------------------------------------------------------
/script.js:
--------------------------------------------------------------------------------
1 | let uiUpdateCallbacks = []
2 | let uiTabChangeCallbacks = []
3 | let uiCurrentTab = null;
4 |
5 |
6 | function gradioApp() {
7 | const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot
8 | return !!gradioShadowRoot ? gradioShadowRoot : document;
9 | }
10 |
11 | function get_uiCurrentTab() {
12 | return gradioApp().querySelector('.tabs button:not(.border-transparent)')
13 | }
14 |
15 | function get_uiCurrentTabContent() {
16 | return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
17 | }
18 |
19 | function onUiUpdate(callback) {
20 | uiUpdateCallbacks.push(callback)
21 | }
22 |
23 | function onUiTabChange(callback) {
24 | uiTabChangeCallbacks.push(callback)
25 | }
26 |
27 | function runCallback(x, m) {
28 | try {
29 | x(m)
30 | } catch (e) {
31 | (console.error || console.log).call(console, e.message, e);
32 | }
33 | }
34 |
35 | function executeCallbacks(queue, m) {
36 | queue.forEach(function (x) {
37 | runCallback(x, m)
38 | })
39 | }
40 |
41 | document.addEventListener("DOMContentLoaded", function () {
42 | let mutationObserver = new MutationObserver(function (m) {
43 | executeCallbacks(uiUpdateCallbacks, m);
44 | const newTab = get_uiCurrentTab();
45 | if (newTab && (newTab !== uiCurrentTab)) {
46 | uiCurrentTab = newTab;
47 | executeCallbacks(uiTabChangeCallbacks);
48 | }
49 | });
50 | mutationObserver.observe(gradioApp(), {childList: true, subtree: true})
51 | });
52 |
53 | document.addEventListener('keydown', function (e) {
54 | let handled = false;
55 | if (e.key !== undefined) {
56 | if ((e.key === "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
57 | } else if (e.keyCode !== undefined) {
58 | if ((e.keyCode === 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
59 | }
60 | if (handled) {
61 | let button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
62 | if (button) {
63 | button.click();
64 | }
65 | e.preventDefault();
66 | }
67 | })
68 |
--------------------------------------------------------------------------------
/scripts/localization.js:
--------------------------------------------------------------------------------
1 | // localization = {} -- the dict with translations is created by the backend
2 |
3 | ignore_ids_for_localization = {}
4 |
5 | re_num = /^[\.\d]+$/
6 | re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
7 |
8 | original_lines = {}
9 | translated_lines = {}
10 |
11 | function textNodesUnder(el) {
12 | var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);
13 | while (n = walk.nextNode()) a.push(n);
14 | return a;
15 | }
16 |
17 | function canBeTranslated(node, text) {
18 | if (!text) return false;
19 | if (!node.parentElement) return false;
20 |
21 | parentType = node.parentElement.nodeName
22 | if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;
23 |
24 | if (parentType == 'OPTION' || parentType == 'SPAN') {
25 | pnode = node
26 | for (var level = 0; level < 4; level++) {
27 | pnode = pnode.parentElement
28 | if (!pnode) break;
29 |
30 | if (ignore_ids_for_localization[pnode.id] == parentType) return false;
31 | }
32 | }
33 |
34 | if (re_num.test(text)) return false;
35 | if (re_emoji.test(text)) return false;
36 | return true
37 | }
38 |
39 | function getTranslation(text) {
40 | if (!text) return undefined
41 |
42 | if (translated_lines[text] === undefined) {
43 | original_lines[text] = 1
44 | }
45 |
46 | tl = localization[text]
47 | if (tl !== undefined) {
48 | translated_lines[tl] = 1
49 | }
50 |
51 | return tl
52 | }
53 |
54 | function processTextNode(node) {
55 | text = node.textContent.trim()
56 |
57 | if (!canBeTranslated(node, text)) return
58 |
59 | tl = getTranslation(text)
60 | if (tl !== undefined) {
61 | node.textContent = tl
62 | }
63 | }
64 |
65 | function processNode(node) {
66 | if (node.nodeType == 3) {
67 | processTextNode(node)
68 | return
69 | }
70 |
71 | if (node.title) {
72 | tl = getTranslation(node.title)
73 | if (tl !== undefined) {
74 | node.title = tl
75 | }
76 | }
77 |
78 | if (node.placeholder) {
79 | tl = getTranslation(node.placeholder)
80 | if (tl !== undefined) {
81 | node.placeholder = tl
82 | }
83 | }
84 |
85 | textNodesUnder(node).forEach(function (node) {
86 | processTextNode(node)
87 | })
88 | }
89 |
90 | onUiUpdate(function (m) {
91 | m.forEach(function (mutation) {
92 | mutation.addedNodes.forEach(function (node) {
93 | processNode(node)
94 | })
95 | });
96 | })
97 |
98 |
99 | document.addEventListener("DOMContentLoaded", function () {
100 | processNode(gradioApp())
101 |
102 | if (localization.rtl) { // if the language is from right to left,
103 | (new MutationObserver((mutations, observer) => { // wait for the style to load
104 | mutations.forEach(mutation => {
105 | mutation.addedNodes.forEach(node => {
106 | if (node.tagName === 'STYLE') {
107 | observer.disconnect();
108 |
109 | for (const x of node.sheet.rules) { // find all rtl media rules
110 | if (Array.from(x.media || []).includes('rtl')) {
111 | x.media.appendMedium('all'); // enable them
112 | }
113 | }
114 | }
115 | })
116 | });
117 | })).observe(gradioApp(), {childList: true});
118 | }
119 | })
120 |
--------------------------------------------------------------------------------
/vits/mel_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from librosa.filters import mel as librosa_mel_fn
4 |
5 | MAX_WAV_VALUE = 32768.0
6 |
7 |
8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9 | """
10 | PARAMS
11 | ------
12 | C: compression factor
13 | """
14 | return torch.log(torch.clamp(x, min=clip_val) * C)
15 |
16 |
17 | def dynamic_range_decompression_torch(x, C=1):
18 | """
19 | PARAMS
20 | ------
21 | C: compression factor used to compress
22 | """
23 | return torch.exp(x) / C
24 |
25 |
26 | def spectral_normalize_torch(magnitudes):
27 | output = dynamic_range_compression_torch(magnitudes)
28 | return output
29 |
30 |
31 | def spectral_de_normalize_torch(magnitudes):
32 | output = dynamic_range_decompression_torch(magnitudes)
33 | return output
34 |
35 |
36 | mel_basis = {}
37 | hann_window = {}
38 |
39 |
40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41 | if torch.min(y) < -1.:
42 | print('min value is ', torch.min(y))
43 | if torch.max(y) > 1.:
44 | print('max value is ', torch.max(y))
45 |
46 | global hann_window
47 | dtype_device = str(y.dtype) + '_' + str(y.device)
48 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
49 | if wnsize_dtype_device not in hann_window:
50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
51 |
52 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
53 | mode='reflect')
54 | y = y.squeeze(1)
55 |
56 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
57 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
58 |
59 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
60 | return spec
61 |
62 |
63 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
64 | global mel_basis
65 | dtype_device = str(spec.dtype) + '_' + str(spec.device)
66 | fmax_dtype_device = str(fmax) + '_' + dtype_device
67 | if fmax_dtype_device not in mel_basis:
68 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
69 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
70 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
71 | spec = spectral_normalize_torch(spec)
72 | return spec
73 |
74 |
75 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
76 | if torch.min(y) < -1.:
77 | print('min value is ', torch.min(y))
78 | if torch.max(y) > 1.:
79 | print('max value is ', torch.max(y))
80 |
81 | global mel_basis, hann_window
82 | dtype_device = str(y.dtype) + '_' + str(y.device)
83 | fmax_dtype_device = str(fmax) + '_' + dtype_device
84 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
85 | if fmax_dtype_device not in mel_basis:
86 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
87 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
88 | if wnsize_dtype_device not in hann_window:
89 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
90 |
91 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
92 | mode='reflect')
93 | y = y.squeeze(1)
94 |
95 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
96 | center=center, pad_mode='reflect', normalized=False, onesided=True)
97 |
98 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
99 |
100 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
101 | spec = spectral_normalize_torch(spec)
102 |
103 | return spec
104 |
--------------------------------------------------------------------------------
/vits/text/japanese.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unidecode import unidecode
3 | import pyopenjtalk
4 |
5 |
6 | # Regular expression matching Japanese without punctuation marks:
7 | _japanese_characters = re.compile(
8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9 |
10 | # Regular expression matching non-Japanese characters or punctuation marks:
11 | _japanese_marks = re.compile(
12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13 |
14 | # List of (symbol, Japanese) pairs for marks:
15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16 | ('%', 'パーセント')
17 | ]]
18 |
19 | # List of (romaji, ipa) pairs for marks:
20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21 | ('ts', 'ʦ'),
22 | ('u', 'ɯ'),
23 | ('j', 'ʥ'),
24 | ('y', 'j'),
25 | ('ni', 'n^i'),
26 | ('nj', 'n^'),
27 | ('hi', 'çi'),
28 | ('hj', 'ç'),
29 | ('f', 'ɸ'),
30 | ('I', 'i*'),
31 | ('U', 'ɯ*'),
32 | ('r', 'ɾ')
33 | ]]
34 |
35 | # List of (romaji, ipa2) pairs for marks:
36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37 | ('u', 'ɯ'),
38 | ('ʧ', 'tʃ'),
39 | ('j', 'dʑ'),
40 | ('y', 'j'),
41 | ('ni', 'n^i'),
42 | ('nj', 'n^'),
43 | ('hi', 'çi'),
44 | ('hj', 'ç'),
45 | ('f', 'ɸ'),
46 | ('I', 'i*'),
47 | ('U', 'ɯ*'),
48 | ('r', 'ɾ')
49 | ]]
50 |
51 | # List of (consonant, sokuon) pairs:
52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53 | (r'Q([↑↓]*[kg])', r'k#\1'),
54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55 | (r'Q([↑↓]*[sʃ])', r's\1'),
56 | (r'Q([↑↓]*[pb])', r'p#\1')
57 | ]]
58 |
59 | # List of (consonant, hatsuon) pairs:
60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61 | (r'N([↑↓]*[pbm])', r'm\1'),
62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63 | (r'N([↑↓]*[tdn])', r'n\1'),
64 | (r'N([↑↓]*[kg])', r'ŋ\1')
65 | ]]
66 |
67 |
68 | def symbols_to_japanese(text):
69 | for regex, replacement in _symbols_to_japanese:
70 | text = re.sub(regex, replacement, text)
71 | return text
72 |
73 |
74 | def japanese_to_romaji_with_accent(text):
75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76 | text = symbols_to_japanese(text)
77 | sentences = re.split(_japanese_marks, text)
78 | marks = re.findall(_japanese_marks, text)
79 | text = ''
80 | for i, sentence in enumerate(sentences):
81 | if re.match(_japanese_characters, sentence):
82 | if text != '':
83 | text += ' '
84 | labels = pyopenjtalk.extract_fullcontext(sentence)
85 | for n, label in enumerate(labels):
86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
87 | if phoneme not in ['sil', 'pau']:
88 | text += phoneme.replace('ch', 'ʧ').replace('sh',
89 | 'ʃ').replace('cl', 'Q')
90 | else:
91 | continue
92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1))
95 | a3 = int(re.search(r"\+(\d+)/", label).group(1))
96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
97 | a2_next = -1
98 | else:
99 | a2_next = int(
100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
101 | # Accent phrase boundary
102 | if a3 == 1 and a2_next == 1:
103 | text += ' '
104 | # Falling
105 | elif a1 == 0 and a2_next == a2 + 1:
106 | text += '↓'
107 | # Rising
108 | elif a2 == 1 and a2_next == 2:
109 | text += '↑'
110 | if i < len(marks):
111 | text += unidecode(marks[i]).replace(' ', '')
112 | return text
113 |
114 |
115 | def get_real_sokuon(text):
116 | for regex, replacement in _real_sokuon:
117 | text = re.sub(regex, replacement, text)
118 | return text
119 |
120 |
121 | def get_real_hatsuon(text):
122 | for regex, replacement in _real_hatsuon:
123 | text = re.sub(regex, replacement, text)
124 | return text
125 |
126 |
127 | def japanese_to_ipa(text):
128 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
129 | text = re.sub(
130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
131 | text = get_real_sokuon(text)
132 | text = get_real_hatsuon(text)
133 | for regex, replacement in _romaji_to_ipa:
134 | text = re.sub(regex, replacement, text)
135 | return text
136 |
137 |
138 | def japanese_to_ipa2(text):
139 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
140 | text = get_real_sokuon(text)
141 | text = get_real_hatsuon(text)
142 | for regex, replacement in _romaji_to_ipa2:
143 | text = re.sub(regex, replacement, text)
144 | return text
145 |
146 |
147 | def japanese_to_ipa3(text):
148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
150 | text = re.sub(
151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
153 | return text
154 |
--------------------------------------------------------------------------------
/modules/vits_model.py:
--------------------------------------------------------------------------------
1 | import importlib.util
2 | import os.path
3 | from typing import List, Dict
4 |
5 | import torch
6 |
7 | import modules.devices as devices
8 | import vits.utils
9 | from modules.model import ModelInfo, refresh_model_list
10 | from modules.utils import search_ext_file, model_hash
11 | from vits.models import SynthesizerTrn
12 | from vits.text.symbols import symbols as builtin_symbols
13 | from vits.utils import HParams
14 |
15 | # todo: cmdline here
16 | MODEL_PATH = os.path.join(os.path.join(os.getcwd(), "models"), "vits")
17 |
18 |
19 | class VITSModel:
20 | model: SynthesizerTrn
21 | hps: HParams
22 | symbols: List[str]
23 |
24 | model_name: str
25 | model_folder: str
26 | checkpoint_path: str
27 | config_path: str
28 |
29 | speakers: List[str]
30 |
31 | def __init__(self, info: ModelInfo):
32 | self.model_name = info.model_name
33 | self.model_folder = info.model_folder
34 | self.checkpoint_path = info.checkpoint_path
35 | self.config_path = info.config_path
36 | self.custom_symbols = None
37 | # self.state = "" # maybe for multiprocessing
38 |
39 | def load_model(self):
40 | hps = vits.utils.get_hparams_from_file(self.config_path)
41 | self.load_custom_symbols(f"{self.model_folder}/symbols.py")
42 | if self.custom_symbols:
43 | _symbols = self.custom_symbols.symbols
44 | elif "symbols" in hps:
45 | _symbols = hps.symbols
46 | else:
47 | _symbols = builtin_symbols
48 |
49 | if hasattr(self.custom_symbols, "symbols_zh"):
50 | hps["symbols_zh"] = self.custom_symbols.symbols_zh
51 |
52 | model = SynthesizerTrn(
53 | len(_symbols),
54 | hps.data.filter_length // 2 + 1,
55 | hps.train.segment_size // hps.data.hop_length,
56 | n_speakers=hps.data.n_speakers,
57 | **hps.model)
58 | model, _, _, _ = load_checkpoint(checkpoint_path=self.checkpoint_path,
59 | model=model, optimizer=None)
60 | model.eval().to(devices.device)
61 |
62 | self.speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"]
63 | self.model = model
64 | self.hps = hps
65 | self.symbols = _symbols
66 |
67 | def load_custom_symbols(self, symbol_path):
68 | if os.path.exists(symbol_path):
69 | spec = importlib.util.spec_from_file_location('symbols', symbol_path)
70 | _sym = importlib.util.module_from_spec(spec)
71 | spec.loader.exec_module(_sym)
72 | if not hasattr(_sym, "symbols"):
73 | print(f"Loading symbol file {symbol_path} failed, so such attr")
74 | return
75 | self.custom_symbols = _sym
76 |
77 |
78 | vits_model_list: Dict[str, ModelInfo] = {}
79 | curr_vits_model: VITSModel = None
80 |
81 |
82 | def get_model() -> VITSModel:
83 | return curr_vits_model
84 |
85 |
86 | def get_model_name():
87 | return curr_vits_model.model_name if curr_vits_model is not None else None
88 |
89 |
90 | def get_model_list():
91 | return [k for k, _ in vits_model_list.items()]
92 |
93 |
94 | def get_speakers():
95 | return curr_vits_model.speakers if curr_vits_model is not None else ["None"]
96 |
97 |
98 | def refresh_list():
99 | vits_model_list.clear()
100 | model_list = refresh_model_list(model_path=MODEL_PATH)
101 | for m in model_list:
102 | d = m["dir"]
103 | p = os.path.join(MODEL_PATH, m["dir"])
104 | pth_path = m["pth"]
105 | config_path = m["config"]
106 | vits_model_list[d] = ModelInfo(
107 | model_name=d,
108 | model_folder=p,
109 | model_hash=model_hash(pth_path),
110 | checkpoint_path=pth_path,
111 | config_path=config_path
112 | )
113 | if len(vits_model_list.items()) == 0:
114 | print("No vits model found. Please put a model in models/vits")
115 |
116 |
117 | def init_load_model():
118 | info = next(iter(vits_model_list.values()))
119 | load_model(info.model_name)
120 |
121 |
122 | def init_model():
123 | global curr_vits_model
124 | info = next(iter(vits_model_list.values()))
125 | curr_vits_model = VITSModel(info)
126 | # load_model(info.model_name)
127 |
128 |
129 | def load_model(model_name: str):
130 | global curr_vits_model, vits_model_list
131 | if curr_vits_model and curr_vits_model.model_name == model_name:
132 | return
133 | info = vits_model_list[model_name]
134 | print(f"Loading weights [{info.model_hash}] from {info.checkpoint_path}...")
135 | m = VITSModel(info)
136 | m.load_model()
137 | curr_vits_model = m
138 | print("Model loaded.")
139 |
140 |
141 | def load_checkpoint(checkpoint_path, model, optimizer=None):
142 | assert os.path.isfile(checkpoint_path)
143 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
144 | iteration = checkpoint_dict['iteration']
145 | learning_rate = checkpoint_dict['learning_rate']
146 | if optimizer is not None:
147 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
148 | saved_state_dict = checkpoint_dict['model']
149 | if hasattr(model, 'module'):
150 | state_dict = model.module.state_dict()
151 | else:
152 | state_dict = model.state_dict()
153 | new_state_dict = {}
154 | for k, v in state_dict.items():
155 | try:
156 | new_state_dict[k] = saved_state_dict[k]
157 | except:
158 | new_state_dict[k] = v
159 | if hasattr(model, 'module'):
160 | model.module.load_state_dict(new_state_dict)
161 | else:
162 | model.load_state_dict(new_state_dict)
163 | return model, optimizer, learning_rate, iteration
164 |
--------------------------------------------------------------------------------
/vits/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.nn import functional as F
4 | import torch.jit
5 |
6 |
7 | def script_method(fn, _rcb=None):
8 | return fn
9 |
10 |
11 | def script(obj, optimize=True, _frames_up=0, _rcb=None):
12 | return obj
13 |
14 |
15 | torch.jit.script_method = script_method
16 | torch.jit.script = script
17 |
18 |
19 | def init_weights(m, mean=0.0, std=0.01):
20 | classname = m.__class__.__name__
21 | if classname.find("Conv") != -1:
22 | m.weight.data.normal_(mean, std)
23 |
24 |
25 | def get_padding(kernel_size, dilation=1):
26 | return int((kernel_size * dilation - dilation) / 2)
27 |
28 |
29 | def convert_pad_shape(pad_shape):
30 | l = pad_shape[::-1]
31 | pad_shape = [item for sublist in l for item in sublist]
32 | return pad_shape
33 |
34 |
35 | def intersperse(lst, item):
36 | result = [item] * (len(lst) * 2 + 1)
37 | result[1::2] = lst
38 | return result
39 |
40 |
41 | def kl_divergence(m_p, logs_p, m_q, logs_q):
42 | """KL(P||Q)"""
43 | kl = (logs_q - logs_p) - 0.5
44 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2. * logs_q)
45 | return kl
46 |
47 |
48 | def rand_gumbel(shape):
49 | """Sample from the Gumbel distribution, protect from overflows."""
50 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
51 | return -torch.log(-torch.log(uniform_samples))
52 |
53 |
54 | def rand_gumbel_like(x):
55 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
56 | return g
57 |
58 |
59 | def slice_segments(x, ids_str, segment_size=4):
60 | ret = torch.zeros_like(x[:, :, :segment_size])
61 | for i in range(x.size(0)):
62 | idx_str = ids_str[i]
63 | idx_end = idx_str + segment_size
64 | ret[i] = x[i, :, idx_str:idx_end]
65 | return ret
66 |
67 |
68 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
69 | b, d, t = x.size()
70 | if x_lengths is None:
71 | x_lengths = t
72 | ids_str_max = x_lengths - segment_size + 1
73 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
74 | ret = slice_segments(x, ids_str, segment_size)
75 | return ret, ids_str
76 |
77 |
78 | def get_timing_signal_1d(
79 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
80 | position = torch.arange(length, dtype=torch.float)
81 | num_timescales = channels // 2
82 | log_timescale_increment = (
83 | math.log(float(max_timescale) / float(min_timescale)) /
84 | (num_timescales - 1))
85 | inv_timescales = min_timescale * torch.exp(
86 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
87 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
88 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
89 | signal = F.pad(signal, [0, 0, 0, channels % 2])
90 | signal = signal.view(1, channels, length)
91 | return signal
92 |
93 |
94 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
95 | b, channels, length = x.size()
96 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
97 | return x + signal.to(dtype=x.dtype, device=x.device)
98 |
99 |
100 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
101 | b, channels, length = x.size()
102 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
103 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
104 |
105 |
106 | def subsequent_mask(length):
107 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
108 | return mask
109 |
110 |
111 | @torch.jit.script
112 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
113 | n_channels_int = n_channels[0]
114 | in_act = input_a + input_b
115 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
116 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
117 | acts = t_act * s_act
118 | return acts
119 |
120 |
121 | def convert_pad_shape(pad_shape):
122 | l = pad_shape[::-1]
123 | pad_shape = [item for sublist in l for item in sublist]
124 | return pad_shape
125 |
126 |
127 | def shift_1d(x):
128 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
129 | return x
130 |
131 |
132 | def sequence_mask(length, max_length=None):
133 | if max_length is None:
134 | max_length = length.max()
135 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
136 | return x.unsqueeze(0) < length.unsqueeze(1)
137 |
138 |
139 | def generate_path(duration, mask):
140 | """
141 | duration: [b, 1, t_x]
142 | mask: [b, 1, t_y, t_x]
143 | """
144 | device = duration.device
145 |
146 | b, _, t_y, t_x = mask.shape
147 | cum_duration = torch.cumsum(duration, -1)
148 |
149 | cum_duration_flat = cum_duration.view(b * t_x)
150 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
151 | path = path.view(b, t_x, t_y)
152 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
153 | path = path.unsqueeze(1).transpose(2, 3) * mask
154 | return path
155 |
156 |
157 | def clip_grad_value_(parameters, clip_value, norm_type=2):
158 | if isinstance(parameters, torch.Tensor):
159 | parameters = [parameters]
160 | parameters = list(filter(lambda p: p.grad is not None, parameters))
161 | norm_type = float(norm_type)
162 | if clip_value is not None:
163 | clip_value = float(clip_value)
164 |
165 | total_norm = 0
166 | for p in parameters:
167 | param_norm = p.grad.data.norm(norm_type)
168 | total_norm += param_norm.item() ** norm_type
169 | if clip_value is not None:
170 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
171 | total_norm = total_norm ** (1. / norm_type)
172 | return total_norm
173 |
--------------------------------------------------------------------------------
/modules/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import gradio as gr
5 |
6 | from modules.localization import localization_files
7 |
8 | config_file = "config.json"
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name")
12 | parser.add_argument("--port", type=int, help="launch gradio with given server port, defaults to 8860 if available", default="8860")
13 | parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
14 | parser.add_argument("--server-name", type=str, help="sets hostname of server", default=None)
15 | parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
16 | parser.add_argument("--device-id", type=str, help="select the default CUDA device to use", default=None)
17 | parser.add_argument("--use-cpu", action='store_true', help="use cpu")
18 | parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable safe unpickle")
19 | parser.add_argument("--freeze-settings", action='store_true', help="freeze settings")
20 | parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
21 |
22 | # todo: use logger
23 | parser.add_argument("--debug", action='store_true', help="Output debug log")
24 |
25 | cmd_opts = parser.parse_args()
26 |
27 |
28 | class OptionInfo:
29 | def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
30 | self.default = default
31 | self.label = label
32 | self.component = component
33 | self.component_args = component_args
34 | self.onchange = onchange
35 | self.section = section
36 | self.refresh = refresh
37 |
38 |
39 | options_templates = {}
40 |
41 | # options_templates.update({
42 | # "vits_model": OptionInfo(None, "VITS checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
43 | # })
44 |
45 | options_templates.update({
46 | "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + localization_files}),
47 | })
48 |
49 |
50 | class Options:
51 | data = None
52 | data_labels = options_templates
53 | type_map = {int: float}
54 |
55 | def __init__(self):
56 | self.data = {k: v.default for k, v in self.data_labels.items()}
57 |
58 | def __setattr__(self, key, value):
59 | if self.data is not None:
60 | if key in self.data or key in self.data_labels:
61 | assert not cmd_opts.freeze_settings, "changing settings is disabled"
62 |
63 | info = opts.data_labels.get(key, None)
64 | comp_args = info.component_args if info else None
65 | if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
66 | raise RuntimeError(f"not possible to set {key} because it is restricted")
67 |
68 | self.data[key] = value
69 | return
70 |
71 | return super(Options, self).__setattr__(key, value)
72 |
73 | def __getattr__(self, item):
74 | if self.data is not None:
75 | if item in self.data:
76 | return self.data[item]
77 |
78 | if item in self.data_labels:
79 | return self.data_labels[item].default
80 |
81 | return super(Options, self).__getattribute__(item)
82 |
83 | def set(self, key, value):
84 | oldval = self.data.get(key, None)
85 | if oldval == value:
86 | return False
87 |
88 | try:
89 | setattr(self, key, value)
90 | except RuntimeError:
91 | return False
92 |
93 | if self.data_labels[key].onchange is not None:
94 | try:
95 | self.data_labels[key].onchange()
96 | except Exception as e:
97 | print(e)
98 | print(f"Error when handling onchange event: changing setting {key} to {value}")
99 | setattr(self, key, oldval)
100 | return False
101 |
102 | return True
103 |
104 | def save(self, filename=config_file):
105 | assert not cmd_opts.freeze_settings, "saving settings is disabled"
106 |
107 | with open(filename, "w", encoding="utf8") as file:
108 | json.dump(self.data, file, indent=4, ensure_ascii=False)
109 |
110 | def same_type(self, x, y):
111 | if x is None or y is None:
112 | return True
113 |
114 | type_x = self.type_map.get(type(x), type(x))
115 | type_y = self.type_map.get(type(y), type(y))
116 |
117 | return type_x == type_y
118 |
119 | def load(self, filename):
120 | with open(filename, "r", encoding="utf8") as file:
121 | self.data = json.load(file)
122 |
123 | bad_settings = 0
124 | for k, v in self.data.items():
125 | info = self.data_labels.get(k, None)
126 | if info is not None and not self.same_type(info.default, v):
127 | print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})")
128 | bad_settings += 1
129 |
130 | if bad_settings > 0:
131 | print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.")
132 |
133 | def onchange(self, key, func, call=True):
134 | item = self.data_labels.get(key)
135 | item.onchange = func
136 |
137 | if call:
138 | func()
139 |
140 | def dump_json(self):
141 | d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
142 | return json.dumps(d)
143 |
144 | def add_option(self, key, info):
145 | self.data_labels[key] = info
146 |
147 |
148 | opts = Options()
149 | if os.path.exists(config_file):
150 | opts.load(config_file)
151 | else:
152 | print("Config not found, generating default config...")
153 | opts.save()
154 |
--------------------------------------------------------------------------------
/vits/text/english.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 |
18 |
19 | import re
20 | import inflect
21 | from unidecode import unidecode
22 | import eng_to_ipa as ipa
23 | _inflect = inflect.engine()
24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29 | _number_re = re.compile(r'[0-9]+')
30 |
31 | # List of (regular expression, replacement) pairs for abbreviations:
32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33 | ('mrs', 'misess'),
34 | ('mr', 'mister'),
35 | ('dr', 'doctor'),
36 | ('st', 'saint'),
37 | ('co', 'company'),
38 | ('jr', 'junior'),
39 | ('maj', 'major'),
40 | ('gen', 'general'),
41 | ('drs', 'doctors'),
42 | ('rev', 'reverend'),
43 | ('lt', 'lieutenant'),
44 | ('hon', 'honorable'),
45 | ('sgt', 'sergeant'),
46 | ('capt', 'captain'),
47 | ('esq', 'esquire'),
48 | ('ltd', 'limited'),
49 | ('col', 'colonel'),
50 | ('ft', 'fort'),
51 | ]]
52 |
53 |
54 | # List of (ipa, lazy ipa) pairs:
55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56 | ('r', 'ɹ'),
57 | ('æ', 'e'),
58 | ('ɑ', 'a'),
59 | ('ɔ', 'o'),
60 | ('ð', 'z'),
61 | ('θ', 's'),
62 | ('ɛ', 'e'),
63 | ('ɪ', 'i'),
64 | ('ʊ', 'u'),
65 | ('ʒ', 'ʥ'),
66 | ('ʤ', 'ʥ'),
67 | ('ˈ', '↓'),
68 | ]]
69 |
70 | # List of (ipa, lazy ipa2) pairs:
71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72 | ('r', 'ɹ'),
73 | ('ð', 'z'),
74 | ('θ', 's'),
75 | ('ʒ', 'ʑ'),
76 | ('ʤ', 'dʑ'),
77 | ('ˈ', '↓'),
78 | ]]
79 |
80 | # List of (ipa, ipa2) pairs
81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82 | ('r', 'ɹ'),
83 | ('ʤ', 'dʒ'),
84 | ('ʧ', 'tʃ')
85 | ]]
86 |
87 |
88 | def expand_abbreviations(text):
89 | for regex, replacement in _abbreviations:
90 | text = re.sub(regex, replacement, text)
91 | return text
92 |
93 |
94 | def collapse_whitespace(text):
95 | return re.sub(r'\s+', ' ', text)
96 |
97 |
98 | def _remove_commas(m):
99 | return m.group(1).replace(',', '')
100 |
101 |
102 | def _expand_decimal_point(m):
103 | return m.group(1).replace('.', ' point ')
104 |
105 |
106 | def _expand_dollars(m):
107 | match = m.group(1)
108 | parts = match.split('.')
109 | if len(parts) > 2:
110 | return match + ' dollars' # Unexpected format
111 | dollars = int(parts[0]) if parts[0] else 0
112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113 | if dollars and cents:
114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115 | cent_unit = 'cent' if cents == 1 else 'cents'
116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117 | elif dollars:
118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119 | return '%s %s' % (dollars, dollar_unit)
120 | elif cents:
121 | cent_unit = 'cent' if cents == 1 else 'cents'
122 | return '%s %s' % (cents, cent_unit)
123 | else:
124 | return 'zero dollars'
125 |
126 |
127 | def _expand_ordinal(m):
128 | return _inflect.number_to_words(m.group(0))
129 |
130 |
131 | def _expand_number(m):
132 | num = int(m.group(0))
133 | if num > 1000 and num < 3000:
134 | if num == 2000:
135 | return 'two thousand'
136 | elif num > 2000 and num < 2010:
137 | return 'two thousand ' + _inflect.number_to_words(num % 100)
138 | elif num % 100 == 0:
139 | return _inflect.number_to_words(num // 100) + ' hundred'
140 | else:
141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142 | else:
143 | return _inflect.number_to_words(num, andword='')
144 |
145 |
146 | def normalize_numbers(text):
147 | text = re.sub(_comma_number_re, _remove_commas, text)
148 | text = re.sub(_pounds_re, r'\1 pounds', text)
149 | text = re.sub(_dollars_re, _expand_dollars, text)
150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151 | text = re.sub(_ordinal_re, _expand_ordinal, text)
152 | text = re.sub(_number_re, _expand_number, text)
153 | return text
154 |
155 |
156 | def mark_dark_l(text):
157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158 |
159 |
160 | def english_to_ipa(text):
161 | text = unidecode(text).lower()
162 | text = expand_abbreviations(text)
163 | text = normalize_numbers(text)
164 | phonemes = ipa.convert(text)
165 | phonemes = collapse_whitespace(phonemes)
166 | return phonemes
167 |
168 |
169 | def english_to_lazy_ipa(text):
170 | text = english_to_ipa(text)
171 | for regex, replacement in _lazy_ipa:
172 | text = re.sub(regex, replacement, text)
173 | return text
174 |
175 |
176 | def english_to_ipa2(text):
177 | text = english_to_ipa(text)
178 | text = mark_dark_l(text)
179 | for regex, replacement in _ipa_to_ipa2:
180 | text = re.sub(regex, replacement, text)
181 | return text.replace('...', '…')
182 |
183 |
184 | def english_to_lazy_ipa2(text):
185 | text = english_to_ipa(text)
186 | for regex, replacement in _lazy_ipa2:
187 | text = re.sub(regex, replacement, text)
188 | return text
189 |
--------------------------------------------------------------------------------
/vits/text/korean.py:
--------------------------------------------------------------------------------
1 | import re
2 | from jamo import h2j, j2hcj
3 | import ko_pron
4 |
5 |
6 | # This is a list of Korean classifiers preceded by pure Korean numerals.
7 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통'
8 |
9 | # List of (hangul, hangul divided) pairs:
10 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [
11 | ('ㄳ', 'ㄱㅅ'),
12 | ('ㄵ', 'ㄴㅈ'),
13 | ('ㄶ', 'ㄴㅎ'),
14 | ('ㄺ', 'ㄹㄱ'),
15 | ('ㄻ', 'ㄹㅁ'),
16 | ('ㄼ', 'ㄹㅂ'),
17 | ('ㄽ', 'ㄹㅅ'),
18 | ('ㄾ', 'ㄹㅌ'),
19 | ('ㄿ', 'ㄹㅍ'),
20 | ('ㅀ', 'ㄹㅎ'),
21 | ('ㅄ', 'ㅂㅅ'),
22 | ('ㅘ', 'ㅗㅏ'),
23 | ('ㅙ', 'ㅗㅐ'),
24 | ('ㅚ', 'ㅗㅣ'),
25 | ('ㅝ', 'ㅜㅓ'),
26 | ('ㅞ', 'ㅜㅔ'),
27 | ('ㅟ', 'ㅜㅣ'),
28 | ('ㅢ', 'ㅡㅣ'),
29 | ('ㅑ', 'ㅣㅏ'),
30 | ('ㅒ', 'ㅣㅐ'),
31 | ('ㅕ', 'ㅣㅓ'),
32 | ('ㅖ', 'ㅣㅔ'),
33 | ('ㅛ', 'ㅣㅗ'),
34 | ('ㅠ', 'ㅣㅜ')
35 | ]]
36 |
37 | # List of (Latin alphabet, hangul) pairs:
38 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
39 | ('a', '에이'),
40 | ('b', '비'),
41 | ('c', '시'),
42 | ('d', '디'),
43 | ('e', '이'),
44 | ('f', '에프'),
45 | ('g', '지'),
46 | ('h', '에이치'),
47 | ('i', '아이'),
48 | ('j', '제이'),
49 | ('k', '케이'),
50 | ('l', '엘'),
51 | ('m', '엠'),
52 | ('n', '엔'),
53 | ('o', '오'),
54 | ('p', '피'),
55 | ('q', '큐'),
56 | ('r', '아르'),
57 | ('s', '에스'),
58 | ('t', '티'),
59 | ('u', '유'),
60 | ('v', '브이'),
61 | ('w', '더블유'),
62 | ('x', '엑스'),
63 | ('y', '와이'),
64 | ('z', '제트')
65 | ]]
66 |
67 | # List of (ipa, lazy ipa) pairs:
68 | _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
69 | ('t͡ɕ','ʧ'),
70 | ('d͡ʑ','ʥ'),
71 | ('ɲ','n^'),
72 | ('ɕ','ʃ'),
73 | ('ʷ','w'),
74 | ('ɭ','l`'),
75 | ('ʎ','ɾ'),
76 | ('ɣ','ŋ'),
77 | ('ɰ','ɯ'),
78 | ('ʝ','j'),
79 | ('ʌ','ə'),
80 | ('ɡ','g'),
81 | ('\u031a','#'),
82 | ('\u0348','='),
83 | ('\u031e',''),
84 | ('\u0320',''),
85 | ('\u0339','')
86 | ]]
87 |
88 |
89 | def latin_to_hangul(text):
90 | for regex, replacement in _latin_to_hangul:
91 | text = re.sub(regex, replacement, text)
92 | return text
93 |
94 |
95 | def divide_hangul(text):
96 | text = j2hcj(h2j(text))
97 | for regex, replacement in _hangul_divided:
98 | text = re.sub(regex, replacement, text)
99 | return text
100 |
101 |
102 | def hangul_number(num, sino=True):
103 | '''Reference https://github.com/Kyubyong/g2pK'''
104 | num = re.sub(',', '', num)
105 |
106 | if num == '0':
107 | return '영'
108 | if not sino and num == '20':
109 | return '스무'
110 |
111 | digits = '123456789'
112 | names = '일이삼사오육칠팔구'
113 | digit2name = {d: n for d, n in zip(digits, names)}
114 |
115 | modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉'
116 | decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔'
117 | digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())}
118 | digit2dec = {d: dec for d, dec in zip(digits, decimals.split())}
119 |
120 | spelledout = []
121 | for i, digit in enumerate(num):
122 | i = len(num) - i - 1
123 | if sino:
124 | if i == 0:
125 | name = digit2name.get(digit, '')
126 | elif i == 1:
127 | name = digit2name.get(digit, '') + '십'
128 | name = name.replace('일십', '십')
129 | else:
130 | if i == 0:
131 | name = digit2mod.get(digit, '')
132 | elif i == 1:
133 | name = digit2dec.get(digit, '')
134 | if digit == '0':
135 | if i % 4 == 0:
136 | last_three = spelledout[-min(3, len(spelledout)):]
137 | if ''.join(last_three) == '':
138 | spelledout.append('')
139 | continue
140 | else:
141 | spelledout.append('')
142 | continue
143 | if i == 2:
144 | name = digit2name.get(digit, '') + '백'
145 | name = name.replace('일백', '백')
146 | elif i == 3:
147 | name = digit2name.get(digit, '') + '천'
148 | name = name.replace('일천', '천')
149 | elif i == 4:
150 | name = digit2name.get(digit, '') + '만'
151 | name = name.replace('일만', '만')
152 | elif i == 5:
153 | name = digit2name.get(digit, '') + '십'
154 | name = name.replace('일십', '십')
155 | elif i == 6:
156 | name = digit2name.get(digit, '') + '백'
157 | name = name.replace('일백', '백')
158 | elif i == 7:
159 | name = digit2name.get(digit, '') + '천'
160 | name = name.replace('일천', '천')
161 | elif i == 8:
162 | name = digit2name.get(digit, '') + '억'
163 | elif i == 9:
164 | name = digit2name.get(digit, '') + '십'
165 | elif i == 10:
166 | name = digit2name.get(digit, '') + '백'
167 | elif i == 11:
168 | name = digit2name.get(digit, '') + '천'
169 | elif i == 12:
170 | name = digit2name.get(digit, '') + '조'
171 | elif i == 13:
172 | name = digit2name.get(digit, '') + '십'
173 | elif i == 14:
174 | name = digit2name.get(digit, '') + '백'
175 | elif i == 15:
176 | name = digit2name.get(digit, '') + '천'
177 | spelledout.append(name)
178 | return ''.join(elem for elem in spelledout)
179 |
180 |
181 | def number_to_hangul(text):
182 | '''Reference https://github.com/Kyubyong/g2pK'''
183 | tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text))
184 | for token in tokens:
185 | num, classifier = token
186 | if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers:
187 | spelledout = hangul_number(num, sino=False)
188 | else:
189 | spelledout = hangul_number(num, sino=True)
190 | text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}')
191 | # digit by digit for remaining digits
192 | digits = '0123456789'
193 | names = '영일이삼사오육칠팔구'
194 | for d, n in zip(digits, names):
195 | text = text.replace(d, n)
196 | return text
197 |
198 |
199 | def korean_to_lazy_ipa(text):
200 | text = latin_to_hangul(text)
201 | text = number_to_hangul(text)
202 | text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text)
203 | for regex, replacement in _ipa_to_lazy_ipa:
204 | text = re.sub(regex, replacement, text)
205 | return text
206 |
207 |
208 | def korean_to_ipa(text):
209 | text = korean_to_lazy_ipa(text)
210 | return text.replace('ʧ','tʃ').replace('ʥ','dʑ')
211 |
--------------------------------------------------------------------------------
/vits/text/cleaners.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pyopenjtalk
3 |
4 | pyopenjtalk._lazy_init()
5 |
6 |
7 | def japanese_cleaners(text):
8 | from .japanese import japanese_to_romaji_with_accent
9 | text = japanese_to_romaji_with_accent(text)
10 | text = re.sub(r'([A-Za-z])$', r'\1.', text)
11 | return text
12 |
13 |
14 | def japanese_cleaners2(text):
15 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
16 |
17 |
18 | def korean_cleaners(text):
19 | """Pipeline for Korean text"""
20 | from .korean import latin_to_hangul, number_to_hangul, divide_hangul
21 | text = latin_to_hangul(text)
22 | text = number_to_hangul(text)
23 | text = divide_hangul(text)
24 | text = re.sub(r'([\u3131-\u3163])$', r'\1.', text)
25 | return text
26 |
27 |
28 | def chinese_cleaners(text):
29 | """Pipeline for Chinese text"""
30 | from .mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo
31 | text = number_to_chinese(text)
32 | text = chinese_to_bopomofo(text)
33 | text = latin_to_bopomofo(text)
34 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text)
35 | return text
36 |
37 |
38 | def chinese_cleaners1(text):
39 | from pypinyin import Style, pinyin
40 |
41 | phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)]
42 | return ' '.join(phones)
43 |
44 |
45 | def chinese_cleaners2(text):
46 | from pypinyin import Style, pinyin
47 | from pypinyin.style._utils import get_finals, get_initials
48 | return " ".join([
49 | p
50 | for phone in pinyin(text, style=Style.TONE3, v_to_u=True)
51 | for p in [
52 | get_initials(phone[0], strict=True),
53 | get_finals(phone[0][:-1], strict=True) + phone[0][-1]
54 | if phone[0][-1].isdigit()
55 | else get_finals(phone[0], strict=True)
56 | if phone[0][-1].isalnum()
57 | else phone[0],
58 | ]
59 | if len(p) != 0 and not p.isdigit()
60 | ])
61 |
62 |
63 | def zh_ja_mixture_cleaners(text):
64 | from .mandarin import chinese_to_romaji
65 | from .japanese import japanese_to_romaji_with_accent
66 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
67 | lambda x: chinese_to_romaji(x.group(1)) + ' ', text)
68 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent(
69 | x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…') + ' ', text)
70 | text = re.sub(r'\s+$', '', text)
71 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
72 | return text
73 |
74 |
75 | def sanskrit_cleaners(text):
76 | text = text.replace('॥', '।').replace('ॐ', 'ओम्')
77 | if text[-1] != '।':
78 | text += ' ।'
79 | return text
80 |
81 |
82 | def cjks_cleaners(text):
83 | from .mandarin import chinese_to_lazy_ipa
84 | from .japanese import japanese_to_ipa
85 | from .korean import korean_to_lazy_ipa
86 | from .sanskrit import devanagari_to_ipa
87 | from .english import english_to_lazy_ipa
88 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
89 | lambda x: chinese_to_lazy_ipa(x.group(1)) + ' ', text)
90 | text = re.sub(r'\[JA\](.*?)\[JA\]',
91 | lambda x: japanese_to_ipa(x.group(1)) + ' ', text)
92 | text = re.sub(r'\[KO\](.*?)\[KO\]',
93 | lambda x: korean_to_lazy_ipa(x.group(1)) + ' ', text)
94 | text = re.sub(r'\[SA\](.*?)\[SA\]',
95 | lambda x: devanagari_to_ipa(x.group(1)) + ' ', text)
96 | text = re.sub(r'\[EN\](.*?)\[EN\]',
97 | lambda x: english_to_lazy_ipa(x.group(1)) + ' ', text)
98 | text = re.sub(r'\s+$', '', text)
99 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
100 | return text
101 |
102 |
103 | def cjke_cleaners(text):
104 | from .mandarin import chinese_to_lazy_ipa
105 | from .japanese import japanese_to_ipa
106 | from .korean import korean_to_ipa
107 | from .english import english_to_ipa2
108 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace(
109 | 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn') + ' ', text)
110 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace(
111 | 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz') + ' ', text)
112 | text = re.sub(r'\[KO\](.*?)\[KO\]',
113 | lambda x: korean_to_ipa(x.group(1)) + ' ', text)
114 | text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace(
115 | 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u') + ' ', text)
116 | text = re.sub(r'\s+$', '', text)
117 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
118 | return text
119 |
120 |
121 | def cjke_cleaners2(text):
122 | from .mandarin import chinese_to_ipa
123 | from .japanese import japanese_to_ipa2
124 | from .korean import korean_to_ipa
125 | from .english import english_to_ipa2
126 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
127 | lambda x: chinese_to_ipa(x.group(1)) + ' ', text)
128 | text = re.sub(r'\[JA\](.*?)\[JA\]',
129 | lambda x: japanese_to_ipa2(x.group(1)) + ' ', text)
130 | text = re.sub(r'\[KO\](.*?)\[KO\]',
131 | lambda x: korean_to_ipa(x.group(1)) + ' ', text)
132 | text = re.sub(r'\[EN\](.*?)\[EN\]',
133 | lambda x: english_to_ipa2(x.group(1)) + ' ', text)
134 | text = re.sub(r'\s+$', '', text)
135 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
136 | return text
137 |
138 |
139 | def thai_cleaners(text):
140 | from .thai import num_to_thai, latin_to_thai
141 | text = num_to_thai(text)
142 | text = latin_to_thai(text)
143 | return text
144 |
145 |
146 | def shanghainese_cleaners(text):
147 | from .shanghainese import shanghainese_to_ipa
148 | text = shanghainese_to_ipa(text)
149 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
150 | return text
151 |
152 |
153 | def chinese_dialect_cleaners(text):
154 | from .mandarin import chinese_to_ipa2
155 | from .japanese import japanese_to_ipa3
156 | from .shanghainese import shanghainese_to_ipa
157 | from .cantonese import cantonese_to_ipa
158 | from .english import english_to_lazy_ipa2
159 | from .ngu_dialect import ngu_dialect_to_ipa
160 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
161 | lambda x: chinese_to_ipa2(x.group(1)) + ' ', text)
162 | text = re.sub(r'\[JA\](.*?)\[JA\]',
163 | lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ') + ' ', text)
164 | text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5',
165 | '˧˧˦').replace(
166 | '6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e') + ' ', text)
167 | text = re.sub(r'\[GD\](.*?)\[GD\]',
168 | lambda x: cantonese_to_ipa(x.group(1)) + ' ', text)
169 | text = re.sub(r'\[EN\](.*?)\[EN\]',
170 | lambda x: english_to_lazy_ipa2(x.group(1)) + ' ', text)
171 | text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group(
172 | 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ') + ' ', text)
173 | text = re.sub(r'\s+$', '', text)
174 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
175 | return text
176 |
--------------------------------------------------------------------------------
/modules/safe.py:
--------------------------------------------------------------------------------
1 | # This file is modified from stable-diffusion-webui
2 | import pickle
3 | import collections
4 | import sys
5 | import traceback
6 |
7 | import torch
8 | import numpy
9 | import _codecs
10 | import zipfile
11 | import re
12 |
13 | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
14 | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
15 |
16 |
17 | def encode(*args):
18 | out = _codecs.encode(*args)
19 | return out
20 |
21 |
22 | class RestrictedUnpickler(pickle.Unpickler):
23 | extra_handler = None
24 |
25 | def persistent_load(self, saved_id):
26 | assert saved_id[0] == 'storage'
27 | return TypedStorage()
28 |
29 | def find_class(self, module, name):
30 | if self.extra_handler is not None:
31 | res = self.extra_handler(module, name)
32 | if res is not None:
33 | return res
34 |
35 | if module == 'collections' and name == 'OrderedDict':
36 | return getattr(collections, name)
37 | if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
38 | return getattr(torch._utils, name)
39 | if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
40 | return getattr(torch, name)
41 | if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
42 | return getattr(torch.nn.modules.container, name)
43 | if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
44 | return getattr(numpy.core.multiarray, name)
45 | if module == 'numpy' and name in ['dtype', 'ndarray']:
46 | return getattr(numpy, name)
47 | if module == '_codecs' and name == 'encode':
48 | return encode
49 | if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
50 | import pytorch_lightning.callbacks
51 | return pytorch_lightning.callbacks.model_checkpoint
52 | if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
53 | import pytorch_lightning.callbacks.model_checkpoint
54 | return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
55 | if module == "__builtin__" and name == 'set':
56 | return set
57 |
58 | # Forbid everything else.
59 | raise Exception(f"global '{module}/{name}' is forbidden")
60 |
61 |
62 | # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/'
63 | allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
64 | data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
65 |
66 |
67 | def check_zip_filenames(filename, names):
68 | for name in names:
69 | if allowed_zip_names_re.match(name):
70 | continue
71 |
72 | raise Exception(f"bad file inside {filename}: {name}")
73 |
74 |
75 | def check_pt(filename, extra_handler):
76 | try:
77 |
78 | # new pytorch format is a zip file
79 | with zipfile.ZipFile(filename) as z:
80 | check_zip_filenames(filename, z.namelist())
81 |
82 | # find filename of data.pkl in zip file: '/data.pkl'
83 | data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
84 | if len(data_pkl_filenames) == 0:
85 | raise Exception(f"data.pkl not found in {filename}")
86 | if len(data_pkl_filenames) > 1:
87 | raise Exception(f"Multiple data.pkl found in {filename}")
88 | with z.open(data_pkl_filenames[0]) as file:
89 | unpickler = RestrictedUnpickler(file)
90 | unpickler.extra_handler = extra_handler
91 | unpickler.load()
92 |
93 | except zipfile.BadZipfile:
94 |
95 | # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
96 | with open(filename, "rb") as file:
97 | unpickler = RestrictedUnpickler(file)
98 | unpickler.extra_handler = extra_handler
99 | for i in range(5):
100 | unpickler.load()
101 |
102 |
103 | def load(filename, *args, **kwargs):
104 | return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
105 |
106 |
107 | def load_with_extra(filename, extra_handler=None, *args, **kwargs):
108 | """
109 | this function is intended to be used by extensions that want to load models with
110 | some extra classes in them that the usual unpickler would find suspicious.
111 |
112 | Use the extra_handler argument to specify a function that takes module and field name as text,
113 | and returns that field's value:
114 |
115 | ```python
116 | def extra(module, name):
117 | if module == 'collections' and name == 'OrderedDict':
118 | return collections.OrderedDict
119 |
120 | return None
121 |
122 | safe.load_with_extra('model.pt', extra_handler=extra)
123 | ```
124 |
125 | The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
126 | definitely unsafe.
127 | """
128 | from modules.options import cmd_opts
129 | try:
130 | if not cmd_opts.disable_safe_unpickle:
131 | check_pt(filename, extra_handler)
132 |
133 | except pickle.UnpicklingError:
134 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
135 | print(traceback.format_exc(), file=sys.stderr)
136 | print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
137 | print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
138 | return None
139 |
140 | except Exception:
141 | print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
142 | print(traceback.format_exc(), file=sys.stderr)
143 | print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
144 | print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
145 | return None
146 |
147 | return unsafe_torch_load(filename, *args, **kwargs)
148 |
149 |
150 | class Extra:
151 | """
152 | A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
153 | (because it's not your code making the torch.load call). The intended use is like this:
154 |
155 | ```
156 | import torch
157 | from modules import safe
158 |
159 | def handler(module, name):
160 | if module == 'torch' and name in ['float64', 'float16']:
161 | return getattr(torch, name)
162 |
163 | return None
164 |
165 | with safe.Extra(handler):
166 | x = torch.load('model.pt')
167 | ```
168 | """
169 |
170 | def __init__(self, handler):
171 | self.handler = handler
172 |
173 | def __enter__(self):
174 | global global_extra_handler
175 |
176 | assert global_extra_handler is None, 'already inside an Extra() block'
177 | global_extra_handler = self.handler
178 |
179 | def __exit__(self, exc_type, exc_val, exc_tb):
180 | global global_extra_handler
181 |
182 | global_extra_handler = None
183 |
184 |
185 | unsafe_torch_load = torch.load
186 | torch.load = load
187 | global_extra_handler = None
188 |
--------------------------------------------------------------------------------
/vits/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import argparse
5 | import logging
6 | import json
7 | import subprocess
8 | import numpy as np
9 | from scipy.io.wavfile import read
10 | import torch
11 |
12 | MATPLOTLIB_FLAG = False
13 |
14 | logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
15 | logger = logging
16 |
17 |
18 | def load_checkpoint(checkpoint_path, model, optimizer=None):
19 | assert os.path.isfile(checkpoint_path)
20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21 | iteration = checkpoint_dict['iteration']
22 | learning_rate = checkpoint_dict['learning_rate']
23 | if optimizer is not None:
24 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
25 | saved_state_dict = checkpoint_dict['model']
26 | if hasattr(model, 'module'):
27 | state_dict = model.module.state_dict()
28 | else:
29 | state_dict = model.state_dict()
30 | new_state_dict = {}
31 | for k, v in state_dict.items():
32 | try:
33 | new_state_dict[k] = saved_state_dict[k]
34 | except:
35 | logger.info("%s is not in the checkpoint" % k)
36 | new_state_dict[k] = v
37 | if hasattr(model, 'module'):
38 | model.module.load_state_dict(new_state_dict)
39 | else:
40 | model.load_state_dict(new_state_dict)
41 | logger.info("Loaded checkpoint '{}' (iteration {})".format(
42 | checkpoint_path, iteration))
43 | return model, optimizer, learning_rate, iteration
44 |
45 |
46 | def plot_spectrogram_to_numpy(spectrogram):
47 | global MATPLOTLIB_FLAG
48 | if not MATPLOTLIB_FLAG:
49 | import matplotlib
50 | matplotlib.use("Agg")
51 | MATPLOTLIB_FLAG = True
52 | mpl_logger = logging.getLogger('matplotlib')
53 | mpl_logger.setLevel(logging.WARNING)
54 | import matplotlib.pylab as plt
55 | import numpy as np
56 |
57 | fig, ax = plt.subplots(figsize=(10, 2))
58 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
59 | interpolation='none')
60 | plt.colorbar(im, ax=ax)
61 | plt.xlabel("Frames")
62 | plt.ylabel("Channels")
63 | plt.tight_layout()
64 |
65 | fig.canvas.draw()
66 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
67 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
68 | plt.close()
69 | return data
70 |
71 |
72 | def plot_alignment_to_numpy(alignment, info=None):
73 | global MATPLOTLIB_FLAG
74 | if not MATPLOTLIB_FLAG:
75 | import matplotlib
76 | matplotlib.use("Agg")
77 | MATPLOTLIB_FLAG = True
78 | mpl_logger = logging.getLogger('matplotlib')
79 | mpl_logger.setLevel(logging.WARNING)
80 | import matplotlib.pylab as plt
81 | import numpy as np
82 |
83 | fig, ax = plt.subplots(figsize=(6, 4))
84 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
85 | interpolation='none')
86 | fig.colorbar(im, ax=ax)
87 | xlabel = 'Decoder timestep'
88 | if info is not None:
89 | xlabel += '\n\n' + info
90 | plt.xlabel(xlabel)
91 | plt.ylabel('Encoder timestep')
92 | plt.tight_layout()
93 |
94 | fig.canvas.draw()
95 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
96 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
97 | plt.close()
98 | return data
99 |
100 |
101 | def load_wav_to_torch(full_path):
102 | sampling_rate, data = read(full_path)
103 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
104 |
105 |
106 | def load_filepaths_and_text(filename, split="|"):
107 | with open(filename, encoding='utf-8') as f:
108 | filepaths_and_text = [line.strip().split(split) for line in f]
109 | return filepaths_and_text
110 |
111 |
112 | def get_hparams(init=True):
113 | parser = argparse.ArgumentParser()
114 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
115 | help='JSON file for configuration')
116 | parser.add_argument('-m', '--model', type=str, required=True,
117 | help='Model name')
118 |
119 | args = parser.parse_args()
120 | model_dir = os.path.join("./logs", args.model)
121 |
122 | if not os.path.exists(model_dir):
123 | os.makedirs(model_dir)
124 |
125 | config_path = args.config
126 | config_save_path = os.path.join(model_dir, "config.json")
127 | if init:
128 | with open(config_path, "r") as f:
129 | data = f.read()
130 | with open(config_save_path, "w") as f:
131 | f.write(data)
132 | else:
133 | with open(config_save_path, "r") as f:
134 | data = f.read()
135 | config = json.loads(data)
136 |
137 | hparams = HParams(**config)
138 | hparams.model_dir = model_dir
139 | return hparams
140 |
141 |
142 | def get_hparams_from_dir(model_dir):
143 | config_save_path = os.path.join(model_dir, "config.json")
144 | with open(config_save_path, "r") as f:
145 | data = f.read()
146 | config = json.loads(data)
147 |
148 | hparams = HParams(**config)
149 | hparams.model_dir = model_dir
150 | return hparams
151 |
152 |
153 | def get_hparams_from_file(config_path):
154 | with open(config_path, "r", encoding="utf-8") as f:
155 | data = f.read()
156 | config = json.loads(data)
157 |
158 | hparams = HParams(**config)
159 | return hparams
160 |
161 |
162 | def check_git_hash(model_dir):
163 | source_dir = os.path.dirname(os.path.realpath(__file__))
164 | if not os.path.exists(os.path.join(source_dir, ".git")):
165 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
166 | source_dir
167 | ))
168 | return
169 |
170 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
171 |
172 | path = os.path.join(model_dir, "githash")
173 | if os.path.exists(path):
174 | saved_hash = open(path).read()
175 | if saved_hash != cur_hash:
176 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(
177 | saved_hash[:8], cur_hash[:8]))
178 | else:
179 | open(path, "w").write(cur_hash)
180 |
181 |
182 | def get_logger(model_dir, filename="train.log"):
183 | global logger
184 | logger = logging.getLogger(os.path.basename(model_dir))
185 | logger.setLevel(logging.DEBUG)
186 |
187 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
188 | if not os.path.exists(model_dir):
189 | os.makedirs(model_dir)
190 | h = logging.FileHandler(os.path.join(model_dir, filename))
191 | h.setLevel(logging.DEBUG)
192 | h.setFormatter(formatter)
193 | logger.addHandler(h)
194 | return logger
195 |
196 |
197 | class HParams():
198 | def __init__(self, **kwargs):
199 | for k, v in kwargs.items():
200 | if type(v) == dict:
201 | v = HParams(**v)
202 | self[k] = v
203 |
204 | def keys(self):
205 | return self.__dict__.keys()
206 |
207 | def items(self):
208 | return self.__dict__.items()
209 |
210 | def values(self):
211 | return self.__dict__.values()
212 |
213 | def __len__(self):
214 | return len(self.__dict__)
215 |
216 | def __getitem__(self, key):
217 | return getattr(self, key)
218 |
219 | def __setitem__(self, key, value):
220 | return setattr(self, key, value)
221 |
222 | def __contains__(self, key):
223 | return key in self.__dict__
224 |
225 | def __repr__(self):
226 | return self.__dict__.__repr__()
227 |
--------------------------------------------------------------------------------
/modules/process.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os.path
3 | import re
4 | import shutil
5 | import time
6 | from pathlib import Path
7 | from typing import Tuple, List
8 |
9 | import librosa
10 | import numpy as np
11 | import scipy.io.wavfile as wavfile
12 | import soundfile
13 | import tqdm
14 | from torch import no_grad, LongTensor
15 |
16 | import modules.sovits_model as sovits_model
17 | import modules.vits_model as vits_model
18 | from modules.devices import device, torch_gc
19 | from modules.options import cmd_opts
20 | from modules.sovits_model import Svc as SovitsSvc
21 | from modules.utils import windows_filename
22 | from modules.vits_model import VITSModel
23 | from repositories.sovits.inference import slicer
24 | from vits import commons
25 | from vits.text import text_to_sequence
26 |
27 |
28 | class Text2SpeechTask:
29 | origin: str
30 | speaker: str
31 | method: str
32 | pre_processed: List[Tuple[int, str]]
33 |
34 | def __init__(self, origin: str, speaker: str, method: str):
35 | self.origin = origin
36 | self.speaker = speaker
37 | self.method = method
38 | self.pre_processed = []
39 |
40 | def preprocess(self):
41 | model = vits_model.curr_vits_model
42 | if self.method == "Simple":
43 | speaker_id = model.speakers.index(self.speaker)
44 | self.pre_processed.append((speaker_id, self.origin))
45 | elif self.method == "Multi Speakers":
46 | match = re.findall(r"\[(.*)] (.*)", self.origin)
47 | for m in match:
48 | if m[0] not in model.speakers:
49 | err = f"Error: Unknown speaker {m[0]}, check your input."
50 | print(err)
51 | return err
52 | speaker_id = model.speakers.index(m[0])
53 | self.pre_processed.append((speaker_id, m[1]))
54 | elif self.method == "Batch Process":
55 | spl = self.origin.split("\n")
56 | speaker_id = model.speakers.index(self.speaker)
57 | for line in spl:
58 | self.pre_processed.append((speaker_id, line))
59 |
60 |
61 | class SovitsTask:
62 | # 暂时貌似没有需要 preprocess 的,就先放这里了
63 | pass
64 |
65 |
66 | def text2speech(text: str, speaker: str, speed, method="Simple"):
67 | if text == "":
68 | return "Fail: You need to input text.", None
69 | model = vits_model.get_model()
70 | if not model:
71 | return "Fail: No vits model loaded. Please select a model to load first.", None
72 | task = Text2SpeechTask(origin=text, speaker=speaker, method=method)
73 | err = task.preprocess()
74 | if err:
75 | return err, None
76 | ti = int(time.time())
77 | save_path = ""
78 | output_info = "Success saved to "
79 | outputs = []
80 | for t in task.pre_processed:
81 | sample_rate, data = process_vits(model=model,
82 | text=t[1], speaker_id=t[0], speed=speed)
83 | outputs.append(data)
84 | save_path = f"outputs/vits/{str(ti)}-{windows_filename(t[1])}.wav"
85 | wavfile.write(save_path, sample_rate, data)
86 | output_info += f"\n{save_path}"
87 | ti += 1
88 |
89 | torch_gc()
90 |
91 | if len(outputs) > 1:
92 | batch_file_path = f"outputs/vits-batch/{str(int(time.time()))}.wav"
93 | wavfile.write(batch_file_path, vits_model.curr_vits_model.hps.data.sampling_rate, np.concatenate(outputs))
94 | return f"{output_info}\n{batch_file_path}", batch_file_path
95 | return output_info, save_path
96 |
97 |
98 | def sovits_process(audio_path, speaker: str, vc_transform: int, slice_db: int):
99 | if not audio_path:
100 | return "Fail: You need to input an audio.", None
101 | model = sovits_model.get_model()
102 | if not model:
103 | return "Fail: No so-vits model loaded. Please select a model to load first.", None
104 | ti = int(time.time())
105 | save_path = ""
106 | output_info = "Success saved to "
107 | outputs = []
108 | try:
109 | for af in audio_path:
110 | data, sampling_rate = process_so_vits(svc_model=sovits_model.get_model(),
111 | sid=speaker,
112 | input_audio=af,
113 | vc_transform=vc_transform,
114 | slice_db=slice_db)
115 | outputs.extend(data)
116 | save_path = f"outputs/sovits/{str(ti)}.wav"
117 | soundfile.write(save_path, data, sampling_rate, format="wav")
118 | ti += 1
119 | output_info += f"\n{save_path}"
120 |
121 | if len(outputs) > 1:
122 | batch_file_path = f"outputs/sovits-batch/{str(int(time.time()))}.wav"
123 | soundfile.write(batch_file_path, outputs, sovits_model.get_model().target_sample)
124 | return f"{output_info}\n{batch_file_path}", batch_file_path
125 |
126 | finally:
127 | torch_gc()
128 |
129 | return output_info, save_path
130 |
131 |
132 | def text_processing(text, model: VITSModel):
133 | hps = model.hps
134 | _use_symbols = model.symbols
135 | # 留了点屎山 以后再处理吧
136 | if hasattr(hps, "symbols_zh"):
137 | _use_symbols = hps.symbols_zh
138 | text_norm = text_to_sequence(text, _use_symbols, hps.data.text_cleaners)
139 | if hps.data.add_blank:
140 | text_norm = commons.intersperse(text_norm, 0)
141 | text_norm = LongTensor(text_norm)
142 | return text_norm
143 |
144 |
145 | def process_vits(model: VITSModel, text: str,
146 | speaker_id, speed,
147 | noise_scale=0.667,
148 | noise_scale_w=0.8) -> Tuple[int, np.array]:
149 | stn_tst = text_processing(text, model)
150 | with no_grad():
151 | x_tst = stn_tst.unsqueeze(0).to(device)
152 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
153 | sid = LongTensor([speaker_id]).to(device)
154 | audio = model.model.infer(x_tst, x_tst_lengths, sid=sid,
155 | noise_scale=noise_scale, noise_scale_w=noise_scale_w,
156 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
157 | del stn_tst, x_tst, x_tst_lengths, sid
158 | return model.hps.data.sampling_rate, audio
159 |
160 |
161 | def process_so_vits(svc_model: SovitsSvc, sid, input_audio, vc_transform, slice_db):
162 | audio_path = input_audio.name
163 | wav_path = os.path.join("temp", str(int(time.time())) + ".wav")
164 | if Path(audio_path).suffix != '.wav':
165 | raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
166 | soundfile.write(wav_path, raw_audio, raw_sample_rate)
167 | else:
168 | shutil.copy(audio_path, wav_path)
169 | chunks = slicer.cut(wav_path, db_thresh=slice_db)
170 | audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
171 |
172 | audio = []
173 | for (slice_tag, data) in tqdm.tqdm(audio_data):
174 | if cmd_opts.debug:
175 | print(f'segment start, {round(len(data) / audio_sr, 3)}s')
176 | length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
177 | raw_path = io.BytesIO()
178 | soundfile.write(raw_path, data, audio_sr, format="wav")
179 | raw_path.seek(0)
180 | if slice_tag:
181 | if cmd_opts.debug:
182 | print('jump empty segment')
183 | _audio = np.zeros(length)
184 | else:
185 | out_audio, out_sr = svc_model.infer(sid, vc_transform, raw_path)
186 | _audio = out_audio.cpu().numpy()
187 | audio.extend(list(_audio))
188 | os.remove(wav_path)
189 | return audio, svc_model.target_sample
190 |
--------------------------------------------------------------------------------
/vits/text/mandarin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | from pypinyin import lazy_pinyin, BOPOMOFO
5 | import jieba
6 | import cn2an
7 | import logging
8 |
9 | logging.getLogger('jieba').setLevel(logging.WARNING)
10 | jieba.initialize()
11 |
12 |
13 | # List of (Latin alphabet, bopomofo) pairs:
14 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
15 | ('a', 'ㄟˉ'),
16 | ('b', 'ㄅㄧˋ'),
17 | ('c', 'ㄙㄧˉ'),
18 | ('d', 'ㄉㄧˋ'),
19 | ('e', 'ㄧˋ'),
20 | ('f', 'ㄝˊㄈㄨˋ'),
21 | ('g', 'ㄐㄧˋ'),
22 | ('h', 'ㄝˇㄑㄩˋ'),
23 | ('i', 'ㄞˋ'),
24 | ('j', 'ㄐㄟˋ'),
25 | ('k', 'ㄎㄟˋ'),
26 | ('l', 'ㄝˊㄛˋ'),
27 | ('m', 'ㄝˊㄇㄨˋ'),
28 | ('n', 'ㄣˉ'),
29 | ('o', 'ㄡˉ'),
30 | ('p', 'ㄆㄧˉ'),
31 | ('q', 'ㄎㄧㄡˉ'),
32 | ('r', 'ㄚˋ'),
33 | ('s', 'ㄝˊㄙˋ'),
34 | ('t', 'ㄊㄧˋ'),
35 | ('u', 'ㄧㄡˉ'),
36 | ('v', 'ㄨㄧˉ'),
37 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
38 | ('x', 'ㄝˉㄎㄨˋㄙˋ'),
39 | ('y', 'ㄨㄞˋ'),
40 | ('z', 'ㄗㄟˋ')
41 | ]]
42 |
43 | # List of (bopomofo, romaji) pairs:
44 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
45 | ('ㄅㄛ', 'p⁼wo'),
46 | ('ㄆㄛ', 'pʰwo'),
47 | ('ㄇㄛ', 'mwo'),
48 | ('ㄈㄛ', 'fwo'),
49 | ('ㄅ', 'p⁼'),
50 | ('ㄆ', 'pʰ'),
51 | ('ㄇ', 'm'),
52 | ('ㄈ', 'f'),
53 | ('ㄉ', 't⁼'),
54 | ('ㄊ', 'tʰ'),
55 | ('ㄋ', 'n'),
56 | ('ㄌ', 'l'),
57 | ('ㄍ', 'k⁼'),
58 | ('ㄎ', 'kʰ'),
59 | ('ㄏ', 'h'),
60 | ('ㄐ', 'ʧ⁼'),
61 | ('ㄑ', 'ʧʰ'),
62 | ('ㄒ', 'ʃ'),
63 | ('ㄓ', 'ʦ`⁼'),
64 | ('ㄔ', 'ʦ`ʰ'),
65 | ('ㄕ', 's`'),
66 | ('ㄖ', 'ɹ`'),
67 | ('ㄗ', 'ʦ⁼'),
68 | ('ㄘ', 'ʦʰ'),
69 | ('ㄙ', 's'),
70 | ('ㄚ', 'a'),
71 | ('ㄛ', 'o'),
72 | ('ㄜ', 'ə'),
73 | ('ㄝ', 'e'),
74 | ('ㄞ', 'ai'),
75 | ('ㄟ', 'ei'),
76 | ('ㄠ', 'au'),
77 | ('ㄡ', 'ou'),
78 | ('ㄧㄢ', 'yeNN'),
79 | ('ㄢ', 'aNN'),
80 | ('ㄧㄣ', 'iNN'),
81 | ('ㄣ', 'əNN'),
82 | ('ㄤ', 'aNg'),
83 | ('ㄧㄥ', 'iNg'),
84 | ('ㄨㄥ', 'uNg'),
85 | ('ㄩㄥ', 'yuNg'),
86 | ('ㄥ', 'əNg'),
87 | ('ㄦ', 'əɻ'),
88 | ('ㄧ', 'i'),
89 | ('ㄨ', 'u'),
90 | ('ㄩ', 'ɥ'),
91 | ('ˉ', '→'),
92 | ('ˊ', '↑'),
93 | ('ˇ', '↓↑'),
94 | ('ˋ', '↓'),
95 | ('˙', ''),
96 | (',', ','),
97 | ('。', '.'),
98 | ('!', '!'),
99 | ('?', '?'),
100 | ('—', '-')
101 | ]]
102 |
103 | # List of (romaji, ipa) pairs:
104 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
105 | ('ʃy', 'ʃ'),
106 | ('ʧʰy', 'ʧʰ'),
107 | ('ʧ⁼y', 'ʧ⁼'),
108 | ('NN', 'n'),
109 | ('Ng', 'ŋ'),
110 | ('y', 'j'),
111 | ('h', 'x')
112 | ]]
113 |
114 | # List of (bopomofo, ipa) pairs:
115 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
116 | ('ㄅㄛ', 'p⁼wo'),
117 | ('ㄆㄛ', 'pʰwo'),
118 | ('ㄇㄛ', 'mwo'),
119 | ('ㄈㄛ', 'fwo'),
120 | ('ㄅ', 'p⁼'),
121 | ('ㄆ', 'pʰ'),
122 | ('ㄇ', 'm'),
123 | ('ㄈ', 'f'),
124 | ('ㄉ', 't⁼'),
125 | ('ㄊ', 'tʰ'),
126 | ('ㄋ', 'n'),
127 | ('ㄌ', 'l'),
128 | ('ㄍ', 'k⁼'),
129 | ('ㄎ', 'kʰ'),
130 | ('ㄏ', 'x'),
131 | ('ㄐ', 'tʃ⁼'),
132 | ('ㄑ', 'tʃʰ'),
133 | ('ㄒ', 'ʃ'),
134 | ('ㄓ', 'ts`⁼'),
135 | ('ㄔ', 'ts`ʰ'),
136 | ('ㄕ', 's`'),
137 | ('ㄖ', 'ɹ`'),
138 | ('ㄗ', 'ts⁼'),
139 | ('ㄘ', 'tsʰ'),
140 | ('ㄙ', 's'),
141 | ('ㄚ', 'a'),
142 | ('ㄛ', 'o'),
143 | ('ㄜ', 'ə'),
144 | ('ㄝ', 'ɛ'),
145 | ('ㄞ', 'aɪ'),
146 | ('ㄟ', 'eɪ'),
147 | ('ㄠ', 'ɑʊ'),
148 | ('ㄡ', 'oʊ'),
149 | ('ㄧㄢ', 'jɛn'),
150 | ('ㄩㄢ', 'ɥæn'),
151 | ('ㄢ', 'an'),
152 | ('ㄧㄣ', 'in'),
153 | ('ㄩㄣ', 'ɥn'),
154 | ('ㄣ', 'ən'),
155 | ('ㄤ', 'ɑŋ'),
156 | ('ㄧㄥ', 'iŋ'),
157 | ('ㄨㄥ', 'ʊŋ'),
158 | ('ㄩㄥ', 'jʊŋ'),
159 | ('ㄥ', 'əŋ'),
160 | ('ㄦ', 'əɻ'),
161 | ('ㄧ', 'i'),
162 | ('ㄨ', 'u'),
163 | ('ㄩ', 'ɥ'),
164 | ('ˉ', '→'),
165 | ('ˊ', '↑'),
166 | ('ˇ', '↓↑'),
167 | ('ˋ', '↓'),
168 | ('˙', ''),
169 | (',', ','),
170 | ('。', '.'),
171 | ('!', '!'),
172 | ('?', '?'),
173 | ('—', '-')
174 | ]]
175 |
176 | # List of (bopomofo, ipa2) pairs:
177 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
178 | ('ㄅㄛ', 'pwo'),
179 | ('ㄆㄛ', 'pʰwo'),
180 | ('ㄇㄛ', 'mwo'),
181 | ('ㄈㄛ', 'fwo'),
182 | ('ㄅ', 'p'),
183 | ('ㄆ', 'pʰ'),
184 | ('ㄇ', 'm'),
185 | ('ㄈ', 'f'),
186 | ('ㄉ', 't'),
187 | ('ㄊ', 'tʰ'),
188 | ('ㄋ', 'n'),
189 | ('ㄌ', 'l'),
190 | ('ㄍ', 'k'),
191 | ('ㄎ', 'kʰ'),
192 | ('ㄏ', 'h'),
193 | ('ㄐ', 'tɕ'),
194 | ('ㄑ', 'tɕʰ'),
195 | ('ㄒ', 'ɕ'),
196 | ('ㄓ', 'tʂ'),
197 | ('ㄔ', 'tʂʰ'),
198 | ('ㄕ', 'ʂ'),
199 | ('ㄖ', 'ɻ'),
200 | ('ㄗ', 'ts'),
201 | ('ㄘ', 'tsʰ'),
202 | ('ㄙ', 's'),
203 | ('ㄚ', 'a'),
204 | ('ㄛ', 'o'),
205 | ('ㄜ', 'ɤ'),
206 | ('ㄝ', 'ɛ'),
207 | ('ㄞ', 'aɪ'),
208 | ('ㄟ', 'eɪ'),
209 | ('ㄠ', 'ɑʊ'),
210 | ('ㄡ', 'oʊ'),
211 | ('ㄧㄢ', 'jɛn'),
212 | ('ㄩㄢ', 'yæn'),
213 | ('ㄢ', 'an'),
214 | ('ㄧㄣ', 'in'),
215 | ('ㄩㄣ', 'yn'),
216 | ('ㄣ', 'ən'),
217 | ('ㄤ', 'ɑŋ'),
218 | ('ㄧㄥ', 'iŋ'),
219 | ('ㄨㄥ', 'ʊŋ'),
220 | ('ㄩㄥ', 'jʊŋ'),
221 | ('ㄥ', 'ɤŋ'),
222 | ('ㄦ', 'əɻ'),
223 | ('ㄧ', 'i'),
224 | ('ㄨ', 'u'),
225 | ('ㄩ', 'y'),
226 | ('ˉ', '˥'),
227 | ('ˊ', '˧˥'),
228 | ('ˇ', '˨˩˦'),
229 | ('ˋ', '˥˩'),
230 | ('˙', ''),
231 | (',', ','),
232 | ('。', '.'),
233 | ('!', '!'),
234 | ('?', '?'),
235 | ('—', '-')
236 | ]]
237 |
238 |
239 | def number_to_chinese(text):
240 | numbers = re.findall(r'\d+(?:\.?\d+)?', text)
241 | for number in numbers:
242 | text = text.replace(number, cn2an.an2cn(number), 1)
243 | return text
244 |
245 |
246 | def chinese_to_bopomofo(text):
247 | text = text.replace('、', ',').replace(';', ',').replace(':', ',')
248 | words = jieba.lcut(text, cut_all=False)
249 | text = ''
250 | for word in words:
251 | bopomofos = lazy_pinyin(word, BOPOMOFO)
252 | if not re.search('[\u4e00-\u9fff]', word):
253 | text += word
254 | continue
255 | for i in range(len(bopomofos)):
256 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
257 | if text != '':
258 | text += ' '
259 | text += ''.join(bopomofos)
260 | return text
261 |
262 |
263 | def latin_to_bopomofo(text):
264 | for regex, replacement in _latin_to_bopomofo:
265 | text = re.sub(regex, replacement, text)
266 | return text
267 |
268 |
269 | def bopomofo_to_romaji(text):
270 | for regex, replacement in _bopomofo_to_romaji:
271 | text = re.sub(regex, replacement, text)
272 | return text
273 |
274 |
275 | def bopomofo_to_ipa(text):
276 | for regex, replacement in _bopomofo_to_ipa:
277 | text = re.sub(regex, replacement, text)
278 | return text
279 |
280 |
281 | def bopomofo_to_ipa2(text):
282 | for regex, replacement in _bopomofo_to_ipa2:
283 | text = re.sub(regex, replacement, text)
284 | return text
285 |
286 |
287 | def chinese_to_romaji(text):
288 | text = number_to_chinese(text)
289 | text = chinese_to_bopomofo(text)
290 | text = latin_to_bopomofo(text)
291 | text = bopomofo_to_romaji(text)
292 | text = re.sub('i([aoe])', r'y\1', text)
293 | text = re.sub('u([aoəe])', r'w\1', text)
294 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
295 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
296 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
297 | return text
298 |
299 |
300 | def chinese_to_lazy_ipa(text):
301 | text = chinese_to_romaji(text)
302 | for regex, replacement in _romaji_to_ipa:
303 | text = re.sub(regex, replacement, text)
304 | return text
305 |
306 |
307 | def chinese_to_ipa(text):
308 | text = number_to_chinese(text)
309 | text = chinese_to_bopomofo(text)
310 | text = latin_to_bopomofo(text)
311 | text = bopomofo_to_ipa(text)
312 | text = re.sub('i([aoe])', r'j\1', text)
313 | text = re.sub('u([aoəe])', r'w\1', text)
314 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
315 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
316 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
317 | return text
318 |
319 |
320 | def chinese_to_ipa2(text):
321 | text = number_to_chinese(text)
322 | text = chinese_to_bopomofo(text)
323 | text = latin_to_bopomofo(text)
324 | text = bopomofo_to_ipa2(text)
325 | text = re.sub(r'i([aoe])', r'j\1', text)
326 | text = re.sub(r'u([aoəe])', r'w\1', text)
327 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
328 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
329 | return text
330 |
--------------------------------------------------------------------------------
/vits/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 | DEFAULT_MIN_BIN_WIDTH = 1e-3
7 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
8 | DEFAULT_MIN_DERIVATIVE = 1e-3
9 |
10 |
11 | def piecewise_rational_quadratic_transform(inputs,
12 | unnormalized_widths,
13 | unnormalized_heights,
14 | unnormalized_derivatives,
15 | inverse=False,
16 | tails=None,
17 | tail_bound=1.,
18 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
19 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
20 | min_derivative=DEFAULT_MIN_DERIVATIVE):
21 | if tails is None:
22 | spline_fn = rational_quadratic_spline
23 | spline_kwargs = {}
24 | else:
25 | spline_fn = unconstrained_rational_quadratic_spline
26 | spline_kwargs = {
27 | 'tails': tails,
28 | 'tail_bound': tail_bound
29 | }
30 |
31 | outputs, logabsdet = spline_fn(
32 | inputs=inputs,
33 | unnormalized_widths=unnormalized_widths,
34 | unnormalized_heights=unnormalized_heights,
35 | unnormalized_derivatives=unnormalized_derivatives,
36 | inverse=inverse,
37 | min_bin_width=min_bin_width,
38 | min_bin_height=min_bin_height,
39 | min_derivative=min_derivative,
40 | **spline_kwargs
41 | )
42 | return outputs, logabsdet
43 |
44 |
45 | def searchsorted(bin_locations, inputs, eps=1e-6):
46 | bin_locations[..., -1] += eps
47 | return torch.sum(
48 | inputs[..., None] >= bin_locations,
49 | dim=-1
50 | ) - 1
51 |
52 |
53 | def unconstrained_rational_quadratic_spline(inputs,
54 | unnormalized_widths,
55 | unnormalized_heights,
56 | unnormalized_derivatives,
57 | inverse=False,
58 | tails='linear',
59 | tail_bound=1.,
60 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
61 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
62 | min_derivative=DEFAULT_MIN_DERIVATIVE):
63 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
64 | outside_interval_mask = ~inside_interval_mask
65 |
66 | outputs = torch.zeros_like(inputs)
67 | logabsdet = torch.zeros_like(inputs)
68 |
69 | if tails == 'linear':
70 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
71 | constant = np.log(np.exp(1 - min_derivative) - 1)
72 | unnormalized_derivatives[..., 0] = constant
73 | unnormalized_derivatives[..., -1] = constant
74 |
75 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
76 | logabsdet[outside_interval_mask] = 0
77 | else:
78 | raise RuntimeError('{} tails are not implemented.'.format(tails))
79 |
80 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
81 | inputs=inputs[inside_interval_mask],
82 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
83 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
84 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
85 | inverse=inverse,
86 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
87 | min_bin_width=min_bin_width,
88 | min_bin_height=min_bin_height,
89 | min_derivative=min_derivative
90 | )
91 |
92 | return outputs, logabsdet
93 |
94 |
95 | def rational_quadratic_spline(inputs,
96 | unnormalized_widths,
97 | unnormalized_heights,
98 | unnormalized_derivatives,
99 | inverse=False,
100 | left=0., right=1., bottom=0., top=1.,
101 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
102 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
103 | min_derivative=DEFAULT_MIN_DERIVATIVE):
104 | if torch.min(inputs) < left or torch.max(inputs) > right:
105 | raise ValueError('Input to a transform is not within its domain')
106 |
107 | num_bins = unnormalized_widths.shape[-1]
108 |
109 | if min_bin_width * num_bins > 1.0:
110 | raise ValueError('Minimal bin width too large for the number of bins')
111 | if min_bin_height * num_bins > 1.0:
112 | raise ValueError('Minimal bin height too large for the number of bins')
113 |
114 | widths = F.softmax(unnormalized_widths, dim=-1)
115 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
116 | cumwidths = torch.cumsum(widths, dim=-1)
117 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
118 | cumwidths = (right - left) * cumwidths + left
119 | cumwidths[..., 0] = left
120 | cumwidths[..., -1] = right
121 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
122 |
123 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
124 |
125 | heights = F.softmax(unnormalized_heights, dim=-1)
126 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
127 | cumheights = torch.cumsum(heights, dim=-1)
128 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
129 | cumheights = (top - bottom) * cumheights + bottom
130 | cumheights[..., 0] = bottom
131 | cumheights[..., -1] = top
132 | heights = cumheights[..., 1:] - cumheights[..., :-1]
133 |
134 | if inverse:
135 | bin_idx = searchsorted(cumheights, inputs)[..., None]
136 | else:
137 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
138 |
139 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
140 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
141 |
142 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
143 | delta = heights / widths
144 | input_delta = delta.gather(-1, bin_idx)[..., 0]
145 |
146 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
147 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
148 |
149 | input_heights = heights.gather(-1, bin_idx)[..., 0]
150 |
151 | if inverse:
152 | a = (((inputs - input_cumheights) * (input_derivatives
153 | + input_derivatives_plus_one
154 | - 2 * input_delta)
155 | + input_heights * (input_delta - input_derivatives)))
156 | b = (input_heights * input_derivatives
157 | - (inputs - input_cumheights) * (input_derivatives
158 | + input_derivatives_plus_one
159 | - 2 * input_delta))
160 | c = - input_delta * (inputs - input_cumheights)
161 |
162 | discriminant = b.pow(2) - 4 * a * c
163 | assert (discriminant >= 0).all()
164 |
165 | root = (2 * c) / (-b - torch.sqrt(discriminant))
166 | outputs = root * input_bin_widths + input_cumwidths
167 |
168 | theta_one_minus_theta = root * (1 - root)
169 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
170 | * theta_one_minus_theta)
171 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
172 | + 2 * input_delta * theta_one_minus_theta
173 | + input_derivatives * (1 - root).pow(2))
174 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
175 |
176 | return outputs, -logabsdet
177 | else:
178 | theta = (inputs - input_cumwidths) / input_bin_widths
179 | theta_one_minus_theta = theta * (1 - theta)
180 |
181 | numerator = input_heights * (input_delta * theta.pow(2)
182 | + input_derivatives * theta_one_minus_theta)
183 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
184 | * theta_one_minus_theta)
185 | outputs = input_cumheights + numerator / denominator
186 |
187 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
188 | + 2 * input_delta * theta_one_minus_theta
189 | + input_derivatives * (1 - theta).pow(2))
190 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
191 |
192 | return outputs, logabsdet
193 |
--------------------------------------------------------------------------------
/modules/sovits_model.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import time
4 | from pathlib import Path
5 | from typing import Dict
6 |
7 | import librosa
8 | import numpy as np
9 | import parselmouth
10 | import soundfile
11 | import torch
12 | import torchaudio
13 |
14 | import modules.utils as utils
15 | from modules.devices import device
16 | from modules.model import ModelInfo
17 | from modules.model import refresh_model_list
18 | from modules.options import cmd_opts
19 | from modules.utils import model_hash
20 | from repositories.sovits.hubert import hubert_model
21 | from repositories.sovits.models import SynthesizerTrn
22 |
23 | MODEL_PATH = os.path.join(os.path.join(os.getcwd(), "models"), "sovits")
24 |
25 |
26 | class Svc(object):
27 | def __init__(self, net_g_path, config_path, hubert_path="models/hubert/hubert-soft-0d54a1f4.pt"):
28 | self.net_g_path = net_g_path
29 | self.hubert_path = hubert_path
30 | self.dev = device
31 | self.net_g_ms = None
32 | self.hps_ms = utils.get_hparams_from_file(config_path)
33 | self.target_sample = self.hps_ms.data.sampling_rate
34 | self.hop_size = self.hps_ms.data.hop_length
35 | self.speakers = {}
36 | for spk, sid in self.hps_ms.spk.items():
37 | self.speakers[sid] = spk
38 | self.spk2id = self.hps_ms.spk
39 |
40 | def load_model(self):
41 | self.hubert_soft = hubert_model.hubert_soft(self.hubert_path)
42 | if self.dev != torch.device("cpu"):
43 | self.hubert_soft = self.hubert_soft.cuda()
44 | self.net_g_ms = SynthesizerTrn(
45 | self.hps_ms.data.filter_length // 2 + 1,
46 | self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
47 | **self.hps_ms.model)
48 | _ = load_checkpoint(self.net_g_path, self.net_g_ms, None)
49 | if "half" in self.net_g_path and self.dev != torch.device("cpu"):
50 | _ = self.net_g_ms.half().eval().to(self.dev)
51 | else:
52 | _ = self.net_g_ms.eval().to(self.dev)
53 |
54 | def get_units(self, source, sr):
55 | source = source.unsqueeze(0).to(self.dev)
56 | with torch.inference_mode():
57 | start = time.time()
58 | units = self.hubert_soft.units(source)
59 | use_time = time.time() - start
60 | if cmd_opts.debug:
61 | print("hubert use time: {}".format(use_time))
62 | return units
63 |
64 | def get_unit_pitch(self, in_path, tran):
65 | source, sr = torchaudio.load(in_path)
66 | source = torchaudio.functional.resample(source, sr, 16000)
67 | if len(source.shape) == 2 and source.shape[1] >= 2:
68 | source = torch.mean(source, dim=0).unsqueeze(0)
69 | soft = self.get_units(source, sr).squeeze(0).cpu().numpy()
70 | f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0] * 2, tran)
71 | return soft, f0
72 |
73 | def infer(self, speaker_id, tran, raw_path):
74 | if type(speaker_id) == str:
75 | speaker_id = self.spk2id[speaker_id]
76 | sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
77 | soft, pitch = self.get_unit_pitch(raw_path, tran)
78 | f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.dev)
79 | if "half" in self.net_g_path and torch.cuda.is_available():
80 | stn_tst = torch.HalfTensor(soft)
81 | else:
82 | stn_tst = torch.FloatTensor(soft)
83 | with torch.no_grad():
84 | x_tst = stn_tst.unsqueeze(0).to(self.dev)
85 | start = time.time()
86 | x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2)
87 | audio = self.net_g_ms.infer(x_tst, f0=f0, g=sid)[0, 0].data.float()
88 | use_time = time.time() - start
89 | if cmd_opts.debug:
90 | print("vits use time: {}".format(use_time))
91 | return audio, audio.shape[-1]
92 |
93 |
94 | class SovitsModel(Svc):
95 | def __init__(self, info: ModelInfo):
96 | super(SovitsModel, self).__init__(info.checkpoint_path, info.config_path)
97 | self.model_name = info.model_name
98 |
99 |
100 | sovits_model_list: Dict[str, ModelInfo] = {}
101 | curr_model: SovitsModel = None
102 |
103 |
104 | def get_model() -> Svc:
105 | return curr_model
106 |
107 |
108 | def get_model_name():
109 | return curr_model.model_name if curr_model is not None else None
110 |
111 |
112 | def get_model_list():
113 | return [k for k, _ in sovits_model_list.items()]
114 |
115 |
116 | def get_speakers():
117 | if curr_model is None:
118 | return ["None"]
119 | return [spk for sid, spk in curr_model.speakers.items() if spk != "None"]
120 |
121 |
122 | def refresh_list():
123 | sovits_model_list.clear()
124 | model_list = refresh_model_list(model_path=MODEL_PATH)
125 | for m in model_list:
126 | d = m["dir"]
127 | p = os.path.join(MODEL_PATH, m["dir"])
128 | pth_path = m["pth"]
129 | config_path = m["config"]
130 | sovits_model_list[d] = ModelInfo(
131 | model_name=d,
132 | model_folder=p,
133 | model_hash=model_hash(pth_path),
134 | checkpoint_path=pth_path,
135 | config_path=config_path
136 | )
137 | if len(sovits_model_list.items()) == 0:
138 | print("No so-vits model found. Please put a model in models/sovits")
139 |
140 |
141 | # def init_model():
142 | # global curr_sovits_model
143 | # info = next(iter(sovits_model_list.values()))
144 | # # load_model(info.model_name)
145 | # curr_sovits_model = SovitsModel(info)
146 |
147 |
148 | def load_model(model_name: str):
149 | global curr_model, sovits_model_list
150 | if curr_model and curr_model.model_name == model_name:
151 | return
152 | info = sovits_model_list[model_name]
153 | print(f"Loading so-vits weights [{info.model_hash}] from {info.checkpoint_path}...")
154 | m = SovitsModel(info)
155 | m.load_model()
156 | curr_model = m
157 | print("so-vits model loaded.")
158 |
159 |
160 | def load_checkpoint(checkpoint_path, model, optimizer=None):
161 | assert os.path.isfile(checkpoint_path)
162 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
163 | iteration = checkpoint_dict['iteration']
164 | learning_rate = checkpoint_dict['learning_rate']
165 | if iteration is None:
166 | iteration = 1
167 | if learning_rate is None:
168 | learning_rate = 0.0002
169 | if optimizer is not None and checkpoint_dict['optimizer'] is not None:
170 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
171 | saved_state_dict = checkpoint_dict['model']
172 | if hasattr(model, 'module'):
173 | state_dict = model.module.state_dict()
174 | else:
175 | state_dict = model.state_dict()
176 | new_state_dict = {}
177 | for k, v in state_dict.items():
178 | try:
179 | new_state_dict[k] = saved_state_dict[k]
180 | except:
181 | # logger.info("%s is not in the checkpoint" % k)
182 | new_state_dict[k] = v
183 | if hasattr(model, 'module'):
184 | model.module.load_state_dict(new_state_dict)
185 | else:
186 | model.load_state_dict(new_state_dict)
187 | print(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})")
188 | return model, optimizer, learning_rate, iteration
189 |
190 |
191 | def format_wav(audio_path):
192 | if Path(audio_path).suffix == '.wav':
193 | return
194 | raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
195 | soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
196 |
197 |
198 | def get_end_file(dir_path, end):
199 | file_lists = []
200 | for root, dirs, files in os.walk(dir_path):
201 | files = [f for f in files if f[0] != '.']
202 | dirs[:] = [d for d in dirs if d[0] != '.']
203 | for f_file in files:
204 | if f_file.endswith(end):
205 | file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
206 | return file_lists
207 |
208 |
209 | def get_md5(content):
210 | return hashlib.new("md5", content).hexdigest()
211 |
212 |
213 | def resize2d_f0(x, target_len):
214 | source = np.array(x)
215 | source[source < 0.001] = np.nan
216 | target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)),
217 | source)
218 | res = np.nan_to_num(target)
219 | return res
220 |
221 |
222 | def get_f0(x, p_len, f0_up_key=0):
223 | time_step = 160 / 16000 * 1000
224 | f0_min = 50
225 | f0_max = 1100
226 | f0_mel_min = 1127 * np.log(1 + f0_min / 700)
227 | f0_mel_max = 1127 * np.log(1 + f0_max / 700)
228 |
229 | f0 = parselmouth.Sound(x, 16000).to_pitch_ac(
230 | time_step=time_step / 1000, voicing_threshold=0.6,
231 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
232 | if len(f0) > p_len:
233 | f0 = f0[:p_len]
234 | pad_size = (p_len - len(f0) + 1) // 2
235 | if (pad_size > 0 or p_len - len(f0) - pad_size > 0):
236 | f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode='constant')
237 |
238 | f0 *= pow(2, f0_up_key / 12)
239 | f0_mel = 1127 * np.log(1 + f0 / 700)
240 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
241 | f0_mel[f0_mel <= 1] = 1
242 | f0_mel[f0_mel > 255] = 255
243 | f0_coarse = np.rint(f0_mel).astype(np.int)
244 | return f0_coarse, f0
245 |
246 |
247 | def clean_pitch(input_pitch):
248 | num_nan = np.sum(input_pitch == 1)
249 | if num_nan / len(input_pitch) > 0.9:
250 | input_pitch[input_pitch != 1] = 1
251 | return input_pitch
252 |
253 |
254 | def plt_pitch(input_pitch):
255 | input_pitch = input_pitch.astype(float)
256 | input_pitch[input_pitch == 1] = np.nan
257 | return input_pitch
258 |
259 |
260 | def f0_to_pitch(ff):
261 | f0_pitch = 69 + 12 * np.log2(ff / 440)
262 | return f0_pitch
263 |
264 |
265 | def fill_a_to_b(a, b):
266 | if len(a) < len(b):
267 | for _ in range(0, len(b) - len(a)):
268 | a.append(a[0])
269 |
--------------------------------------------------------------------------------
/modules/ui.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import gradio as gr
4 |
5 | import modules.sovits_model as sovits_model
6 | import modules.vits_model as vits_model
7 | from modules.localization import gen_localization_js
8 | from modules.options import opts
9 | from modules.process import text2speech, sovits_process
10 | from modules.utils import open_folder
11 | from modules.vits_model import get_model_list, refresh_list
12 |
13 | refresh_symbol = "\U0001f504" # 🔄
14 | folder_symbol = '\U0001f4c2' # 📂
15 |
16 | _gradio_template_response_orig = gr.routes.templates.TemplateResponse
17 | script_path = "scripts"
18 |
19 |
20 | class ToolButton(gr.Button, gr.components.FormComponent):
21 | def __init__(self, **kwargs):
22 | super().__init__(variant="tool", **kwargs)
23 |
24 | def get_block_name(self):
25 | return "button"
26 |
27 |
28 | def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
29 | def refresh():
30 | refresh_method()
31 | args = refreshed_args() if callable(refreshed_args) else refreshed_args
32 |
33 | for k, v in args.items():
34 | setattr(refresh_component, k, v)
35 |
36 | return gr.update(**(args or {}))
37 |
38 | refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
39 | refresh_button.click(
40 | fn=refresh,
41 | inputs=[],
42 | outputs=[refresh_component]
43 | )
44 | return refresh_button
45 |
46 |
47 | def create_setting_component(key):
48 | def fun():
49 | return opts.data[key] if key in opts.data else opts.data_labels[key].default
50 |
51 | info = opts.data_labels[key]
52 | t = type(info.default)
53 |
54 | args = info.component_args() if callable(info.component_args) else info.component_args
55 |
56 | if info.component is not None:
57 | comp = info.component
58 | elif t == str:
59 | comp = gr.Textbox
60 | elif t == int:
61 | comp = gr.Number
62 | elif t == bool:
63 | comp = gr.Checkbox
64 | else:
65 | raise Exception(f'bad options item type: {str(t)} for key {key}')
66 |
67 | elem_id = "setting_" + key
68 |
69 | if info.refresh is not None:
70 | with gr.Row():
71 | res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
72 | create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
73 | else:
74 | res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
75 |
76 | return res
77 |
78 |
79 | def change_vits_model(model_name):
80 | vits_model.load_model(model_name)
81 | speakers = vits_model.get_speakers()
82 | return gr.update(choices=speakers, value=speakers[0])
83 |
84 |
85 | def change_sovits_model(model_name):
86 | sovits_model.load_model(model_name)
87 | speakers = sovits_model.get_speakers()
88 | return gr.update(choices=speakers, value=speakers[0])
89 |
90 |
91 | def create_ui():
92 | css = "style.css"
93 | component_dict = {}
94 | reload_javascript()
95 |
96 | vits_model_list = vits_model.get_model_list()
97 | vits_speakers = vits_model.get_speakers()
98 |
99 | sovits_model_list = sovits_model.get_model_list()
100 | sovits_speakers = sovits_model.get_speakers()
101 |
102 | with gr.Blocks(analytics_enabled=False) as txt2img_interface:
103 | with gr.Row(elem_id="toprow"):
104 | with gr.Column(scale=6):
105 | with gr.Row():
106 | with gr.Column(scale=80):
107 | with gr.Row():
108 | input_text = gr.Textbox(label="Text", show_label=False, lines=4,
109 | placeholder="Text (press Ctrl+Enter or Alt+Enter to generate)")
110 | with gr.Column(scale=1):
111 | with gr.Row():
112 | vits_submit_btn = gr.Button("Generate", elem_id=f"vits_generate", variant="primary")
113 |
114 | with gr.Row().style(equal_height=False):
115 | with gr.Column(variant="panel", elem_id="vits_settings"):
116 | with gr.Row():
117 | vits_model_picker = gr.Dropdown(label="VITS Checkpoint", choices=vits_model_list,
118 | value=vits_model.get_model_name())
119 | create_refresh_button(vits_model_picker, refresh_method=vits_model.refresh_list(),
120 | refreshed_args=lambda: {"choices": vits_model.get_model_list()},
121 | elem_id="vits_model_refresh")
122 | with gr.Row():
123 | process_method = gr.Radio(label="Process Method",
124 | choices=["Simple", "Batch Process", "Multi Speakers"],
125 | value="Simple")
126 |
127 | with gr.Row():
128 | speaker_index = gr.Dropdown(label="Speakers",
129 | choices=vits_speakers, value=vits_speakers[0])
130 |
131 | speed = gr.Slider(value=1, minimum=0.5, maximum=2, step=0.1,
132 | elem_id=f"vits_speed",
133 | label="Speed")
134 |
135 | with gr.Column(variant="panel", elem_id="vits_output"):
136 | tts_output1 = gr.Textbox(label="Output Message")
137 | tts_output2 = gr.Audio(label="Output Audio", elem_id=f"vits_audio")
138 |
139 | with gr.Column():
140 | with gr.Row(elem_id=f"functional_buttons"):
141 | open_folder_button = gr.Button(f"{folder_symbol} Open Folder", elem_id=f'open_vits_folder')
142 | save_button = gr.Button('Save', elem_id=f'save')
143 |
144 | open_folder_button.click(fn=lambda: open_folder("outputs/vits"))
145 |
146 | vits_model_picker.change(
147 | fn=change_vits_model,
148 | inputs=[vits_model_picker],
149 | outputs=[speaker_index]
150 | )
151 |
152 | vits_submit_btn.click(
153 | fn=text2speech,
154 | inputs=[
155 | input_text,
156 | speaker_index,
157 | speed,
158 | process_method
159 | ],
160 | outputs=[
161 | tts_output1,
162 | tts_output2
163 | ]
164 | )
165 |
166 | with gr.Blocks(analytics_enabled=False) as sovits_interface:
167 | with gr.Row():
168 | with gr.Column(scale=6, elem_id="sovits_audio_panel"):
169 | sovits_audio_input = gr.File(label="Upload Audio File", elem_id=f"sovits_input_audio", file_count="multiple")
170 | with gr.Column(scale=1):
171 | with gr.Row():
172 | sovits_submit_btn = gr.Button("Generate", elem_id=f"sovits_generate", variant="primary")
173 |
174 | with gr.Row().style(equal_height=False):
175 | with gr.Column(variant="panel", elem_id="sovits_settings"):
176 | with gr.Row():
177 | sovits_model_picker = gr.Dropdown(label="SO-VITS Checkpoint", choices=sovits_model_list,
178 | value=sovits_model.get_model_name())
179 | create_refresh_button(sovits_model_picker, refresh_method=sovits_model.refresh_list(),
180 | refreshed_args=lambda: {"choices": sovits_model.get_model_list()},
181 | elem_id="sovits_model_refresh")
182 |
183 | with gr.Row():
184 | sovits_speaker_index = gr.Dropdown(label="Speakers",
185 | choices=sovits_speakers, value=sovits_speakers[0])
186 |
187 | with gr.Row():
188 | vc_transform = gr.Slider(value=0, minimum=-20, maximum=20, step=1,
189 | elem_id=f"vc_transform",
190 | label="VC Transform")
191 | slice_db = gr.Slider(value=-40, minimum=-100, maximum=0, step=5,
192 | elem_id=f"slice_db",
193 | label="Slice db")
194 | with gr.Column(variant="panel", elem_id="sovits_output"):
195 | sovits_output1 = gr.Textbox(label="Output Message")
196 | sovits_output2 = gr.Audio(label="Output Audio", elem_id=f"sovits_output_audio")
197 |
198 | sovits_submit_btn.click(
199 | fn=sovits_process,
200 | inputs=[sovits_audio_input, sovits_speaker_index, vc_transform, slice_db],
201 | outputs=[sovits_output1, sovits_output2]
202 | )
203 |
204 | sovits_model_picker.change(
205 | fn=change_sovits_model,
206 | inputs=[sovits_model_picker],
207 | outputs=[sovits_speaker_index]
208 | )
209 |
210 | with gr.Blocks(analytics_enabled=False) as settings_interface:
211 | settings_component = []
212 |
213 | def run_settings(*args):
214 | changed = []
215 |
216 | for key, value, comp in zip(opts.data_labels.keys(), args, settings_component):
217 | assert opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
218 |
219 | for key, value, comp in zip(opts.data_labels.keys(), args, settings_component):
220 | if opts.set(key, value):
221 | changed.append(key)
222 |
223 | try:
224 | opts.save()
225 | except RuntimeError:
226 | return f'{len(changed)} settings changed without save: {", ".join(changed)}.'
227 | return f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
228 |
229 | with gr.Row():
230 | with gr.Column(scale=6):
231 | settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
232 | with gr.Column():
233 | restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
234 |
235 | settings_result = gr.HTML(elem_id="settings_result")
236 |
237 | for i, (k, item) in enumerate(opts.data_labels.items()):
238 | component = create_setting_component(k)
239 | component_dict[k] = component
240 | settings_component.append(component)
241 |
242 | settings_submit.click(
243 | fn=run_settings,
244 | inputs=settings_component,
245 | outputs=[settings_result],
246 | )
247 |
248 | interfaces = [
249 | (txt2img_interface, "VITS", "vits"),
250 | (sovits_interface, "SO-VITS", "sovits"),
251 | (settings_interface, "Settings", "settings")
252 | ]
253 |
254 | with gr.Blocks(css=css, analytics_enabled=False, title="VITS") as demo:
255 | with gr.Tabs(elem_id="tabs") as tabs:
256 | for interface, label, ifid in interfaces:
257 | with gr.TabItem(label, id=ifid, elem_id="tab_" + ifid):
258 | interface.render()
259 |
260 | component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
261 |
262 | # def get_settings_values():
263 | # return [getattr(opts, key) for key in component_keys]
264 | #
265 | # demo.load(
266 | # fn=get_settings_values,
267 | # inputs=[],
268 | # outputs=[component_dict[k] for k in component_keys],
269 | # )
270 |
271 | return demo
272 |
273 |
274 | def reload_javascript():
275 | scripts_list = [os.path.join(script_path, i) for i in os.listdir(script_path) if i.endswith(".js")]
276 | with open("script.js", "r", encoding="utf8") as jsfile:
277 | javascript = f''
278 |
279 | for path in scripts_list:
280 | with open(path, "r", encoding="utf8") as jsfile:
281 | javascript += f"\n"
282 |
283 | javascript += gen_localization_js(opts.localization)
284 |
285 | # todo: theme
286 | # if cmd_opts.theme is not None:
287 | # javascript += f"\n\n"
288 |
289 | def template_response(*args, **kwargs):
290 | res = _gradio_template_response_orig(*args, **kwargs)
291 | res.body = res.body.replace(
292 | b'', f'{javascript}'.encode("utf8"))
293 | res.init_headers()
294 | return res
295 |
296 | gr.routes.templates.TemplateResponse = template_response
297 |
--------------------------------------------------------------------------------
/style.css:
--------------------------------------------------------------------------------
1 | .container {
2 | max-width: 100%;
3 | }
4 |
5 | [id$=_generate] {
6 | min-height: 6.5em;
7 | margin-top: 0.5em;
8 | margin-right: 0.25em;
9 | }
10 |
11 | /*#sovits_audio_panel {*/
12 | /* min-height: 5.5em;*/
13 | /*}*/
14 |
15 | #sh {
16 | min-width: 2em;
17 | min-height: 2em;
18 | max-width: 2em;
19 | max-height: 2em;
20 | flex-grow: 0;
21 | padding-left: 0.25em;
22 | padding-right: 0.25em;
23 | margin: 0.1em 0;
24 | opacity: 0%;
25 | cursor: default;
26 | }
27 |
28 | .output-html p {
29 | margin: 0 0.5em;
30 | }
31 |
32 | .row > *,
33 | .row > .gr-form > * {
34 | min-width: min(120px, 100%);
35 | flex: 1 1 0%;
36 | }
37 |
38 | .performance {
39 | font-size: 0.85em;
40 | color: #444;
41 | }
42 |
43 | .performance p {
44 | display: inline-block;
45 | }
46 |
47 | .performance .time {
48 | margin-right: 0;
49 | }
50 |
51 | .performance .vram {
52 | }
53 |
54 | .justify-center.overflow-x-scroll {
55 | justify-content: left;
56 | }
57 |
58 | .justify-center.overflow-x-scroll button:first-of-type {
59 | margin-left: auto;
60 | }
61 |
62 | .justify-center.overflow-x-scroll button:last-of-type {
63 | margin-right: auto;
64 | }
65 |
66 | [id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder {
67 | min-width: 2.3em;
68 | height: 2.5em;
69 | flex-grow: 0;
70 | padding-left: 0.25em;
71 | padding-right: 0.25em;
72 | }
73 |
74 | #hidden_element {
75 | display: none;
76 | }
77 |
78 | [id$=_seed_row], [id$=_subseed_row] {
79 | gap: 0.5rem;
80 | padding: 0.6em;
81 | }
82 |
83 | [id$=_subseed_show_box] {
84 | min-width: auto;
85 | flex-grow: 0;
86 | }
87 |
88 | [id$=_subseed_show_box] > div {
89 | border: 0;
90 | height: 100%;
91 | }
92 |
93 | [id$=_subseed_show] {
94 | min-width: auto;
95 | flex-grow: 0;
96 | padding: 0;
97 | }
98 |
99 | [id$=_subseed_show] label {
100 | height: 100%;
101 | }
102 |
103 | #roll_col {
104 | min-width: unset !important;
105 | flex-grow: 0 !important;
106 | padding: 0.4em 0;
107 | }
108 |
109 | #roll_col > button {
110 | min-width: 2em;
111 | min-height: 2em;
112 | max-width: 2em;
113 | max-height: 2em;
114 | flex-grow: 0;
115 | padding-left: 0.25em;
116 | padding-right: 0.25em;
117 | margin: 0.1em 0;
118 | }
119 |
120 | .gr-form {
121 | background: transparent;
122 | }
123 |
124 | .my-4 {
125 | margin-top: 0;
126 | margin-bottom: 0;
127 | }
128 |
129 | #toprow div {
130 | border: none;
131 | gap: 0;
132 | background: transparent;
133 | }
134 |
135 | #resize_mode {
136 | flex: 1.5;
137 | }
138 |
139 | button {
140 | align-self: stretch !important;
141 | }
142 |
143 | .overflow-hidden, .gr-panel {
144 | overflow: visible !important;
145 | }
146 |
147 | #x_type, #y_type {
148 | max-width: 10em;
149 | }
150 |
151 |
152 | fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span {
153 | position: absolute;
154 | top: -0.7em;
155 | line-height: 1.2em;
156 | padding: 0;
157 | margin: 0 0.5em;
158 |
159 | background-color: white;
160 | box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white;
161 |
162 | z-index: 300;
163 | }
164 |
165 | .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span {
166 | background-color: rgb(31, 41, 55);
167 | box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
168 | }
169 |
170 | #txt2img_column_batch, #img2img_column_batch {
171 | min-width: min(13.5em, 100%) !important;
172 | }
173 |
174 | #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span {
175 | position: relative;
176 | border: none;
177 | margin-right: 8em;
178 | }
179 |
180 | #settings .gr-panel div.flex-col div.justify-between div {
181 | position: relative;
182 | z-index: 200;
183 | }
184 |
185 | #settings {
186 | display: block;
187 | }
188 |
189 | #settings > div {
190 | border: none;
191 | margin-left: 10em;
192 | }
193 |
194 | #settings > div.flex-wrap {
195 | float: left;
196 | display: block;
197 | margin-left: 0;
198 | width: 10em;
199 | }
200 |
201 | #settings > div.flex-wrap button {
202 | display: block;
203 | border: none;
204 | text-align: left;
205 | }
206 |
207 | #settings_result {
208 | height: 1.4em;
209 | margin: 0 1.2em;
210 | }
211 |
212 | input[type="range"] {
213 | margin: 0.5em 0 -0.3em 0;
214 | }
215 |
216 | #mask_bug_info {
217 | text-align: center;
218 | display: block;
219 | margin-top: -0.75em;
220 | margin-bottom: -0.75em;
221 | }
222 |
223 |
224 | .transition.opacity-20 {
225 | opacity: 1 !important;
226 | }
227 |
228 |
229 | .min-h-\[4rem\] {
230 | min-height: unset !important;
231 | }
232 |
233 | .progressDiv {
234 | width: 100%;
235 | height: 20px;
236 | background: #b4c0cc;
237 | border-radius: 8px;
238 | }
239 |
240 | .dark .progressDiv {
241 | background: #424c5b;
242 | }
243 |
244 | .progressDiv .progress {
245 | width: 0%;
246 | height: 20px;
247 | background: #0060df;
248 | color: white;
249 | font-weight: bold;
250 | line-height: 20px;
251 | padding: 0 8px 0 0;
252 | text-align: right;
253 | border-radius: 8px;
254 | }
255 |
256 | #lightboxModal {
257 | display: none;
258 | position: fixed;
259 | z-index: 1001;
260 | padding-top: 100px;
261 | left: 0;
262 | top: 0;
263 | width: 100%;
264 | height: 100%;
265 | overflow: auto;
266 | background-color: rgba(20, 20, 20, 0.95);
267 | user-select: none;
268 | -webkit-user-select: none;
269 | }
270 |
271 | .modalControls {
272 | display: grid;
273 | grid-template-columns: 32px 32px 32px 1fr 32px;
274 | grid-template-areas: "zoom tile save space close";
275 | position: absolute;
276 | top: 0;
277 | left: 0;
278 | right: 0;
279 | padding: 16px;
280 | gap: 16px;
281 | background-color: rgba(0, 0, 0, 0.2);
282 | }
283 |
284 | .modalClose {
285 | grid-area: close;
286 | }
287 |
288 | .modalZoom {
289 | grid-area: zoom;
290 | }
291 |
292 | .modalSave {
293 | grid-area: save;
294 | }
295 |
296 | .modalTileImage {
297 | grid-area: tile;
298 | }
299 |
300 | .modalClose,
301 | .modalZoom,
302 | .modalTileImage {
303 | color: white;
304 | font-size: 35px;
305 | font-weight: bold;
306 | cursor: pointer;
307 | }
308 |
309 | .modalSave {
310 | color: white;
311 | font-size: 28px;
312 | margin-top: 8px;
313 | font-weight: bold;
314 | cursor: pointer;
315 | }
316 |
317 | .modalClose:hover,
318 | .modalClose:focus,
319 | .modalSave:hover,
320 | .modalSave:focus,
321 | .modalZoom:hover,
322 | .modalZoom:focus {
323 | color: #999;
324 | text-decoration: none;
325 | cursor: pointer;
326 | }
327 |
328 | #modalImage {
329 | display: block;
330 | margin-left: auto;
331 | margin-right: auto;
332 | margin-top: auto;
333 | width: auto;
334 | }
335 |
336 | .modalImageFullscreen {
337 | object-fit: contain;
338 | height: 90%;
339 | }
340 |
341 | .modalPrev,
342 | .modalNext {
343 | cursor: pointer;
344 | position: absolute;
345 | top: 50%;
346 | width: auto;
347 | padding: 16px;
348 | margin-top: -50px;
349 | color: white;
350 | font-weight: bold;
351 | font-size: 20px;
352 | transition: 0.6s ease;
353 | border-radius: 0 3px 3px 0;
354 | user-select: none;
355 | -webkit-user-select: none;
356 | }
357 |
358 | .modalNext {
359 | right: 0;
360 | border-radius: 3px 0 0 3px;
361 | }
362 |
363 | .modalPrev:hover,
364 | .modalNext:hover {
365 | background-color: rgba(0, 0, 0, 0.8);
366 | }
367 |
368 | #imageARPreview {
369 | position: absolute;
370 | top: 0px;
371 | left: 0px;
372 | border: 2px solid red;
373 | background: rgba(255, 0, 0, 0.3);
374 | z-index: 900;
375 | pointer-events: none;
376 | display: none
377 | }
378 |
379 | .red {
380 | color: red;
381 | }
382 |
383 | .gallery-item {
384 | --tw-bg-opacity: 0 !important;
385 | }
386 |
387 | #context-menu {
388 | z-index: 9999;
389 | position: absolute;
390 | display: block;
391 | padding: 0px 0;
392 | border: 2px solid #a55000;
393 | border-radius: 8px;
394 | box-shadow: 1px 1px 2px #CE6400;
395 | width: 200px;
396 | }
397 |
398 | .context-menu-items {
399 | list-style: none;
400 | margin: 0;
401 | padding: 0;
402 | }
403 |
404 | .context-menu-items a {
405 | display: block;
406 | padding: 5px;
407 | cursor: pointer;
408 | }
409 |
410 | .context-menu-items a:hover {
411 | background: #a55000;
412 | }
413 |
414 | #quicksettings {
415 | gap: 0.4em;
416 | }
417 |
418 | #quicksettings > div {
419 | border: none;
420 | background: none;
421 | flex: unset;
422 | gap: 0.5em;
423 | }
424 |
425 | #quicksettings > div > div {
426 | max-width: 32em;
427 | min-width: 24em;
428 | padding: 0;
429 | }
430 |
431 | canvas[key="mask"] {
432 | z-index: 12 !important;
433 | filter: invert();
434 | mix-blend-mode: multiply;
435 | pointer-events: none;
436 | }
437 |
438 |
439 | /* gradio 3.4.1 stuff for editable scrollbar values */
440 | .gr-box > div > div > input.gr-text-input {
441 | position: absolute;
442 | right: 0.5em;
443 | top: -0.6em;
444 | z-index: 400;
445 | width: 8em;
446 | }
447 |
448 | #quicksettings .gr-box > div > div > input.gr-text-input {
449 | top: -1.12em;
450 | }
451 |
452 | .row.gr-compact {
453 | overflow: visible;
454 | }
455 |
456 | .gr-form {
457 | background-color: white;
458 | }
459 |
460 | .dark .gr-form {
461 | background-color: rgb(31 41 55 / var(--tw-bg-opacity));
462 | }
463 |
464 | .gr-button-tool {
465 | max-width: 2.5em;
466 | min-width: 2.5em !important;
467 | height: 2.4em;
468 | margin: 0.55em 0;
469 | }
470 |
471 | #quicksettings .gr-button-tool {
472 | margin: 0;
473 | }
474 |
475 |
476 | #img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
477 | padding-top: 0.9em;
478 | }
479 |
480 | #img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form {
481 | border: none;
482 | padding-bottom: 0.5em;
483 | }
484 |
485 | footer {
486 | display: none !important;
487 | }
488 |
489 | #footer {
490 | text-align: center;
491 | }
492 |
493 | #footer div {
494 | display: inline-block;
495 | }
496 |
497 | #footer .versions {
498 | font-size: 85%;
499 | opacity: 0.85;
500 | }
501 |
502 | /* The following handles localization for right-to-left (RTL) languages like Arabic.
503 | The rtl media type will only be activated by the logic in javascript/localization.js.
504 | If you change anything above, you need to make sure it is RTL compliant by just running
505 | your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
506 | Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
507 | @media rtl {
508 | /* this part was added manually */
509 | :host {
510 | direction: rtl;
511 | }
512 |
513 | select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
514 | direction: ltr;
515 | }
516 |
517 | #script_list > label > select,
518 | #x_type > label > select,
519 | #y_type > label > select {
520 | direction: rtl;
521 | }
522 |
523 | .gr-radio, .gr-checkbox {
524 | margin-left: 0.25em;
525 | }
526 |
527 | /* automatically generated with few manual modifications */
528 | .performance .time {
529 | margin-right: unset;
530 | margin-left: 0;
531 | }
532 |
533 | .justify-center.overflow-x-scroll {
534 | justify-content: right;
535 | }
536 |
537 | .justify-center.overflow-x-scroll button:first-of-type {
538 | margin-left: unset;
539 | margin-right: auto;
540 | }
541 |
542 | .justify-center.overflow-x-scroll button:last-of-type {
543 | margin-right: unset;
544 | margin-left: auto;
545 | }
546 |
547 | #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span {
548 | margin-right: unset;
549 | margin-left: 8em;
550 | }
551 |
552 | #txt2img_progressbar, #img2img_progressbar, #ti_progressbar {
553 | right: unset;
554 | left: 0;
555 | }
556 |
557 | .progressDiv .progress {
558 | padding: 0 0 0 8px;
559 | text-align: left;
560 | }
561 |
562 | #lightboxModal {
563 | left: unset;
564 | right: 0;
565 | }
566 |
567 | .modalPrev, .modalNext {
568 | border-radius: 3px 0 0 3px;
569 | }
570 |
571 | .modalNext {
572 | right: unset;
573 | left: 0;
574 | border-radius: 0 3px 3px 0;
575 | }
576 |
577 | #imageARPreview {
578 | left: unset;
579 | right: 0px;
580 | }
581 |
582 | #txt2img_skip, #img2img_skip {
583 | right: unset;
584 | left: 0px;
585 | }
586 |
587 | #context-menu {
588 | box-shadow: -1px 1px 2px #CE6400;
589 | }
590 |
591 | .gr-box > div > div > input.gr-text-input {
592 | right: unset;
593 | left: 0.5em;
594 | }
595 | }
--------------------------------------------------------------------------------
/vits/attentions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from . import commons
7 | from .modules import LayerNorm
8 |
9 |
10 | class Encoder(nn.Module):
11 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4,
12 | **kwargs):
13 | super().__init__()
14 | self.hidden_channels = hidden_channels
15 | self.filter_channels = filter_channels
16 | self.n_heads = n_heads
17 | self.n_layers = n_layers
18 | self.kernel_size = kernel_size
19 | self.p_dropout = p_dropout
20 | self.window_size = window_size
21 |
22 | self.drop = nn.Dropout(p_dropout)
23 | self.attn_layers = nn.ModuleList()
24 | self.norm_layers_1 = nn.ModuleList()
25 | self.ffn_layers = nn.ModuleList()
26 | self.norm_layers_2 = nn.ModuleList()
27 | for i in range(self.n_layers):
28 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
29 | window_size=window_size))
30 | self.norm_layers_1.append(LayerNorm(hidden_channels))
31 | self.ffn_layers.append(
32 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
33 | self.norm_layers_2.append(LayerNorm(hidden_channels))
34 |
35 | def forward(self, x, x_mask):
36 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
37 | x = x * x_mask
38 | for i in range(self.n_layers):
39 | y = self.attn_layers[i](x, x, attn_mask)
40 | y = self.drop(y)
41 | x = self.norm_layers_1[i](x + y)
42 |
43 | y = self.ffn_layers[i](x, x_mask)
44 | y = self.drop(y)
45 | x = self.norm_layers_2[i](x + y)
46 | x = x * x_mask
47 | return x
48 |
49 |
50 | class Decoder(nn.Module):
51 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
52 | proximal_bias=False, proximal_init=True, **kwargs):
53 | super().__init__()
54 | self.hidden_channels = hidden_channels
55 | self.filter_channels = filter_channels
56 | self.n_heads = n_heads
57 | self.n_layers = n_layers
58 | self.kernel_size = kernel_size
59 | self.p_dropout = p_dropout
60 | self.proximal_bias = proximal_bias
61 | self.proximal_init = proximal_init
62 |
63 | self.drop = nn.Dropout(p_dropout)
64 | self.self_attn_layers = nn.ModuleList()
65 | self.norm_layers_0 = nn.ModuleList()
66 | self.encdec_attn_layers = nn.ModuleList()
67 | self.norm_layers_1 = nn.ModuleList()
68 | self.ffn_layers = nn.ModuleList()
69 | self.norm_layers_2 = nn.ModuleList()
70 | for i in range(self.n_layers):
71 | self.self_attn_layers.append(
72 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout,
73 | proximal_bias=proximal_bias, proximal_init=proximal_init))
74 | self.norm_layers_0.append(LayerNorm(hidden_channels))
75 | self.encdec_attn_layers.append(
76 | MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
77 | self.norm_layers_1.append(LayerNorm(hidden_channels))
78 | self.ffn_layers.append(
79 | FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
80 | self.norm_layers_2.append(LayerNorm(hidden_channels))
81 |
82 | def forward(self, x, x_mask, h, h_mask):
83 | """
84 | x: decoder input
85 | h: encoder output
86 | """
87 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
88 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
89 | x = x * x_mask
90 | for i in range(self.n_layers):
91 | y = self.self_attn_layers[i](x, x, self_attn_mask)
92 | y = self.drop(y)
93 | x = self.norm_layers_0[i](x + y)
94 |
95 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
96 | y = self.drop(y)
97 | x = self.norm_layers_1[i](x + y)
98 |
99 | y = self.ffn_layers[i](x, x_mask)
100 | y = self.drop(y)
101 | x = self.norm_layers_2[i](x + y)
102 | x = x * x_mask
103 | return x
104 |
105 |
106 | class MultiHeadAttention(nn.Module):
107 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True,
108 | block_length=None, proximal_bias=False, proximal_init=False):
109 | super().__init__()
110 | assert channels % n_heads == 0
111 |
112 | self.channels = channels
113 | self.out_channels = out_channels
114 | self.n_heads = n_heads
115 | self.p_dropout = p_dropout
116 | self.window_size = window_size
117 | self.heads_share = heads_share
118 | self.block_length = block_length
119 | self.proximal_bias = proximal_bias
120 | self.proximal_init = proximal_init
121 | self.attn = None
122 |
123 | self.k_channels = channels // n_heads
124 | self.conv_q = nn.Conv1d(channels, channels, 1)
125 | self.conv_k = nn.Conv1d(channels, channels, 1)
126 | self.conv_v = nn.Conv1d(channels, channels, 1)
127 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
128 | self.drop = nn.Dropout(p_dropout)
129 |
130 | if window_size is not None:
131 | n_heads_rel = 1 if heads_share else n_heads
132 | rel_stddev = self.k_channels ** -0.5
133 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
134 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
135 |
136 | nn.init.xavier_uniform_(self.conv_q.weight)
137 | nn.init.xavier_uniform_(self.conv_k.weight)
138 | nn.init.xavier_uniform_(self.conv_v.weight)
139 | if proximal_init:
140 | with torch.no_grad():
141 | self.conv_k.weight.copy_(self.conv_q.weight)
142 | self.conv_k.bias.copy_(self.conv_q.bias)
143 |
144 | def forward(self, x, c, attn_mask=None):
145 | q = self.conv_q(x)
146 | k = self.conv_k(c)
147 | v = self.conv_v(c)
148 |
149 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
150 |
151 | x = self.conv_o(x)
152 | return x
153 |
154 | def attention(self, query, key, value, mask=None):
155 | # reshape [b, d, t] -> [b, n_h, t, d_k]
156 | b, d, t_s, t_t = (*key.size(), query.size(2))
157 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
158 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
159 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
160 |
161 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
162 | if self.window_size is not None:
163 | assert t_s == t_t, "Relative attention is only available for self-attention."
164 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
165 | rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
166 | scores_local = self._relative_position_to_absolute_position(rel_logits)
167 | scores = scores + scores_local
168 | if self.proximal_bias:
169 | assert t_s == t_t, "Proximal bias is only available for self-attention."
170 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
171 | if mask is not None:
172 | scores = scores.masked_fill(mask == 0, -1e4)
173 | if self.block_length is not None:
174 | assert t_s == t_t, "Local attention is only available for self-attention."
175 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
176 | scores = scores.masked_fill(block_mask == 0, -1e4)
177 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
178 | p_attn = self.drop(p_attn)
179 | output = torch.matmul(p_attn, value)
180 | if self.window_size is not None:
181 | relative_weights = self._absolute_position_to_relative_position(p_attn)
182 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
183 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
184 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
185 | return output, p_attn
186 |
187 | def _matmul_with_relative_values(self, x, y):
188 | """
189 | x: [b, h, l, m]
190 | y: [h or 1, m, d]
191 | ret: [b, h, l, d]
192 | """
193 | ret = torch.matmul(x, y.unsqueeze(0))
194 | return ret
195 |
196 | def _matmul_with_relative_keys(self, x, y):
197 | """
198 | x: [b, h, l, d]
199 | y: [h or 1, m, d]
200 | ret: [b, h, l, m]
201 | """
202 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
203 | return ret
204 |
205 | def _get_relative_embeddings(self, relative_embeddings, length):
206 | max_relative_position = 2 * self.window_size + 1
207 | # Pad first before slice to avoid using cond ops.
208 | pad_length = max(length - (self.window_size + 1), 0)
209 | slice_start_position = max((self.window_size + 1) - length, 0)
210 | slice_end_position = slice_start_position + 2 * length - 1
211 | if pad_length > 0:
212 | padded_relative_embeddings = F.pad(
213 | relative_embeddings,
214 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
215 | else:
216 | padded_relative_embeddings = relative_embeddings
217 | used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
218 | return used_relative_embeddings
219 |
220 | def _relative_position_to_absolute_position(self, x):
221 | """
222 | x: [b, h, l, 2*l-1]
223 | ret: [b, h, l, l]
224 | """
225 | batch, heads, length, _ = x.size()
226 | # Concat columns of pad to shift from relative to absolute indexing.
227 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
228 |
229 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
230 | x_flat = x.view([batch, heads, length * 2 * length])
231 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
232 |
233 | # Reshape and slice out the padded elements.
234 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
235 | return x_final
236 |
237 | def _absolute_position_to_relative_position(self, x):
238 | """
239 | x: [b, h, l, l]
240 | ret: [b, h, l, 2*l-1]
241 | """
242 | batch, heads, length, _ = x.size()
243 | # padd along column
244 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
245 | x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
246 | # add 0's in the beginning that will skew the elements after reshape
247 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
248 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
249 | return x_final
250 |
251 | def _attention_bias_proximal(self, length):
252 | """Bias for self-attention to encourage attention to close positions.
253 | Args:
254 | length: an integer scalar.
255 | Returns:
256 | a Tensor with shape [1, 1, length, length]
257 | """
258 | r = torch.arange(length, dtype=torch.float32)
259 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
260 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
261 |
262 |
263 | class FFN(nn.Module):
264 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None,
265 | causal=False):
266 | super().__init__()
267 | self.in_channels = in_channels
268 | self.out_channels = out_channels
269 | self.filter_channels = filter_channels
270 | self.kernel_size = kernel_size
271 | self.p_dropout = p_dropout
272 | self.activation = activation
273 | self.causal = causal
274 |
275 | if causal:
276 | self.padding = self._causal_padding
277 | else:
278 | self.padding = self._same_padding
279 |
280 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
281 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
282 | self.drop = nn.Dropout(p_dropout)
283 |
284 | def forward(self, x, x_mask):
285 | x = self.conv_1(self.padding(x * x_mask))
286 | if self.activation == "gelu":
287 | x = x * torch.sigmoid(1.702 * x)
288 | else:
289 | x = torch.relu(x)
290 | x = self.drop(x)
291 | x = self.conv_2(self.padding(x * x_mask))
292 | return x * x_mask
293 |
294 | def _causal_padding(self, x):
295 | if self.kernel_size == 1:
296 | return x
297 | pad_l = self.kernel_size - 1
298 | pad_r = 0
299 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
300 | x = F.pad(x, commons.convert_pad_shape(padding))
301 | return x
302 |
303 | def _same_padding(self, x):
304 | if self.kernel_size == 1:
305 | return x
306 | pad_l = (self.kernel_size - 1) // 2
307 | pad_r = self.kernel_size // 2
308 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
309 | x = F.pad(x, commons.convert_pad_shape(padding))
310 | return x
311 |
--------------------------------------------------------------------------------
/vits/modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10 | from torch.nn.utils import weight_norm, remove_weight_norm
11 |
12 | from . import commons
13 | from .commons import init_weights, get_padding
14 | from .transforms import piecewise_rational_quadratic_transform
15 |
16 | LRELU_SLOPE = 0.1
17 |
18 |
19 | class LayerNorm(nn.Module):
20 | def __init__(self, channels, eps=1e-5):
21 | super().__init__()
22 | self.channels = channels
23 | self.eps = eps
24 |
25 | self.gamma = nn.Parameter(torch.ones(channels))
26 | self.beta = nn.Parameter(torch.zeros(channels))
27 |
28 | def forward(self, x):
29 | x = x.transpose(1, -1)
30 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31 | return x.transpose(1, -1)
32 |
33 |
34 | class ConvReluNorm(nn.Module):
35 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36 | super().__init__()
37 | self.in_channels = in_channels
38 | self.hidden_channels = hidden_channels
39 | self.out_channels = out_channels
40 | self.kernel_size = kernel_size
41 | self.n_layers = n_layers
42 | self.p_dropout = p_dropout
43 | assert n_layers > 1, "Number of layers should be larger than 0."
44 |
45 | self.conv_layers = nn.ModuleList()
46 | self.norm_layers = nn.ModuleList()
47 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
48 | self.norm_layers.append(LayerNorm(hidden_channels))
49 | self.relu_drop = nn.Sequential(
50 | nn.ReLU(),
51 | nn.Dropout(p_dropout))
52 | for _ in range(n_layers - 1):
53 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
54 | self.norm_layers.append(LayerNorm(hidden_channels))
55 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56 | self.proj.weight.data.zero_()
57 | self.proj.bias.data.zero_()
58 |
59 | def forward(self, x, x_mask):
60 | x_org = x
61 | for i in range(self.n_layers):
62 | x = self.conv_layers[i](x * x_mask)
63 | x = self.norm_layers[i](x)
64 | x = self.relu_drop(x)
65 | x = x_org + self.proj(x)
66 | return x * x_mask
67 |
68 |
69 | class DDSConv(nn.Module):
70 | """
71 | Dilated and Depth-Separable Convolution
72 | """
73 |
74 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
75 | super().__init__()
76 | self.channels = channels
77 | self.kernel_size = kernel_size
78 | self.n_layers = n_layers
79 | self.p_dropout = p_dropout
80 |
81 | self.drop = nn.Dropout(p_dropout)
82 | self.convs_sep = nn.ModuleList()
83 | self.convs_1x1 = nn.ModuleList()
84 | self.norms_1 = nn.ModuleList()
85 | self.norms_2 = nn.ModuleList()
86 | for i in range(n_layers):
87 | dilation = kernel_size ** i
88 | padding = (kernel_size * dilation - dilation) // 2
89 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
90 | groups=channels, dilation=dilation, padding=padding
91 | ))
92 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
93 | self.norms_1.append(LayerNorm(channels))
94 | self.norms_2.append(LayerNorm(channels))
95 |
96 | def forward(self, x, x_mask, g=None):
97 | if g is not None:
98 | x = x + g
99 | for i in range(self.n_layers):
100 | y = self.convs_sep[i](x * x_mask)
101 | y = self.norms_1[i](y)
102 | y = F.gelu(y)
103 | y = self.convs_1x1[i](y)
104 | y = self.norms_2[i](y)
105 | y = F.gelu(y)
106 | y = self.drop(y)
107 | x = x + y
108 | return x * x_mask
109 |
110 |
111 | class WN(torch.nn.Module):
112 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
113 | super(WN, self).__init__()
114 | assert (kernel_size % 2 == 1)
115 | self.hidden_channels = hidden_channels
116 | self.kernel_size = kernel_size,
117 | self.dilation_rate = dilation_rate
118 | self.n_layers = n_layers
119 | self.gin_channels = gin_channels
120 | self.p_dropout = p_dropout
121 |
122 | self.in_layers = torch.nn.ModuleList()
123 | self.res_skip_layers = torch.nn.ModuleList()
124 | self.drop = nn.Dropout(p_dropout)
125 |
126 | if gin_channels != 0:
127 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
128 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
129 |
130 | for i in range(n_layers):
131 | dilation = dilation_rate ** i
132 | padding = int((kernel_size * dilation - dilation) / 2)
133 | in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
134 | dilation=dilation, padding=padding)
135 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
136 | self.in_layers.append(in_layer)
137 |
138 | # last one is not necessary
139 | if i < n_layers - 1:
140 | res_skip_channels = 2 * hidden_channels
141 | else:
142 | res_skip_channels = hidden_channels
143 |
144 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
145 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
146 | self.res_skip_layers.append(res_skip_layer)
147 |
148 | def forward(self, x, x_mask, g=None, **kwargs):
149 | output = torch.zeros_like(x)
150 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
151 |
152 | if g is not None:
153 | g = self.cond_layer(g)
154 |
155 | for i in range(self.n_layers):
156 | x_in = self.in_layers[i](x)
157 | if g is not None:
158 | cond_offset = i * 2 * self.hidden_channels
159 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
160 | else:
161 | g_l = torch.zeros_like(x_in)
162 |
163 | acts = commons.fused_add_tanh_sigmoid_multiply(
164 | x_in,
165 | g_l,
166 | n_channels_tensor)
167 | acts = self.drop(acts)
168 |
169 | res_skip_acts = self.res_skip_layers[i](acts)
170 | if i < self.n_layers - 1:
171 | res_acts = res_skip_acts[:, :self.hidden_channels, :]
172 | x = (x + res_acts) * x_mask
173 | output = output + res_skip_acts[:, self.hidden_channels:, :]
174 | else:
175 | output = output + res_skip_acts
176 | return output * x_mask
177 |
178 | def remove_weight_norm(self):
179 | if self.gin_channels != 0:
180 | torch.nn.utils.remove_weight_norm(self.cond_layer)
181 | for l in self.in_layers:
182 | torch.nn.utils.remove_weight_norm(l)
183 | for l in self.res_skip_layers:
184 | torch.nn.utils.remove_weight_norm(l)
185 |
186 |
187 | class ResBlock1(torch.nn.Module):
188 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
189 | super(ResBlock1, self).__init__()
190 | self.convs1 = nn.ModuleList([
191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
192 | padding=get_padding(kernel_size, dilation[0]))),
193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
194 | padding=get_padding(kernel_size, dilation[1]))),
195 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
196 | padding=get_padding(kernel_size, dilation[2])))
197 | ])
198 | self.convs1.apply(init_weights)
199 |
200 | self.convs2 = nn.ModuleList([
201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
202 | padding=get_padding(kernel_size, 1))),
203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
204 | padding=get_padding(kernel_size, 1))),
205 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
206 | padding=get_padding(kernel_size, 1)))
207 | ])
208 | self.convs2.apply(init_weights)
209 |
210 | def forward(self, x, x_mask=None):
211 | for c1, c2 in zip(self.convs1, self.convs2):
212 | xt = F.leaky_relu(x, LRELU_SLOPE)
213 | if x_mask is not None:
214 | xt = xt * x_mask
215 | xt = c1(xt)
216 | xt = F.leaky_relu(xt, LRELU_SLOPE)
217 | if x_mask is not None:
218 | xt = xt * x_mask
219 | xt = c2(xt)
220 | x = xt + x
221 | if x_mask is not None:
222 | x = x * x_mask
223 | return x
224 |
225 | def remove_weight_norm(self):
226 | for l in self.convs1:
227 | remove_weight_norm(l)
228 | for l in self.convs2:
229 | remove_weight_norm(l)
230 |
231 |
232 | class ResBlock2(torch.nn.Module):
233 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
234 | super(ResBlock2, self).__init__()
235 | self.convs = nn.ModuleList([
236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
237 | padding=get_padding(kernel_size, dilation[0]))),
238 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
239 | padding=get_padding(kernel_size, dilation[1])))
240 | ])
241 | self.convs.apply(init_weights)
242 |
243 | def forward(self, x, x_mask=None):
244 | for c in self.convs:
245 | xt = F.leaky_relu(x, LRELU_SLOPE)
246 | if x_mask is not None:
247 | xt = xt * x_mask
248 | xt = c(xt)
249 | x = xt + x
250 | if x_mask is not None:
251 | x = x * x_mask
252 | return x
253 |
254 | def remove_weight_norm(self):
255 | for l in self.convs:
256 | remove_weight_norm(l)
257 |
258 |
259 | class Log(nn.Module):
260 | def forward(self, x, x_mask, reverse=False, **kwargs):
261 | if not reverse:
262 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
263 | logdet = torch.sum(-y, [1, 2])
264 | return y, logdet
265 | else:
266 | x = torch.exp(x) * x_mask
267 | return x
268 |
269 |
270 | class Flip(nn.Module):
271 | def forward(self, x, *args, reverse=False, **kwargs):
272 | x = torch.flip(x, [1])
273 | if not reverse:
274 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
275 | return x, logdet
276 | else:
277 | return x
278 |
279 |
280 | class ElementwiseAffine(nn.Module):
281 | def __init__(self, channels):
282 | super().__init__()
283 | self.channels = channels
284 | self.m = nn.Parameter(torch.zeros(channels, 1))
285 | self.logs = nn.Parameter(torch.zeros(channels, 1))
286 |
287 | def forward(self, x, x_mask, reverse=False, **kwargs):
288 | if not reverse:
289 | y = self.m + torch.exp(self.logs) * x
290 | y = y * x_mask
291 | logdet = torch.sum(self.logs * x_mask, [1, 2])
292 | return y, logdet
293 | else:
294 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
295 | return x
296 |
297 |
298 | class ResidualCouplingLayer(nn.Module):
299 | def __init__(self,
300 | channels,
301 | hidden_channels,
302 | kernel_size,
303 | dilation_rate,
304 | n_layers,
305 | p_dropout=0,
306 | gin_channels=0,
307 | mean_only=False):
308 | assert channels % 2 == 0, "channels should be divisible by 2"
309 | super().__init__()
310 | self.channels = channels
311 | self.hidden_channels = hidden_channels
312 | self.kernel_size = kernel_size
313 | self.dilation_rate = dilation_rate
314 | self.n_layers = n_layers
315 | self.half_channels = channels // 2
316 | self.mean_only = mean_only
317 |
318 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
319 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
320 | gin_channels=gin_channels)
321 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
322 | self.post.weight.data.zero_()
323 | self.post.bias.data.zero_()
324 |
325 | def forward(self, x, x_mask, g=None, reverse=False):
326 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
327 | h = self.pre(x0) * x_mask
328 | h = self.enc(h, x_mask, g=g)
329 | stats = self.post(h) * x_mask
330 | if not self.mean_only:
331 | m, logs = torch.split(stats, [self.half_channels] * 2, 1)
332 | else:
333 | m = stats
334 | logs = torch.zeros_like(m)
335 |
336 | if not reverse:
337 | x1 = m + x1 * torch.exp(logs) * x_mask
338 | x = torch.cat([x0, x1], 1)
339 | logdet = torch.sum(logs, [1, 2])
340 | return x, logdet
341 | else:
342 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
343 | x = torch.cat([x0, x1], 1)
344 | return x
345 |
346 |
347 | class ConvFlow(nn.Module):
348 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
349 | super().__init__()
350 | self.in_channels = in_channels
351 | self.filter_channels = filter_channels
352 | self.kernel_size = kernel_size
353 | self.n_layers = n_layers
354 | self.num_bins = num_bins
355 | self.tail_bound = tail_bound
356 | self.half_channels = in_channels // 2
357 |
358 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
359 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
360 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
361 | self.proj.weight.data.zero_()
362 | self.proj.bias.data.zero_()
363 |
364 | def forward(self, x, x_mask, g=None, reverse=False):
365 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
366 | h = self.pre(x0)
367 | h = self.convs(h, x_mask, g=g)
368 | h = self.proj(h) * x_mask
369 |
370 | b, c, t = x0.shape
371 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
372 |
373 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
374 | unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels)
375 | unnormalized_derivatives = h[..., 2 * self.num_bins:]
376 |
377 | x1, logabsdet = piecewise_rational_quadratic_transform(x1,
378 | unnormalized_widths,
379 | unnormalized_heights,
380 | unnormalized_derivatives,
381 | inverse=reverse,
382 | tails='linear',
383 | tail_bound=self.tail_bound
384 | )
385 |
386 | x = torch.cat([x0, x1], 1) * x_mask
387 | logdet = torch.sum(logabsdet * x_mask, [1, 2])
388 | if not reverse:
389 | return x, logdet
390 | else:
391 | return x
392 |
--------------------------------------------------------------------------------
/vits/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from . import commons
7 | from . import modules
8 | from . import attentions
9 | from . import monotonic_align
10 |
11 | from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13 | from .commons import init_weights, get_padding
14 |
15 |
16 | class StochasticDurationPredictor(nn.Module):
17 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
18 | super().__init__()
19 | filter_channels = in_channels # it needs to be removed from future version.
20 | self.in_channels = in_channels
21 | self.filter_channels = filter_channels
22 | self.kernel_size = kernel_size
23 | self.p_dropout = p_dropout
24 | self.n_flows = n_flows
25 | self.gin_channels = gin_channels
26 |
27 | self.log_flow = modules.Log()
28 | self.flows = nn.ModuleList()
29 | self.flows.append(modules.ElementwiseAffine(2))
30 | for i in range(n_flows):
31 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
32 | self.flows.append(modules.Flip())
33 |
34 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
35 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
36 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
37 | self.post_flows = nn.ModuleList()
38 | self.post_flows.append(modules.ElementwiseAffine(2))
39 | for i in range(4):
40 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
41 | self.post_flows.append(modules.Flip())
42 |
43 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
44 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
45 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
46 | if gin_channels != 0:
47 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
48 |
49 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
50 | x = torch.detach(x)
51 | x = self.pre(x)
52 | if g is not None:
53 | g = torch.detach(g)
54 | x = x + self.cond(g)
55 | x = self.convs(x, x_mask)
56 | x = self.proj(x) * x_mask
57 |
58 | if not reverse:
59 | flows = self.flows
60 | assert w is not None
61 |
62 | logdet_tot_q = 0
63 | h_w = self.post_pre(w)
64 | h_w = self.post_convs(h_w, x_mask)
65 | h_w = self.post_proj(h_w) * x_mask
66 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
67 | z_q = e_q
68 | for flow in self.post_flows:
69 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
70 | logdet_tot_q += logdet_q
71 | z_u, z1 = torch.split(z_q, [1, 1], 1)
72 | u = torch.sigmoid(z_u) * x_mask
73 | z0 = (w - u) * x_mask
74 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
75 | logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q ** 2)) * x_mask, [1, 2]) - logdet_tot_q
76 |
77 | logdet_tot = 0
78 | z0, logdet = self.log_flow(z0, x_mask)
79 | logdet_tot += logdet
80 | z = torch.cat([z0, z1], 1)
81 | for flow in flows:
82 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
83 | logdet_tot = logdet_tot + logdet
84 | nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot
85 | return nll + logq # [b]
86 | else:
87 | flows = list(reversed(self.flows))
88 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
89 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
90 | for flow in flows:
91 | z = flow(z, x_mask, g=x, reverse=reverse)
92 | z0, z1 = torch.split(z, [1, 1], 1)
93 | logw = z0
94 | return logw
95 |
96 |
97 | class DurationPredictor(nn.Module):
98 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
99 | super().__init__()
100 |
101 | self.in_channels = in_channels
102 | self.filter_channels = filter_channels
103 | self.kernel_size = kernel_size
104 | self.p_dropout = p_dropout
105 | self.gin_channels = gin_channels
106 |
107 | self.drop = nn.Dropout(p_dropout)
108 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
109 | self.norm_1 = modules.LayerNorm(filter_channels)
110 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
111 | self.norm_2 = modules.LayerNorm(filter_channels)
112 | self.proj = nn.Conv1d(filter_channels, 1, 1)
113 |
114 | if gin_channels != 0:
115 | self.cond = nn.Conv1d(gin_channels, in_channels, 1)
116 |
117 | def forward(self, x, x_mask, g=None):
118 | x = torch.detach(x)
119 | if g is not None:
120 | g = torch.detach(g)
121 | x = x + self.cond(g)
122 | x = self.conv_1(x * x_mask)
123 | x = torch.relu(x)
124 | x = self.norm_1(x)
125 | x = self.drop(x)
126 | x = self.conv_2(x * x_mask)
127 | x = torch.relu(x)
128 | x = self.norm_2(x)
129 | x = self.drop(x)
130 | x = self.proj(x * x_mask)
131 | return x * x_mask
132 |
133 |
134 | class TextEncoder(nn.Module):
135 | def __init__(self,
136 | n_vocab,
137 | out_channels,
138 | hidden_channels,
139 | filter_channels,
140 | n_heads,
141 | n_layers,
142 | kernel_size,
143 | p_dropout,
144 | emotion_embedding):
145 | super().__init__()
146 | self.n_vocab = n_vocab
147 | self.out_channels = out_channels
148 | self.hidden_channels = hidden_channels
149 | self.filter_channels = filter_channels
150 | self.n_heads = n_heads
151 | self.n_layers = n_layers
152 | self.kernel_size = kernel_size
153 | self.p_dropout = p_dropout
154 | self.emotion_embedding = emotion_embedding
155 |
156 | if self.n_vocab != 0:
157 | self.emb = nn.Embedding(n_vocab, hidden_channels)
158 | if emotion_embedding:
159 | self.emo_proj = nn.Linear(1024, hidden_channels)
160 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
161 |
162 | self.encoder = attentions.Encoder(
163 | hidden_channels,
164 | filter_channels,
165 | n_heads,
166 | n_layers,
167 | kernel_size,
168 | p_dropout)
169 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
170 |
171 | def forward(self, x, x_lengths, emotion_embedding=None):
172 | if self.n_vocab != 0:
173 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
174 | if emotion_embedding is not None:
175 | x = x + self.emo_proj(emotion_embedding.unsqueeze(1))
176 | x = torch.transpose(x, 1, -1) # [b, h, t]
177 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
178 |
179 | x = self.encoder(x * x_mask, x_mask)
180 | stats = self.proj(x) * x_mask
181 |
182 | m, logs = torch.split(stats, self.out_channels, dim=1)
183 | return x, m, logs, x_mask
184 |
185 |
186 | class ResidualCouplingBlock(nn.Module):
187 | def __init__(self,
188 | channels,
189 | hidden_channels,
190 | kernel_size,
191 | dilation_rate,
192 | n_layers,
193 | n_flows=4,
194 | gin_channels=0):
195 | super().__init__()
196 | self.channels = channels
197 | self.hidden_channels = hidden_channels
198 | self.kernel_size = kernel_size
199 | self.dilation_rate = dilation_rate
200 | self.n_layers = n_layers
201 | self.n_flows = n_flows
202 | self.gin_channels = gin_channels
203 |
204 | self.flows = nn.ModuleList()
205 | for i in range(n_flows):
206 | self.flows.append(
207 | modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
208 | gin_channels=gin_channels, mean_only=True))
209 | self.flows.append(modules.Flip())
210 |
211 | def forward(self, x, x_mask, g=None, reverse=False):
212 | if not reverse:
213 | for flow in self.flows:
214 | x, _ = flow(x, x_mask, g=g, reverse=reverse)
215 | else:
216 | for flow in reversed(self.flows):
217 | x = flow(x, x_mask, g=g, reverse=reverse)
218 | return x
219 |
220 |
221 | class PosteriorEncoder(nn.Module):
222 | def __init__(self,
223 | in_channels,
224 | out_channels,
225 | hidden_channels,
226 | kernel_size,
227 | dilation_rate,
228 | n_layers,
229 | gin_channels=0):
230 | super().__init__()
231 | self.in_channels = in_channels
232 | self.out_channels = out_channels
233 | self.hidden_channels = hidden_channels
234 | self.kernel_size = kernel_size
235 | self.dilation_rate = dilation_rate
236 | self.n_layers = n_layers
237 | self.gin_channels = gin_channels
238 |
239 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
240 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
241 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
242 |
243 | def forward(self, x, x_lengths, g=None):
244 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
245 | x = self.pre(x) * x_mask
246 | x = self.enc(x, x_mask, g=g)
247 | stats = self.proj(x) * x_mask
248 | m, logs = torch.split(stats, self.out_channels, dim=1)
249 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
250 | return z, m, logs, x_mask
251 |
252 |
253 | class Generator(torch.nn.Module):
254 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
255 | upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
256 | super(Generator, self).__init__()
257 | self.num_kernels = len(resblock_kernel_sizes)
258 | self.num_upsamples = len(upsample_rates)
259 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
260 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
261 |
262 | self.ups = nn.ModuleList()
263 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
264 | self.ups.append(weight_norm(
265 | ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)),
266 | k, u, padding=(k - u) // 2)))
267 |
268 | self.resblocks = nn.ModuleList()
269 | for i in range(len(self.ups)):
270 | ch = upsample_initial_channel // (2 ** (i + 1))
271 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
272 | self.resblocks.append(resblock(ch, k, d))
273 |
274 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
275 | self.ups.apply(init_weights)
276 |
277 | if gin_channels != 0:
278 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
279 |
280 | def forward(self, x, g=None):
281 | x = self.conv_pre(x)
282 | if g is not None:
283 | x = x + self.cond(g)
284 |
285 | for i in range(self.num_upsamples):
286 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
287 | x = self.ups[i](x)
288 | xs = None
289 | for j in range(self.num_kernels):
290 | if xs is None:
291 | xs = self.resblocks[i * self.num_kernels + j](x)
292 | else:
293 | xs += self.resblocks[i * self.num_kernels + j](x)
294 | x = xs / self.num_kernels
295 | x = F.leaky_relu(x)
296 | x = self.conv_post(x)
297 | x = torch.tanh(x)
298 |
299 | return x
300 |
301 | def remove_weight_norm(self):
302 | print('Removing weight norm...')
303 | for l in self.ups:
304 | remove_weight_norm(l)
305 | for l in self.resblocks:
306 | l.remove_weight_norm()
307 |
308 |
309 | class DiscriminatorP(torch.nn.Module):
310 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
311 | super(DiscriminatorP, self).__init__()
312 | self.period = period
313 | self.use_spectral_norm = use_spectral_norm
314 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
315 | self.convs = nn.ModuleList([
316 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
317 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
318 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
319 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
320 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
321 | ])
322 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
323 |
324 | def forward(self, x):
325 | fmap = []
326 |
327 | # 1d to 2d
328 | b, c, t = x.shape
329 | if t % self.period != 0: # pad first
330 | n_pad = self.period - (t % self.period)
331 | x = F.pad(x, (0, n_pad), "reflect")
332 | t = t + n_pad
333 | x = x.view(b, c, t // self.period, self.period)
334 |
335 | for l in self.convs:
336 | x = l(x)
337 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
338 | fmap.append(x)
339 | x = self.conv_post(x)
340 | fmap.append(x)
341 | x = torch.flatten(x, 1, -1)
342 |
343 | return x, fmap
344 |
345 |
346 | class DiscriminatorS(torch.nn.Module):
347 | def __init__(self, use_spectral_norm=False):
348 | super(DiscriminatorS, self).__init__()
349 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
350 | self.convs = nn.ModuleList([
351 | norm_f(Conv1d(1, 16, 15, 1, padding=7)),
352 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
353 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
354 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
355 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
356 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
357 | ])
358 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
359 |
360 | def forward(self, x):
361 | fmap = []
362 |
363 | for l in self.convs:
364 | x = l(x)
365 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
366 | fmap.append(x)
367 | x = self.conv_post(x)
368 | fmap.append(x)
369 | x = torch.flatten(x, 1, -1)
370 |
371 | return x, fmap
372 |
373 |
374 | class MultiPeriodDiscriminator(torch.nn.Module):
375 | def __init__(self, use_spectral_norm=False):
376 | super(MultiPeriodDiscriminator, self).__init__()
377 | periods = [2, 3, 5, 7, 11]
378 |
379 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
380 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
381 | self.discriminators = nn.ModuleList(discs)
382 |
383 | def forward(self, y, y_hat):
384 | y_d_rs = []
385 | y_d_gs = []
386 | fmap_rs = []
387 | fmap_gs = []
388 | for i, d in enumerate(self.discriminators):
389 | y_d_r, fmap_r = d(y)
390 | y_d_g, fmap_g = d(y_hat)
391 | y_d_rs.append(y_d_r)
392 | y_d_gs.append(y_d_g)
393 | fmap_rs.append(fmap_r)
394 | fmap_gs.append(fmap_g)
395 |
396 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
397 |
398 |
399 | class SynthesizerTrn(nn.Module):
400 | """
401 | Synthesizer for Training
402 | """
403 |
404 | def __init__(self,
405 | n_vocab,
406 | spec_channels,
407 | segment_size,
408 | inter_channels,
409 | hidden_channels,
410 | filter_channels,
411 | n_heads,
412 | n_layers,
413 | kernel_size,
414 | p_dropout,
415 | resblock,
416 | resblock_kernel_sizes,
417 | resblock_dilation_sizes,
418 | upsample_rates,
419 | upsample_initial_channel,
420 | upsample_kernel_sizes,
421 | n_speakers=0,
422 | gin_channels=0,
423 | use_sdp=True,
424 | emotion_embedding=False,
425 | **kwargs):
426 |
427 | super().__init__()
428 | self.n_vocab = n_vocab
429 | self.spec_channels = spec_channels
430 | self.inter_channels = inter_channels
431 | self.hidden_channels = hidden_channels
432 | self.filter_channels = filter_channels
433 | self.n_heads = n_heads
434 | self.n_layers = n_layers
435 | self.kernel_size = kernel_size
436 | self.p_dropout = p_dropout
437 | self.resblock = resblock
438 | self.resblock_kernel_sizes = resblock_kernel_sizes
439 | self.resblock_dilation_sizes = resblock_dilation_sizes
440 | self.upsample_rates = upsample_rates
441 | self.upsample_initial_channel = upsample_initial_channel
442 | self.upsample_kernel_sizes = upsample_kernel_sizes
443 | self.segment_size = segment_size
444 | self.n_speakers = n_speakers
445 | self.gin_channels = gin_channels
446 |
447 | self.use_sdp = use_sdp
448 |
449 | self.enc_p = TextEncoder(n_vocab,
450 | inter_channels,
451 | hidden_channels,
452 | filter_channels,
453 | n_heads,
454 | n_layers,
455 | kernel_size,
456 | p_dropout,
457 | emotion_embedding)
458 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
459 | upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
460 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
461 | gin_channels=gin_channels)
462 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
463 |
464 | if use_sdp:
465 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
466 | else:
467 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
468 |
469 | if n_speakers > 1:
470 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
471 |
472 | def forward(self, x, x_lengths, y, y_lengths, sid=None, emotion_embedding=None):
473 |
474 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding)
475 | if self.n_speakers > 1:
476 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
477 | else:
478 | g = None
479 |
480 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
481 | z_p = self.flow(z, y_mask, g=g)
482 |
483 | with torch.no_grad():
484 | # negative cross-entropy
485 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
486 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
487 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2),
488 | s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
489 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
490 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
491 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
492 |
493 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
494 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
495 |
496 | w = attn.sum(2)
497 | if self.use_sdp:
498 | l_length = self.dp(x, x_mask, w, g=g)
499 | l_length = l_length / torch.sum(x_mask)
500 | else:
501 | logw_ = torch.log(w + 1e-6) * x_mask
502 | logw = self.dp(x, x_mask, g=g)
503 | l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask) # for averaging
504 |
505 | # expand prior
506 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
507 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
508 |
509 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
510 | o = self.dec(z_slice, g=g)
511 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
512 |
513 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None,
514 | emotion_embedding=None):
515 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding)
516 | if self.n_speakers > 1:
517 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
518 | else:
519 | g = None
520 |
521 | if self.use_sdp:
522 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
523 | else:
524 | logw = self.dp(x, x_mask, g=g)
525 | w = torch.exp(logw) * x_mask * length_scale
526 | w_ceil = torch.ceil(w)
527 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
528 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
529 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
530 | attn = commons.generate_path(w_ceil, attn_mask)
531 |
532 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
533 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1,
534 | 2) # [b, t', t], [b, t, d] -> [b, d, t']
535 |
536 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
537 | z = self.flow(z_p, y_mask, g=g, reverse=True)
538 | o = self.dec((z * y_mask)[:, :, :max_len], g=g)
539 | return o, attn, y_mask, (z, z_p, m_p, logs_p)
540 |
541 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
542 | assert self.n_speakers > 1, "n_speakers have to be larger than 1."
543 | g_src = self.emb_g(sid_src).unsqueeze(-1)
544 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
545 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
546 | z_p = self.flow(z, y_mask, g=g_src)
547 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
548 | o_hat = self.dec(z_hat * y_mask, g=g_tgt)
549 | return o_hat, y_mask, (z, z_p, z_hat)
550 |
--------------------------------------------------------------------------------