├── monotonic_align ├── monotonic_align │ └── temp ├── setup.py ├── __init__.py └── core.pyx ├── asset └── overall.png ├── filelists ├── val.list └── val.list.cleaned ├── text ├── symbols.py ├── __init__.py ├── frontend │ ├── normalizer │ │ ├── acronyms.py │ │ ├── abbrrviation.py │ │ ├── __init__.py │ │ ├── width.py │ │ ├── normalizer.py │ │ └── numbers.py │ ├── zh_normalization │ │ ├── __init__.py │ │ ├── quantifier.py │ │ ├── phonecode.py │ │ ├── constants.py │ │ ├── chronology.py │ │ ├── text_normlization.py │ │ └── num.py │ ├── __init__.py │ ├── punctuation.py │ ├── vocab.py │ ├── generate_lexicon.py │ ├── arpabet.py │ └── zh_frontend.py ├── LICENSE ├── paddle_zh_frontend.py ├── japanese.py ├── cjenglish.py ├── korean.py ├── english.py ├── cleaners.py └── mandarin.py ├── requirements.txt ├── Dockerfile ├── LICENSE ├── preprocess.py ├── losses.py ├── .gitignore ├── configs └── config_cjke.yaml ├── README.md ├── mel_processing.py ├── analysis.py ├── pqmf.py ├── yin.py ├── commons.py ├── app.py ├── transforms.py ├── utils.py ├── data_utils.py └── modules.py /monotonic_align/monotonic_align/temp: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asset/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/pits/HEAD/asset/overall.png -------------------------------------------------------------------------------- /filelists/val.list: -------------------------------------------------------------------------------- 1 | nen/nen218_001.wav|nen|[JA]はい。2人の仲が相変わらずいいのはなによりなんですけど……[JA] 2 | nen/nen303_014.wav|nen|[JA]はい、どうやら保科君の心の穴が広がってしまった可能性がありそうです[JA] 3 | nen/nen310_010.wav|nen|[JA]私だって別に、欠片の回収だけを目的にしているわけではありませんよ[JA] 4 | -------------------------------------------------------------------------------- /monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name='monotonic_align', 7 | ext_modules=cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | 6 | # cjke_cleaners2 7 | _pad = '_' 8 | _punctuation = ',.!?-~…' 9 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 10 | 11 | _extra = "ˌ%$" 12 | # Export all symbols: 13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_extra) 14 | 15 | # Special symbol ids 16 | SPACE_ID = symbols.index(" ") 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.29.21 2 | matplotlib==3.3.1 3 | tensorboard==2.0.0 4 | kiwipiepy==0.8.1 5 | librosa==0.8.0 6 | numpy==1.20.3 7 | scipy==1.6.3 8 | Unidecode==1.1.1 9 | omegaconf==2.1.0 10 | alias_free_torch==0.0.6 11 | phaseaug==0.0.2 12 | gradio==3.20.1 13 | unidecode==1.3.4 14 | pyopenjtalk==0.2.0 15 | jamo==0.4.1 16 | pypinyin==0.44.0 17 | jieba==0.42.1 18 | protobuf==3.19.0 19 | cn2an==0.5.17 20 | inflect==6.0.0 21 | eng_to_ipa==0.0.2 22 | ko_pron==1.3 23 | opencc==1.1.1 24 | pypinyin_dict 25 | g2p_en 26 | g2pM -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | from text.symbols import symbols 2 | 3 | 4 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 5 | 6 | def cleaned_text_to_sequence(cleaned_text): 7 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 8 | Args: 9 | text: string to convert to a sequence 10 | Returns: 11 | List of integers corresponding to the symbols in the text 12 | ''' 13 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 14 | return sequence 15 | -------------------------------------------------------------------------------- /text/frontend/normalizer/acronyms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /text/frontend/normalizer/abbrrviation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """ Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /text/frontend/zh_normalization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from text.frontend.zh_normalization.text_normlization import * 15 | -------------------------------------------------------------------------------- /text/frontend/normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from text.frontend.normalizer.normalizer import * 15 | from text.frontend.normalizer.numbers import * 16 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:21.06-py3 2 | RUN apt-get update && apt-get upgrade -y 3 | RUN mkdir /root/pits 4 | COPY . /root/pits 5 | RUN cd /root/pits && \ 6 | python3 -m pip uninstall -y \ 7 | tensorboard \ 8 | tensorboard-plugin-dlprof \ 9 | nvidia-tensorboard \ 10 | nvidia-tensorboard-plugin-dlprof \ 11 | jupyter-tensorboard \ 12 | && \ 13 | python3 -m pip --no-cache-dir install -r requirements.txt \ 14 | && \ 15 | apt update && \ 16 | apt install -y \ 17 | tmux \ 18 | htop \ 19 | ncdu && \ 20 | apt clean && \ 21 | apt autoremove && \ 22 | rm -rf /var/lib/apt/lists/* /tmp/* && \ 23 | cd /root/pits/monotonic_align && \ 24 | python3 setup.py build_ext --inplace 25 | WORKDIR /root/pits 26 | EXPOSE 6006 27 | -------------------------------------------------------------------------------- /filelists/val.list.cleaned: -------------------------------------------------------------------------------- 1 | nen/nen218_001.wav|nen|ha↓i. ɸɯ*↑taɾi↓no na↓kaga a↑ikawaɾa↓zɯ i↓i no↑wa na↓n^ijoɾina n↓desɯ*kedo……|1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 | nen/nen303_014.wav|nen|ha↓i, do↓ojaɾa ho↓ʃinakɯnno ko↑ko↓ɾono a↑na↓ga çi↑ɾogat#te ʃi↑ma↓t#ta ka↑nooseega a↑ɾi↓soodesɯ*|1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 3 | nen/nen310_010.wav|nen|wa↑taʃidat#te be↑tsɯn^i, ka↑keɾano ka↑iʃɯɯdakeo mo↑kɯ*tekin^i ʃi*↑te i↑ɾɯ wa↑kedewaaɾimase↓Njo|1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 4 | -------------------------------------------------------------------------------- /text/frontend/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .generate_lexicon import * 15 | from .normalizer import * 16 | # from .phonectic import * 17 | from .punctuation import * 18 | from .tone_sandhi import * 19 | from .vocab import * 20 | from .zh_normalization import * 21 | 22 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/frontend/punctuation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ["get_punctuations"] 16 | 17 | EN_PUNCT = [ 18 | " ", 19 | "-", 20 | "...", 21 | ",", 22 | ".", 23 | "?", 24 | "!", 25 | ] 26 | 27 | CN_PUNCT = ["、", ",", ";", ":", "。", "?", "!"] 28 | 29 | 30 | def get_punctuations(lang): 31 | if lang == "en": 32 | return EN_PUNCT 33 | elif lang == "cn": 34 | return CN_PUNCT 35 | else: 36 | raise ValueError(f"language {lang} Not supported") 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 anonymous-pits 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tqdm 4 | 5 | import text.cleaners 6 | from utils import load_filepaths_and_text 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--out_extension", default="cleaned") 11 | parser.add_argument("--text_index", default=2, type=int) 12 | parser.add_argument("--filelists", nargs="+", default=["filelists/train.list", "filelists/val.list"]) 13 | 14 | args = parser.parse_args() 15 | 16 | 17 | for filelist in args.filelists: 18 | print("START:", filelist) 19 | filepaths_and_text = load_filepaths_and_text(filelist) 20 | for i in tqdm.tqdm(range(len(filepaths_and_text))): 21 | original_text = filepaths_and_text[i][args.text_index] 22 | cleaned_text, lang_seq = text.cleaners._clean_text(original_text) 23 | filepaths_and_text[i][args.text_index] = cleaned_text 24 | filepaths_and_text[i].append(" ".join([str(lang) for lang in lang_seq])) 25 | 26 | new_filelist = filelist + "." + args.out_extension 27 | with open(new_filelist, "w", encoding="utf-8") as f: 28 | f.writelines(["|".join(x) + "\n" for x in filepaths_and_text]) 29 | -------------------------------------------------------------------------------- /text/frontend/zh_normalization/quantifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import num2str 17 | 18 | # 温度表达式,温度会影响负号的读法 19 | # -3°C 零下三度 20 | RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)') 21 | 22 | 23 | def replace_temperature(match) -> str: 24 | """ 25 | Args: 26 | match (re.Match) 27 | Returns: 28 | str 29 | """ 30 | sign = match.group(1) 31 | temperature = match.group(2) 32 | unit = match.group(3) 33 | sign: str = "零下" if sign else "" 34 | temperature: str = num2str(temperature) 35 | unit: str = "摄氏度" if unit == "摄氏度" else "度" 36 | result = f"{sign}{temperature}{unit}" 37 | return result 38 | -------------------------------------------------------------------------------- /monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /text/frontend/normalizer/width.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def full2half_width(ustr): 17 | half = [] 18 | for u in ustr: 19 | num = ord(u) 20 | if num == 0x3000: # 全角空格变半角 21 | num = 32 22 | elif 0xFF01 <= num <= 0xFF5E: 23 | num -= 0xfee0 24 | u = chr(num) 25 | half.append(u) 26 | return ''.join(half) 27 | 28 | 29 | def half2full_width(ustr): 30 | full = [] 31 | for u in ustr: 32 | num = ord(u) 33 | if num == 32: # 半角空格变全角 34 | num = 0x3000 35 | elif 0x21 <= num <= 0x7E: 36 | num += 0xfee0 37 | u = chr(num) # to unicode 38 | full.append(u) 39 | 40 | return ''.join(full) 41 | -------------------------------------------------------------------------------- /text/frontend/normalizer/normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | import unicodedata 16 | from builtins import str as unicode 17 | 18 | from text.frontend.normalizer.numbers import normalize_numbers 19 | 20 | 21 | def normalize(sentence): 22 | """ Normalize English text. 23 | """ 24 | # preprocessing 25 | sentence = unicode(sentence) 26 | sentence = normalize_numbers(sentence) 27 | sentence = ''.join( 28 | char for char in unicodedata.normalize('NFD', sentence) 29 | if unicodedata.category(char) != 'Mn') # Strip accents 30 | sentence = sentence.lower() 31 | sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence) 32 | sentence = sentence.replace("i.e.", "that is") 33 | sentence = sentence.replace("e.g.", "for example") 34 | return sentence 35 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import torch 3 | from torch.autograd import Function 4 | 5 | 6 | def feature_loss(fmap_r, fmap_g): 7 | loss = 0 8 | for dr, dg in zip(fmap_r, fmap_g): 9 | for rl, gl in zip(dr, dg): 10 | rl = rl.float().detach() 11 | gl = gl.float() 12 | loss += torch.mean(torch.abs(rl - gl)) 13 | 14 | return loss * 2 15 | 16 | 17 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 18 | loss = 0 19 | r_losses = [] 20 | g_losses = [] 21 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 22 | dr = dr.float() 23 | dg = dg.float() 24 | r_loss = torch.mean((1-dr)**2) 25 | g_loss = torch.mean(dg**2) 26 | loss += (r_loss + g_loss) 27 | r_losses.append(r_loss.item()) 28 | g_losses.append(g_loss.item()) 29 | 30 | return loss, r_losses, g_losses 31 | 32 | 33 | def generator_loss(disc_outputs): 34 | loss = 0 35 | gen_losses = [] 36 | for dg in disc_outputs: 37 | dg = dg.float() 38 | l = torch.mean((1-dg)**2) 39 | gen_losses.append(l) 40 | loss += l 41 | 42 | return loss, gen_losses 43 | 44 | 45 | def kl_loss(z_p, logs, m_p, logs_p, z_mask): 46 | """ 47 | z_p, logs: [b, h, t_t] 48 | m_p, logs_p: [b, h, t_t] 49 | """ 50 | z_p = z_p.float() 51 | logs = logs.float() 52 | m_p = m_p.float() 53 | logs_p = logs_p.float() 54 | z_mask = z_mask.float() 55 | 56 | kl = logs_p - logs - 0.5 57 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 58 | kl = torch.sum(kl * z_mask) 59 | l = kl / torch.sum(z_mask) 60 | return l 61 | 62 | 63 | class ReverseLayerF(Function): 64 | 65 | @staticmethod 66 | def forward(ctx, x, alpha): 67 | ctx.alpha = alpha 68 | 69 | return x.view_as(x) 70 | 71 | @staticmethod 72 | def backward(ctx, grad_output): 73 | output = grad_output.neg() * ctx.alpha 74 | 75 | return output, None 76 | -------------------------------------------------------------------------------- /text/frontend/zh_normalization/phonecode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import verbalize_digit 17 | 18 | # 规范化固话/手机号码 19 | # 手机 20 | # http://www.jihaoba.com/news/show/13680 21 | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 22 | # 联通:130、131、132、156、155、186、185、176 23 | # 电信:133、153、189、180、181、177 24 | RE_MOBILE_PHONE = re.compile( 25 | r"(? str: 34 | if mobile: 35 | sp_parts = phone_string.strip('+').split() 36 | result = ','.join( 37 | [verbalize_digit(part, alt_one=True) for part in sp_parts]) 38 | return result 39 | else: 40 | sil_parts = phone_string.split('-') 41 | result = ','.join( 42 | [verbalize_digit(part, alt_one=True) for part in sil_parts]) 43 | return result 44 | 45 | 46 | def replace_phone(match) -> str: 47 | """ 48 | Args: 49 | match (re.Match) 50 | Returns: 51 | str 52 | """ 53 | return phone2str(match.group(0), mobile=False) 54 | 55 | 56 | def replace_mobile(match) -> str: 57 | """ 58 | Args: 59 | match (re.Match) 60 | Returns: 61 | str 62 | """ 63 | return phone2str(match.group(0)) 64 | -------------------------------------------------------------------------------- /text/frontend/zh_normalization/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | import string 16 | 17 | from pypinyin.constants import SUPPORT_UCS4 18 | 19 | # 全角半角转换 20 | # 英文字符全角 -> 半角映射表 (num: 52) 21 | F2H_ASCII_LETTERS = { 22 | chr(ord(char) + 65248): char 23 | for char in string.ascii_letters 24 | } 25 | 26 | # 英文字符半角 -> 全角映射表 27 | H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()} 28 | 29 | # 数字字符全角 -> 半角映射表 (num: 10) 30 | F2H_DIGITS = {chr(ord(char) + 65248): char for char in string.digits} 31 | # 数字字符半角 -> 全角映射表 32 | H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()} 33 | 34 | # 标点符号全角 -> 半角映射表 (num: 32) 35 | F2H_PUNCTUATIONS = {chr(ord(char) + 65248): char for char in string.punctuation} 36 | # 标点符号半角 -> 全角映射表 37 | H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()} 38 | 39 | # 空格 (num: 1) 40 | F2H_SPACE = {'\u3000': ' '} 41 | H2F_SPACE = {' ': '\u3000'} 42 | 43 | # 非"有拼音的汉字"的字符串,可用于NSW提取 44 | if SUPPORT_UCS4: 45 | RE_NSW = re.compile(r'(?:[^' 46 | r'\u3007' # 〇 47 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 48 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 49 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 50 | r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF] 51 | r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F] 52 | r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D] 53 | r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F] 54 | r'])+') 55 | else: 56 | RE_NSW = re.compile( # pragma: no cover 57 | r'(?:[^' 58 | r'\u3007' # 〇 59 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 60 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 61 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 62 | r'])+') 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | #mac 132 | *DS_Store 133 | -------------------------------------------------------------------------------- /configs/config_cjke.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | log_interval: 50 # step unit 3 | eval_interval: 400 # step unit 4 | save_interval: 200 # epoch unit: 50 for baseline / 500 for fine-tuning 5 | seed: 1234 6 | epochs: 7000 7 | learning_rate: 2e-4 8 | betas: [0.8, 0.99] 9 | eps: 1e-9 10 | batch_size: 16 11 | fp16_run: True #False 12 | lr_decay: 0.999875 13 | segment_size: 8192 14 | c_mel: 45 15 | c_kl: 1.0 16 | c_vq: 1. 17 | c_commit: 0.2 18 | c_yin: 45. 19 | log_path: "logs" 20 | n_sample: 3 21 | alpha: 200 22 | keep_ckpts: 3 23 | 24 | data: 25 | data_path: "dataset" 26 | training_files: "filelists/train.list.cleaned" 27 | validation_files: "filelists/val.list.cleaned" 28 | languages: "en_US" 29 | text_cleaners: ["english_cleaners"] 30 | sampling_rate: 22050 31 | filter_length: 1024 32 | hop_length: 256 33 | win_length: 1024 34 | n_mel_channels: 80 35 | mel_fmin: 0.0 36 | mel_fmax: null 37 | add_blank: True 38 | speakers: ["nen", "p226", "p227", "p228", "p229", "p230", "p231", "p232", "p233", "p234", "p236", "p237", "p238", "p239", "p240", "p241", "p243", "p244", "p245", "p246", "p247", "p248", "p249", "p250", "p251", "p252", "p253", "p254", "p255", "p256", "p257", "p258", "p259", "p260", "p261", "p262", "p263", "p264", "p265", "p266", "p267", "p268", "p269", "p270", "p271", "p272", "p273", "p274", "p275", "p276", "p277", "p278", "p279", "p281", "p282", "p283", "p284", "p285", "p286", "p287", "p288", "p292", "p293", "p294", "p295", "p297", "p298", "p299", "p300", "p301", "p302", "p303", "p304", "p305", "p306", "p307", "p308", "p310", "p311", "p312", "p313", "p314", "p316", "p317", "p318", "p323", "p326", "p329", "p330", "p333", "p334", "p335", "p336", "p339", "p340", "p341", "p343", "p345", "p347", "p351", "p360", "p361", "p362", "p363", "p364", "p374", "p376", "s5"] 39 | persistent_workers: True 40 | midi_start: -5 41 | midi_end: 75 42 | midis: 80 43 | ying_window: 2048 44 | ying_hop: 256 45 | tau_max: 2048 46 | octave_range: 24 47 | 48 | model: 49 | inter_channels: 192 50 | hidden_channels: 192 51 | filter_channels: 768 52 | n_heads: 2 53 | n_layers: 6 54 | kernel_size: 3 55 | p_dropout: 0.1 56 | resblock: "1" 57 | resblock_kernel_sizes: [3,7,11] 58 | resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] 59 | upsample_rates: [8,8,2,2] 60 | upsample_initial_channel: 512 61 | upsample_kernel_sizes: [16,16,4,4] 62 | n_layers_q: 3 63 | use_spectral_norm: False 64 | gin_channels: 256 65 | codebook_size: 320 66 | yin_channels: 80 67 | yin_start: 15 # scope start bin in nansy = 1.5/8 68 | yin_scope: 50 # scope ratio in nansy = 5/8 69 | yin_shift_range: 15 # same as default start index of yingram 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PITS 2 | **PITS: Variational Pitch Inference without Fundamental Frequency for End-to-End Pitch-controllable TTS** 3 | 4 | **Abstract**: Previous pitch-controllable text-to-speech (TTS) models rely on directly modeling fundamental frequency, leading to low variance in synthesized speech. To address this issue, we propose PITS, an end-to-end pitch-controllable TTS model that utilizes variational inference to model pitch. Based on VITS, PITS incorporates the Yingram encoder, the Yingram decoder, and adversarial training of pitch-shifted synthesis to achieve pitch-controllability. Experiments demonstrate that PITS generates high-quality speech that is indistinguishable from ground truth speech and has high pitch-controllability without quality degradation. Code and audio samples will be available at https://github.com/anonymous-pits/pits. 5 | 6 | **Training code is uploaded.** 7 | 8 | **Demo and Checkpoint are uploaded at** [Hugging Face Space](https://huggingface.co/spaces/anonymous-pits/pits)🤗 9 | 10 | Audio samples are uploaded at [github.io](https://anonymous-pits.github.io/pits/). 11 | 12 | For the pitch-shifted Inference, we unify to use the notation in scope-shift, s, instead of pitch-shift. 13 | 14 | Preprint version contains some errors! Please wait for the update! 15 | 16 | ![overall](asset/overall.png) 17 | 18 | README IS WIP... 19 | ## Preprocess 20 | prepare filelist train.list/val.list 21 | 22 | support chinese[ZH] japanese[JA] english[EN] korean[KO] 23 | ``` 24 | python preprocess.py 25 | ``` 26 | 27 | ## Config 28 | + you need to modify speakers list in config/config_cjke.yaml 29 | + you can also modify the **keep_ckpts** and **log_path** 30 | + **data_path** is the root path of your data. 31 | 32 | ## Training 33 | download the pretrained checkpoint 34 | ``` 35 | wget https://huggingface.co/spaces/anonymous-pits/pits/resolve/main/logs/pits_vctk_AD_3000.pth 36 | ``` 37 | fine tuning the pretrained checkpoint 38 | ``` 39 | python train.py -c configs/config_cjke.yaml -m cjke -t pits_vctk_AD_3000.pth 40 | ``` 41 | training from scratch 42 | ``` 43 | python train.py -c configs/config_cjke.yaml -m cjke 44 | ``` 45 | resume from previous training checkpoint 46 | ``` 47 | python train.py -c configs/config_cjke.yaml -m cjke -r logs/cjke/cjke_3000.pth 48 | ``` 49 | 50 | ## References 51 | - Official VITS Implementation: https://github.com/jaywalnut310/vits 52 | - NANSY Implementation from dhchoi99: https://github.com/dhchoi99/NANSY 53 | - Official Avocodo Implementation: https://github.com/ncsoft/avocodo 54 | - Official PhaseAug Implementation: https://github.com/mindslab-ai/phaseaug 55 | - Tacotron Implementation from keithito: https://github.com/keithito/tacotron 56 | - CSTR VCTK Corpus (version 0.92): https://datashare.ed.ac.uk/handle/10283/3443 57 | - G2P for demo, g2p\_en from Kyubyong: https://github.com/Kyubyong/g2p 58 | -------------------------------------------------------------------------------- /text/frontend/normalizer/numbers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # number expansion is not that easy 15 | import re 16 | 17 | import inflect 18 | 19 | _inflect = inflect.engine() 20 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 21 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 22 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 23 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 24 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 25 | _number_re = re.compile(r'[0-9]+') 26 | 27 | 28 | def _remove_commas(m): 29 | return m.group(1).replace(',', '') 30 | 31 | 32 | def _expand_decimal_point(m): 33 | return m.group(1).replace('.', ' point ') 34 | 35 | 36 | def _expand_dollars(m): 37 | match = m.group(1) 38 | parts = match.split('.') 39 | if len(parts) > 2: 40 | return match + ' dollars' # Unexpected format 41 | dollars = int(parts[0]) if parts[0] else 0 42 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 43 | if dollars and cents: 44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 45 | cent_unit = 'cent' if cents == 1 else 'cents' 46 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 47 | elif dollars: 48 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 49 | return '%s %s' % (dollars, dollar_unit) 50 | elif cents: 51 | cent_unit = 'cent' if cents == 1 else 'cents' 52 | return '%s %s' % (cents, cent_unit) 53 | else: 54 | return 'zero dollars' 55 | 56 | 57 | def _expand_ordinal(m): 58 | return _inflect.number_to_words(m.group(0)) 59 | 60 | 61 | def _expand_number(m): 62 | num = int(m.group(0)) 63 | if num > 1000 and num < 3000: 64 | if num == 2000: 65 | return 'two thousand' 66 | elif num > 2000 and num < 2010: 67 | return 'two thousand ' + _inflect.number_to_words(num % 100) 68 | elif num % 100 == 0: 69 | return _inflect.number_to_words(num // 100) + ' hundred' 70 | else: 71 | return _inflect.number_to_words( 72 | num, andword='', zero='oh', group=2).replace(', ', ' ') 73 | else: 74 | return _inflect.number_to_words(num, andword='') 75 | 76 | 77 | def normalize_numbers(text): 78 | """ Normalize numbers in English text. 79 | """ 80 | text = re.sub(_comma_number_re, _remove_commas, text) 81 | text = re.sub(_pounds_re, r'\1 pounds', text) 82 | text = re.sub(_dollars_re, _expand_dollars, text) 83 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 84 | text = re.sub(_ordinal_re, _expand_ordinal, text) 85 | text = re.sub(_number_re, _expand_number, text) 86 | return text 87 | -------------------------------------------------------------------------------- /text/frontend/vocab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from collections import OrderedDict 15 | from typing import Iterable 16 | 17 | __all__ = ["Vocab"] 18 | 19 | 20 | class Vocab(object): 21 | """ Vocabulary. 22 | 23 | Args: 24 | symbols (Iterable[str]): Common symbols. 25 | padding_symbol (str, optional): Symbol for pad. Defaults to "". 26 | unk_symbol (str, optional): Symbol for unknow. Defaults to "" 27 | start_symbol (str, optional): Symbol for start. Defaults to "" 28 | end_symbol (str, optional): Symbol for end. Defaults to "" 29 | """ 30 | 31 | def __init__(self, 32 | symbols: Iterable[str], 33 | padding_symbol="", 34 | unk_symbol="", 35 | start_symbol="", 36 | end_symbol=""): 37 | self.special_symbols = OrderedDict() 38 | for i, item in enumerate( 39 | [padding_symbol, unk_symbol, start_symbol, end_symbol]): 40 | if item: 41 | self.special_symbols[item] = len(self.special_symbols) 42 | 43 | self.padding_symbol = padding_symbol 44 | self.unk_symbol = unk_symbol 45 | self.start_symbol = start_symbol 46 | self.end_symbol = end_symbol 47 | 48 | self.stoi = OrderedDict() 49 | self.stoi.update(self.special_symbols) 50 | 51 | for i, s in enumerate(symbols): 52 | if s not in self.stoi: 53 | self.stoi[s] = len(self.stoi) 54 | self.itos = {v: k for k, v in self.stoi.items()} 55 | 56 | def __len__(self): 57 | return len(self.stoi) 58 | 59 | @property 60 | def num_specials(self): 61 | """ The number of special symbols. 62 | """ 63 | return len(self.special_symbols) 64 | 65 | # special tokens 66 | @property 67 | def padding_index(self): 68 | """ The index of padding symbol 69 | """ 70 | return self.stoi.get(self.padding_symbol, -1) 71 | 72 | @property 73 | def unk_index(self): 74 | """The index of unknow symbol. 75 | """ 76 | return self.stoi.get(self.unk_symbol, -1) 77 | 78 | @property 79 | def start_index(self): 80 | """The index of start symbol. 81 | """ 82 | return self.stoi.get(self.start_symbol, -1) 83 | 84 | @property 85 | def end_index(self): 86 | """ The index of end symbol. 87 | """ 88 | return self.stoi.get(self.end_symbol, -1) 89 | 90 | def __repr__(self): 91 | fmt = "Vocab(size: {},\nstoi:\n{})" 92 | return fmt.format(len(self), self.stoi) 93 | 94 | def __str__(self): 95 | return self.__repr__() 96 | 97 | def lookup(self, symbol): 98 | """ The index that symbol correspond. 99 | """ 100 | return self.stoi[symbol] 101 | 102 | def reverse(self, index): 103 | """ The symbol thar index cottespond. 104 | """ 105 | return self.itos[index] 106 | 107 | def add_symbol(self, symbol): 108 | """ Add a new symbol in vocab. 109 | """ 110 | if symbol in self.stoi: 111 | return 112 | N = len(self.stoi) 113 | self.stoi[symbol] = N 114 | self.itos[N] = symbol 115 | 116 | def add_symbols(self, symbols): 117 | """ Add multiple symbols in vocab. 118 | """ 119 | for symbol in symbols: 120 | self.add_symbol(symbol) 121 | -------------------------------------------------------------------------------- /text/frontend/zh_normalization/chronology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import DIGITS 17 | from .num import num2str 18 | from .num import verbalize_cardinal 19 | from .num import verbalize_digit 20 | 21 | 22 | def _time_num2str(num_string: str) -> str: 23 | """A special case for verbalizing number in time.""" 24 | result = num2str(num_string.lstrip('0')) 25 | if num_string.startswith('0'): 26 | result = DIGITS['0'] + result 27 | return result 28 | 29 | 30 | # 时刻表达式 31 | RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])' 32 | r':([0-5][0-9])' 33 | r'(:([0-5][0-9]))?') 34 | 35 | # 时间范围,如8:30-12:30 36 | RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])' 37 | r':([0-5][0-9])' 38 | r'(:([0-5][0-9]))?' 39 | r'(~|-)' 40 | r'([0-1]?[0-9]|2[0-3])' 41 | r':([0-5][0-9])' 42 | r'(:([0-5][0-9]))?') 43 | 44 | 45 | def replace_time(match) -> str: 46 | """ 47 | Args: 48 | match (re.Match) 49 | Returns: 50 | str 51 | """ 52 | 53 | is_range = len(match.groups()) > 5 54 | 55 | hour = match.group(1) 56 | minute = match.group(2) 57 | second = match.group(4) 58 | 59 | if is_range: 60 | hour_2 = match.group(6) 61 | minute_2 = match.group(7) 62 | second_2 = match.group(9) 63 | 64 | result = f"{num2str(hour)}点" 65 | if minute.lstrip('0'): 66 | if int(minute) == 30: 67 | result += "半" 68 | else: 69 | result += f"{_time_num2str(minute)}分" 70 | if second and second.lstrip('0'): 71 | result += f"{_time_num2str(second)}秒" 72 | 73 | if is_range: 74 | result += "至" 75 | result += f"{num2str(hour_2)}点" 76 | if minute_2.lstrip('0'): 77 | if int(minute) == 30: 78 | result += "半" 79 | else: 80 | result += f"{_time_num2str(minute_2)}分" 81 | if second_2 and second_2.lstrip('0'): 82 | result += f"{_time_num2str(second_2)}秒" 83 | 84 | return result 85 | 86 | 87 | RE_DATE = re.compile(r'(\d{4}|\d{2})年' 88 | r'((0?[1-9]|1[0-2])月)?' 89 | r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?') 90 | 91 | 92 | def replace_date(match) -> str: 93 | """ 94 | Args: 95 | match (re.Match) 96 | Returns: 97 | str 98 | """ 99 | year = match.group(1) 100 | month = match.group(3) 101 | day = match.group(5) 102 | result = "" 103 | if year: 104 | result += f"{verbalize_digit(year)}年" 105 | if month: 106 | result += f"{verbalize_cardinal(month)}月" 107 | if day: 108 | result += f"{verbalize_cardinal(day)}{match.group(9)}" 109 | return result 110 | 111 | 112 | # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期 113 | RE_DATE2 = re.compile( 114 | r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])') 115 | 116 | 117 | def replace_date2(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | year = match.group(1) 125 | month = match.group(3) 126 | day = match.group(4) 127 | result = "" 128 | if year: 129 | result += f"{verbalize_digit(year)}年" 130 | if month: 131 | result += f"{verbalize_cardinal(month)}月" 132 | if day: 133 | result += f"{verbalize_cardinal(day)}日" 134 | return result 135 | -------------------------------------------------------------------------------- /text/paddle_zh_frontend.py: -------------------------------------------------------------------------------- 1 | 2 | from text.frontend.zh_frontend import Frontend 3 | frontend = Frontend() 4 | 5 | pu_symbols = ['!', '?', '…', ",", "."] 6 | replacements = [ 7 | (u"yu", u"u:"), (u"ü", u"u:"), (u"v", u"u:"), 8 | (u"yi", u"i"), (u"you", u"ㄧㄡ"), (u"y", u"i"), 9 | (u"wu", u"u"), (u"wong", u"ㄨㄥ"), (u"w", u"u"), 10 | ] 11 | 12 | table = [ 13 | # special cases 14 | (u"ju", u"ㄐㄩ"), (u"qu", u"ㄑㄩ"), (u"xu", u"ㄒㄩ"), 15 | (u"zhi", u"ㄓ"), (u"chi", u"ㄔ"), (u"shi", u"ㄕ"), (u"ri", u"ㄖ"), 16 | (u"zi", u"ㄗ"), (u"ci", u"ㄘ"), (u"si", u"ㄙ"), 17 | (u"r5", u"ㄦ"), 18 | 19 | # initials 20 | (u"b", u"ㄅ"), (u"p", u"ㄆ"), (u"m", u"ㄇ"), (u"f", u"ㄈ"), 21 | (u"d", u"ㄉ"), (u"t", u"ㄊ"), (u"n", u"ㄋ"), (u"l", u"ㄌ"), 22 | (u"g", u"ㄍ"), (u"k", u"ㄎ"), (u"h", u"ㄏ"), 23 | (u"j", u"ㄐ"), (u"q", u"ㄑ"), (u"x", u"ㄒ"), 24 | (u"zh", u"ㄓ"), (u"ch", u"ㄔ"), (u"sh", u"ㄕ"), (u"r", u"ㄖ"), 25 | (u"z", u"ㄗ"), (u"c", u"ㄘ"), (u"s", u"ㄙ"), 26 | 27 | # finals 28 | (u"i", u"ㄧ"), (u"u", u"ㄨ"), (u"u:", u"ㄩ"), 29 | (u"a", u"ㄚ"), (u"o", u"ㄛ"), (u"e", u"ㄜ"), (u"ê", u"ㄝ"), 30 | (u"ai", u"ㄞ"), (u"ei", u"ㄟ"), (u"ao", u"ㄠ"), (u"ou", u"ㄡ"), 31 | (u"an", u"ㄢ"), (u"en", u"ㄣ"), (u"ang", u"ㄤ"), (u"eng", u"ㄥ"), 32 | (u"er", u"ㄦ"), 33 | (u"ia", u"ㄧㄚ"), (u"io", u"ㄧㄛ"), (u"ie", u"ㄧㄝ"), (u"iai", u"ㄧㄞ"), 34 | (u"iao", u"ㄧㄠ"), (u"iu", u"ㄧㄡ"), (u"ian", u"ㄧㄢ"), 35 | (u"in", u"ㄧㄣ"), (u"iang", u"ㄧㄤ"), (u"ing", u"ㄧㄥ"), 36 | (u"ua", u"ㄨㄚ"), (u"uo", u"ㄨㄛ"), (u"uai", u"ㄨㄞ"), 37 | (u"ui", u"ㄨㄟ"), (u"uan", u"ㄨㄢ"), (u"un", u"ㄨㄣ"), 38 | (u"uang", u"ㄨㄤ"), (u"ong", u"ㄨㄥ"), 39 | (u"u:e", u"ㄩㄝ"), (u"u:an", u"ㄩㄢ"), (u"u:n", u"ㄩㄣ"), (u"iong", u"ㄩㄥ"), 40 | 41 | # tones 42 | (u"1", u"ˉ"), (u"2", u"ˊ"), 43 | (u"3", u"ˇ"), (u"4", u"ˋ"), 44 | (u"5", u"˙"), 45 | ] 46 | 47 | table.sort(key=lambda pair: len(pair[0]), reverse=True) 48 | replacements.extend(table) 49 | 50 | zh_dict = [i.strip() for i in open("text/zh_dict.dict").readlines()] 51 | zh_dict = {i.split("\t")[0]: i.split("\t")[1] for i in zh_dict} 52 | 53 | reversed_zh_dict = {} 54 | all_zh_phones = set() 55 | for k, v in zh_dict.items(): 56 | reversed_zh_dict[v] = k 57 | [all_zh_phones.add(i) for i in v.split(" ")] 58 | 59 | def bopomofo(pinyin): 60 | '''Convert a pinyin string to Bopomofo 61 | The optional tone info must be given as a number suffix, eg: 'ni3' 62 | ''' 63 | 64 | pinyin = pinyin.lower() 65 | for pair in replacements: 66 | pinyin = pinyin.replace(pair[0], pair[1]) 67 | 68 | return pinyin 69 | 70 | def phones_to_pinyins(phones): 71 | pinyins = '' 72 | accu_ph = [] 73 | for ph in phones: 74 | accu_ph.append(ph) 75 | if ph not in all_zh_phones: 76 | assert len(accu_ph) == 1 77 | pinyins += ph 78 | accu_ph = [] 79 | elif " ".join(accu_ph) in reversed_zh_dict.keys(): 80 | pinyins += " " + reversed_zh_dict[" ".join(accu_ph)] 81 | accu_ph = [] 82 | if not accu_ph==[]: 83 | print(accu_ph) 84 | return pinyins.strip() 85 | 86 | def pu_symbol_replace(data): 87 | chinaTab = ['!', '?', "…", ",", "。",'、', "..."] 88 | englishTab = ['!', '?', "…", ",", ".",",", "…"] 89 | for index in range(len(chinaTab)): 90 | if chinaTab[index] in data: 91 | data = data.replace(chinaTab[index], englishTab[index]) 92 | return data 93 | 94 | def zh_to_bopomofo(text): 95 | phones = zh_to_phonemes(text) 96 | pinyins = phones_to_pinyins(phones) 97 | bopomofos = bopomofo(pinyins) 98 | return bopomofos.replace(" ", "").replace("#", " ") 99 | 100 | def pinyin_to_bopomofo(pinyin): 101 | bopomofos = bopomofo(pinyin) 102 | return bopomofos.replace(" ", "").replace("#", " ").replace("%", "% ") 103 | 104 | def zh_to_phonemes(text): 105 | # 替换标点为英文标点 106 | text = pu_symbol_replace(text) 107 | phones = frontend.get_phonemes(text)[0] 108 | return phones 109 | 110 | if __name__ == '__main__': 111 | print(zh_to_bopomofo("替换标点为英文标点")) 112 | 113 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from torch.cuda.amp import autocast 6 | 7 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 8 | """ 9 | PARAMS 10 | ------ 11 | C: compression factor 12 | """ 13 | return torch.log(torch.clamp(x, min=clip_val) * C) 14 | 15 | 16 | def dynamic_range_decompression_torch(x, C=1): 17 | """ 18 | PARAMS 19 | ------ 20 | C: compression factor used to compress 21 | """ 22 | return torch.exp(x) / C 23 | 24 | 25 | def spectral_normalize_torch(magnitudes): 26 | output = dynamic_range_compression_torch(magnitudes) 27 | return output 28 | 29 | 30 | def spectral_de_normalize_torch(magnitudes): 31 | output = dynamic_range_decompression_torch(magnitudes) 32 | return output 33 | 34 | 35 | mel_basis = {} 36 | hann_window = {} 37 | 38 | 39 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 40 | if torch.min(y) < -1.: 41 | print('min value is ', torch.min(y)) 42 | if torch.max(y) > 1.: 43 | print('max value is ', torch.max(y)) 44 | 45 | global hann_window 46 | dtype_device = str(y.dtype) + '_' + str(y.device) 47 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 48 | if wnsize_dtype_device not in hann_window: 49 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 50 | 51 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 52 | y = y.squeeze(1) 53 | with autocast(enabled=False): 54 | y=y.float() 55 | spec = torch.stft( 56 | y, 57 | n_fft, 58 | hop_length=hop_size, 59 | win_length=win_size, 60 | window=hann_window[wnsize_dtype_device], 61 | center=center, 62 | pad_mode='reflect', 63 | normalized=False, 64 | onesided=True 65 | ) 66 | 67 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 68 | return spec 69 | 70 | 71 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 72 | global mel_basis 73 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 74 | fmax_dtype_device = str(fmax) + '_' + dtype_device 75 | if fmax_dtype_device not in mel_basis: 76 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 77 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 78 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 79 | spec = spectral_normalize_torch(spec) 80 | return spec 81 | 82 | 83 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 84 | if torch.min(y) < -1.: 85 | print('min value is ', torch.min(y)) 86 | if torch.max(y) > 1.: 87 | print('max value is ', torch.max(y)) 88 | 89 | global mel_basis, hann_window 90 | dtype_device = str(y.dtype) + '_' + str(y.device) 91 | fmax_dtype_device = str(fmax) + '_' + dtype_device 92 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 93 | if fmax_dtype_device not in mel_basis: 94 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 95 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 96 | if wnsize_dtype_device not in hann_window: 97 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 98 | 99 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 100 | y = y.squeeze(1) 101 | with autocast(enabled=False): 102 | y=y.float() 103 | spec = torch.stft( 104 | y, 105 | n_fft, 106 | hop_length=hop_size, 107 | win_length=win_size, 108 | window=hann_window[wnsize_dtype_device], 109 | center=center, 110 | pad_mode='reflect', 111 | normalized=False, 112 | onesided=True 113 | ) 114 | 115 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 116 | 117 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 118 | spec = spectral_normalize_torch(spec) 119 | 120 | return spec 121 | -------------------------------------------------------------------------------- /text/frontend/zh_normalization/text_normlization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | from typing import List 16 | 17 | from .char_convert import tranditional_to_simplified 18 | from .chronology import RE_DATE 19 | from .chronology import RE_DATE2 20 | from .chronology import RE_TIME 21 | from .chronology import RE_TIME_RANGE 22 | from .chronology import replace_date 23 | from .chronology import replace_date2 24 | from .chronology import replace_time 25 | from .constants import F2H_ASCII_LETTERS 26 | from .constants import F2H_DIGITS 27 | from .constants import F2H_SPACE 28 | from .num import RE_DECIMAL_NUM 29 | from .num import RE_DEFAULT_NUM 30 | from .num import RE_FRAC 31 | from .num import RE_INTEGER 32 | from .num import RE_NUMBER 33 | from .num import RE_PERCENTAGE 34 | from .num import RE_POSITIVE_QUANTIFIERS 35 | from .num import RE_RANGE 36 | from .num import replace_default_num 37 | from .num import replace_frac 38 | from .num import replace_negative_num 39 | from .num import replace_number 40 | from .num import replace_percentage 41 | from .num import replace_positive_quantifier 42 | from .num import replace_range 43 | from .phonecode import RE_MOBILE_PHONE 44 | from .phonecode import RE_NATIONAL_UNIFORM_NUMBER 45 | from .phonecode import RE_TELEPHONE 46 | from .phonecode import replace_mobile 47 | from .phonecode import replace_phone 48 | from .quantifier import RE_TEMPERATURE 49 | from .quantifier import replace_temperature 50 | 51 | 52 | class TextNormalizer(): 53 | def __init__(self): 54 | self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!….][”’]?)') 55 | 56 | def _split(self, text: str, lang="zh") -> List[str]: 57 | """Split long text into sentences with sentence-splitting punctuations. 58 | Args: 59 | text (str): The input text. 60 | Returns: 61 | List[str]: Sentences. 62 | """ 63 | # Only for pure Chinese here 64 | if lang == "zh": 65 | text = text.replace(" ", "") 66 | # 过滤掉特殊字符 67 | text = re.sub(r'[《》【】<=>{}()()&@“”^_|\\]', '', text) 68 | text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) 69 | text = text.strip() 70 | sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] 71 | return sentences 72 | 73 | def _post_replace(self, sentence: str) -> str: 74 | sentence = sentence.replace('/', '每') 75 | sentence = sentence.replace('~', '至') 76 | 77 | return sentence 78 | 79 | def normalize_sentence(self, sentence: str) -> str: 80 | # basic character conversions 81 | sentence = tranditional_to_simplified(sentence) 82 | sentence = sentence.translate(F2H_ASCII_LETTERS).translate( 83 | F2H_DIGITS).translate(F2H_SPACE) 84 | 85 | # number related NSW verbalization 86 | sentence = RE_DATE.sub(replace_date, sentence) 87 | sentence = RE_DATE2.sub(replace_date2, sentence) 88 | 89 | # range first 90 | sentence = RE_TIME_RANGE.sub(replace_time, sentence) 91 | sentence = RE_TIME.sub(replace_time, sentence) 92 | 93 | sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) 94 | sentence = RE_FRAC.sub(replace_frac, sentence) 95 | sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) 96 | sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) 97 | 98 | sentence = RE_TELEPHONE.sub(replace_phone, sentence) 99 | sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence) 100 | 101 | sentence = RE_RANGE.sub(replace_range, sentence) 102 | sentence = RE_INTEGER.sub(replace_negative_num, sentence) 103 | sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) 104 | sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, 105 | sentence) 106 | sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) 107 | sentence = RE_NUMBER.sub(replace_number, sentence) 108 | sentence = self._post_replace(sentence) 109 | 110 | return sentence 111 | 112 | def normalize(self, text: str) -> List[str]: 113 | sentences = self._split(text) 114 | 115 | sentences = [self.normalize_sentence(sent) for sent in sentences] 116 | return sentences 117 | -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/dhchoi99/NANSY 2 | # We have modified the implementation of dhchoi99 to be fully differentiable. 3 | import math 4 | import torch 5 | from yin import * 6 | 7 | 8 | class Pitch(torch.nn.Module): 9 | 10 | def __init__( 11 | self, 12 | sr=22050, 13 | w_step=256, 14 | W=2048, 15 | tau_max=2048, 16 | midi_start=5, 17 | midi_end=85, 18 | octave_range=12): 19 | super(Pitch, self).__init__() 20 | self.sr = sr 21 | self.w_step = w_step 22 | self.W = W 23 | self.tau_max = tau_max 24 | self.unfold = torch.nn.Unfold((1, self.W), 25 | 1, 26 | 0, 27 | stride=(1, self.w_step)) 28 | midis = list(range(midi_start, midi_end)) 29 | self.len_midis = len(midis) 30 | c_ms = torch.tensor([self.midi_to_lag(m, octave_range) for m in midis]) 31 | self.register_buffer('c_ms', c_ms) 32 | self.register_buffer('c_ms_ceil', torch.ceil(self.c_ms).long()) 33 | self.register_buffer('c_ms_floor', torch.floor(self.c_ms).long()) 34 | 35 | def midi_to_lag(self, m: int, octave_range: float = 12): 36 | """converts midi-to-lag, eq. (4) 37 | 38 | Args: 39 | m: midi 40 | sr: sample_rate 41 | octave_range: 42 | 43 | Returns: 44 | lag: time lag(tau, c(m)) calculated from midi, eq. (4) 45 | 46 | """ 47 | f = 440 * math.pow(2, (m - 69) / octave_range) 48 | lag = self.sr / f 49 | return lag 50 | 51 | def yingram_from_cmndf(self, cmndfs: torch.Tensor) -> torch.Tensor: 52 | """ yingram calculator from cMNDFs(cumulative Mean Normalized Difference Functions) 53 | 54 | Args: 55 | cmndfs: torch.Tensor 56 | calculated cumulative mean normalized difference function 57 | for details, see models/yin.py or eq. (1) and (2) 58 | ms: list of midi(int) 59 | sr: sampling rate 60 | 61 | Returns: 62 | y: 63 | calculated batch yingram 64 | 65 | 66 | """ 67 | #c_ms = np.asarray([Pitch.midi_to_lag(m, sr) for m in ms]) 68 | #c_ms = torch.from_numpy(c_ms).to(cmndfs.device) 69 | 70 | y = (cmndfs[:, self.c_ms_ceil] - 71 | cmndfs[:, self.c_ms_floor]) / (self.c_ms_ceil - self.c_ms_floor).unsqueeze(0) * ( 72 | self.c_ms - self.c_ms_floor).unsqueeze(0) + cmndfs[:, self.c_ms_floor] 73 | return y 74 | 75 | def yingram(self, x: torch.Tensor): 76 | """calculates yingram from raw audio (multi segment) 77 | 78 | Args: 79 | x: raw audio, torch.Tensor of shape (t) 80 | W: yingram Window Size 81 | tau_max: 82 | sr: sampling rate 83 | w_step: yingram bin step size 84 | 85 | Returns: 86 | yingram: yingram. torch.Tensor of shape (80 x t') 87 | 88 | """ 89 | # x.shape: t -> B,T, B,T = x.shape 90 | B, T = x.shape 91 | w_len = self.W 92 | 93 | 94 | frames = self.unfold(x.view(B, 1, 1, T)) 95 | frames = frames.permute(0, 2, 96 | 1).contiguous().view(-1, 97 | self.W) #[B* frames, W] 98 | # If not using gpu, or torch not compatible, implemented numpy batch function is still fine 99 | dfs = differenceFunctionTorch(frames, frames.shape[-1], self.tau_max) 100 | cmndfs = cumulativeMeanNormalizedDifferenceFunctionTorch( 101 | dfs, self.tau_max) 102 | yingram = self.yingram_from_cmndf(cmndfs) #[B*frames,F] 103 | yingram = yingram.view(B, -1, self.len_midis).permute(0, 2, 104 | 1) # [B,F,T] 105 | return yingram 106 | 107 | def crop_scope(self, x, yin_start, 108 | scope_shift): # x: tensor [B,C,T] #scope_shift: tensor [B] 109 | return torch.stack([ 110 | x[i, yin_start + scope_shift[i]:yin_start + self.yin_scope + 111 | scope_shift[i], :] for i in range(x.shape[0]) 112 | ], 113 | dim=0) 114 | 115 | 116 | if __name__ == '__main__': 117 | import torch 118 | import librosa as rosa 119 | import matplotlib.pyplot as plt 120 | wav = torch.tensor(rosa.load('LJ001-0002.wav', sr=22050, 121 | mono=True)[0]).unsqueeze(0) 122 | # wav = torch.randn(1,40965) 123 | 124 | wav = torch.nn.functional.pad(wav, (0, (-wav.shape[1]) % 256)) 125 | # wav = wav[#:,:8096] 126 | print(wav.shape) 127 | pitch = Pitch() 128 | 129 | with torch.no_grad(): 130 | ps = pitch.yingram(torch.nn.functional.pad(wav, (1024, 1024))) 131 | ps = torch.nn.functional.pad(ps, (0, 0, 8, 8), mode='replicate') 132 | print(ps.shape) 133 | spec = torch.stft(wav, 1024, 256, return_complex=False) 134 | print(spec.shape) 135 | plt.subplot(2, 1, 1) 136 | plt.pcolor(ps[0].numpy(), cmap='magma') 137 | plt.colorbar() 138 | plt.subplot(2, 1, 2) 139 | plt.pcolor(ps[0][15:65, :].numpy(), cmap='magma') 140 | plt.colorbar() 141 | plt.show() 142 | -------------------------------------------------------------------------------- /pqmf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright 2020 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """Pseudo QMF modules.""" 7 | ''' 8 | Copied from https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/layers/pqmf.py 9 | ''' 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | from scipy.signal import kaiser 16 | 17 | 18 | def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0): 19 | """Design prototype filter for PQMF. 20 | This method is based on `A Kaiser window approach for the design of prototype 21 | filters of cosine modulated filterbanks`_. 22 | Args: 23 | taps (int): The number of filter taps. 24 | cutoff_ratio (float): Cut-off frequency ratio. 25 | beta (float): Beta coefficient for kaiser window. 26 | Returns: 27 | ndarray: Impluse response of prototype filter (taps + 1,). 28 | .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: 29 | https://ieeexplore.ieee.org/abstract/document/681427 30 | """ 31 | # check the arguments are valid 32 | assert taps % 2 == 0, "The number of taps mush be even number." 33 | assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." 34 | 35 | # make initial filter 36 | omega_c = np.pi * cutoff_ratio 37 | with np.errstate(invalid="ignore"): 38 | h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( 39 | np.pi * (np.arange(taps + 1) - 0.5 * taps) 40 | ) 41 | h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form 42 | 43 | # apply kaiser window 44 | w = kaiser(taps + 1, beta) 45 | h = h_i * w 46 | 47 | return h 48 | 49 | 50 | class PQMF(torch.nn.Module): 51 | """PQMF module. 52 | This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. 53 | .. _`Near-perfect-reconstruction pseudo-QMF banks`: 54 | https://ieeexplore.ieee.org/document/258122 55 | """ 56 | 57 | def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0): 58 | """Initilize PQMF module. 59 | The cutoff_ratio and beta parameters are optimized for #subbands = 4. 60 | See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195. 61 | Args: 62 | subbands (int): The number of subbands. 63 | taps (int): The number of filter taps. 64 | cutoff_ratio (float): Cut-off frequency ratio. 65 | beta (float): Beta coefficient for kaiser window. 66 | """ 67 | super(PQMF, self).__init__() 68 | 69 | # build analysis & synthesis filter coefficients 70 | h_proto = design_prototype_filter(taps, cutoff_ratio, beta) 71 | h_analysis = np.zeros((subbands, len(h_proto))) 72 | h_synthesis = np.zeros((subbands, len(h_proto))) 73 | for k in range(subbands): 74 | h_analysis[k] = ( 75 | 2 76 | * h_proto 77 | * np.cos( 78 | (2 * k + 1) 79 | * (np.pi / (2 * subbands)) 80 | * (np.arange(taps + 1) - (taps / 2)) 81 | + (-1) ** k * np.pi / 4 82 | ) 83 | ) 84 | h_synthesis[k] = ( 85 | 2 86 | * h_proto 87 | * np.cos( 88 | (2 * k + 1) 89 | * (np.pi / (2 * subbands)) 90 | * (np.arange(taps + 1) - (taps / 2)) 91 | - (-1) ** k * np.pi / 4 92 | ) 93 | ) 94 | 95 | # convert to tensor 96 | analysis_filter = torch.Tensor(h_analysis).float().unsqueeze(1) 97 | synthesis_filter = torch.Tensor(h_synthesis).float().unsqueeze(0) 98 | 99 | # register coefficients as beffer 100 | self.register_buffer("analysis_filter", analysis_filter) 101 | self.register_buffer("synthesis_filter", synthesis_filter) 102 | 103 | # filter for downsampling & upsampling 104 | updown_filter = torch.zeros((subbands, subbands, subbands)).float() 105 | for k in range(subbands): 106 | updown_filter[k, k, 0] = 1.0 107 | self.register_buffer("updown_filter", updown_filter) 108 | self.subbands = subbands 109 | 110 | # keep padding info 111 | self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) 112 | 113 | def analysis(self, x): 114 | """Analysis with PQMF. 115 | Args: 116 | x (Tensor): Input tensor (B, 1, T). 117 | Returns: 118 | Tensor: Output tensor (B, subbands, T // subbands). 119 | """ 120 | x = F.conv1d(self.pad_fn(x), self.analysis_filter) 121 | return F.conv1d(x, self.updown_filter, stride=self.subbands) 122 | 123 | def synthesis(self, x): 124 | """Synthesis with PQMF. 125 | Args: 126 | x (Tensor): Input tensor (B, subbands, T // subbands). 127 | Returns: 128 | Tensor: Output tensor (B, 1, T). 129 | """ 130 | # NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands. 131 | # Not sure this is the correct way, it is better to check again. 132 | # TODO(kan-bayashi): Understand the reconstruction procedure 133 | x = F.conv_transpose1d( 134 | x, self.updown_filter * self.subbands, stride=self.subbands 135 | ) 136 | return F.conv1d(self.pad_fn(x), self.synthesis_filter) 137 | -------------------------------------------------------------------------------- /text/japanese.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unidecode import unidecode 3 | import pyopenjtalk 4 | 5 | 6 | # Regular expression matching Japanese without punctuation marks: 7 | _japanese_characters = re.compile( 8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 9 | 10 | # Regular expression matching non-Japanese characters or punctuation marks: 11 | _japanese_marks = re.compile( 12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 13 | 14 | # List of (symbol, Japanese) pairs for marks: 15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 16 | ('%', 'パーセント') 17 | ]] 18 | 19 | # List of (romaji, ipa) pairs for marks: 20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 21 | ('ts', 'ʦ'), 22 | ('u', 'ɯ'), 23 | ('j', 'ʥ'), 24 | ('y', 'j'), 25 | ('ni', 'n^i'), 26 | ('nj', 'n^'), 27 | ('hi', 'çi'), 28 | ('hj', 'ç'), 29 | ('f', 'ɸ'), 30 | ('I', 'i*'), 31 | ('U', 'ɯ*'), 32 | ('r', 'ɾ') 33 | ]] 34 | 35 | # List of (romaji, ipa2) pairs for marks: 36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 37 | ('u', 'ɯ'), 38 | ('ʧ', 'tʃ'), 39 | ('j', 'dʑ'), 40 | ('y', 'j'), 41 | ('ni', 'n^i'), 42 | ('nj', 'n^'), 43 | ('hi', 'çi'), 44 | ('hj', 'ç'), 45 | ('f', 'ɸ'), 46 | ('I', 'i*'), 47 | ('U', 'ɯ*'), 48 | ('r', 'ɾ') 49 | ]] 50 | 51 | # List of (consonant, sokuon) pairs: 52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 53 | (r'Q([↑↓]*[kg])', r'k#\1'), 54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 55 | (r'Q([↑↓]*[sʃ])', r's\1'), 56 | (r'Q([↑↓]*[pb])', r'p#\1') 57 | ]] 58 | 59 | # List of (consonant, hatsuon) pairs: 60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 61 | (r'N([↑↓]*[pbm])', r'm\1'), 62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 63 | (r'N([↑↓]*[tdn])', r'n\1'), 64 | (r'N([↑↓]*[kg])', r'ŋ\1') 65 | ]] 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def japanese_to_romaji_with_accent(text): 75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 76 | text = symbols_to_japanese(text) 77 | sentences = re.split(_japanese_marks, text) 78 | marks = re.findall(_japanese_marks, text) 79 | text = '' 80 | for i, sentence in enumerate(sentences): 81 | if re.match(_japanese_characters, sentence): 82 | if text != '': 83 | text += ' ' 84 | labels = pyopenjtalk.extract_fullcontext(sentence) 85 | for n, label in enumerate(labels): 86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 87 | if phoneme not in ['sil', 'pau']: 88 | text += phoneme.replace('ch', 'ʧ').replace('sh', 89 | 'ʃ').replace('cl', 'Q') 90 | else: 91 | continue 92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 95 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: 97 | a2_next = -1 98 | else: 99 | a2_next = int( 100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 101 | # Accent phrase boundary 102 | if a3 == 1 and a2_next == 1: 103 | text += ' ' 104 | # Falling 105 | elif a1 == 0 and a2_next == a2 + 1: 106 | text += '↓' 107 | # Rising 108 | elif a2 == 1 and a2_next == 2: 109 | text += '↑' 110 | if i < len(marks): 111 | text += unidecode(marks[i]).replace(' ', '') 112 | return text 113 | 114 | 115 | def get_real_sokuon(text): 116 | for regex, replacement in _real_sokuon: 117 | text = re.sub(regex, replacement, text) 118 | return text 119 | 120 | 121 | def get_real_hatsuon(text): 122 | for regex, replacement in _real_hatsuon: 123 | text = re.sub(regex, replacement, text) 124 | return text 125 | 126 | 127 | def japanese_to_ipa(text): 128 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 129 | text = re.sub( 130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 131 | text = get_real_sokuon(text) 132 | text = get_real_hatsuon(text) 133 | for regex, replacement in _romaji_to_ipa: 134 | text = re.sub(regex, replacement, text) 135 | return text 136 | 137 | 138 | def japanese_to_ipa2(text): 139 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 140 | text = get_real_sokuon(text) 141 | text = get_real_hatsuon(text) 142 | for regex, replacement in _romaji_to_ipa2: 143 | text = re.sub(regex, replacement, text) 144 | return text 145 | 146 | 147 | def japanese_to_ipa3(text): 148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') 150 | text = re.sub( 151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) 153 | return text 154 | -------------------------------------------------------------------------------- /text/frontend/generate_lexicon.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Design principles: https://zhuanlan.zhihu.com/p/349600439 15 | """Generate lexicon and symbols for Mandarin Chinese phonology. 16 | The lexicon is used for Montreal Force Aligner. 17 | Note that syllables are used as word in this lexicon. Since syllables rather 18 | than words are used in transcriptions produced by `reorganize_baker.py`. 19 | We make this choice to better leverage other software for chinese text to 20 | pinyin tools like pypinyin. This is the convention for G2P in Chinese. 21 | """ 22 | import re 23 | from collections import OrderedDict 24 | 25 | INITIALS = [ 26 | 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh', 27 | 'r', 'z', 'c', 's', 'j', 'q', 'x' 28 | ] 29 | 30 | FINALS = [ 31 | 'a', 'ai', 'ao', 'an', 'ang', 'e', 'er', 'ei', 'en', 'eng', 'o', 'ou', 32 | 'ong', 'ii', 'iii', 'i', 'ia', 'iao', 'ian', 'iang', 'ie', 'io', 'iou', 33 | 'iong', 'in', 'ing', 'u', 'ua', 'uai', 'uan', 'uang', 'uei', 'uo', 'uen', 34 | 'ueng', 'v', 've', 'van', 'vn' 35 | ] 36 | 37 | SPECIALS = ['sil', 'sp'] 38 | 39 | 40 | def rule(C, V, R, T): 41 | """Generate a syllable given the initial, the final, erhua indicator, and tone. 42 | Orthographical rules for pinyin are applied. (special case for y, w, ui, un, iu) 43 | 44 | Note that in this system, 'ü' is alway written as 'v' when appeared in phoneme, but converted to 45 | 'u' in syllables when certain conditions are satisfied. 46 | 47 | 'i' is distinguished when appeared in phonemes, and separated into 3 categories, 'i', 'ii' and 'iii'. 48 | Erhua is is possibly applied to every finals, except for finals that already ends with 'r'. 49 | When a syllable is impossible or does not have any characters with this pronunciation, return None 50 | to filter it out. 51 | """ 52 | 53 | # 不可拼的音节, ii 只能和 z, c, s 拼 54 | if V in ["ii"] and (C not in ['z', 'c', 's']): 55 | return None 56 | # iii 只能和 zh, ch, sh, r 拼 57 | if V in ['iii'] and (C not in ['zh', 'ch', 'sh', 'r']): 58 | return None 59 | 60 | # 齐齿呼或者撮口呼不能和 f, g, k, h, zh, ch, sh, r, z, c, s 61 | if (V not in ['ii', 'iii']) and V[0] in ['i', 'v'] and ( 62 | C in ['f', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'r', 'z', 'c', 's']): 63 | return None 64 | 65 | # 撮口呼只能和 j, q, x l, n 拼 66 | if V.startswith("v"): 67 | # v, ve 只能和 j ,q , x, n, l 拼 68 | if V in ['v', 've']: 69 | if C not in ['j', 'q', 'x', 'n', 'l', '']: 70 | return None 71 | # 其他只能和 j, q, x 拼 72 | else: 73 | if C not in ['j', 'q', 'x', '']: 74 | return None 75 | 76 | # j, q, x 只能和齐齿呼或者撮口呼拼 77 | if (C in ['j', 'q', 'x']) and not ( 78 | (V not in ['ii', 'iii']) and V[0] in ['i', 'v']): 79 | return None 80 | 81 | # b, p ,m, f 不能和合口呼拼,除了 u 之外 82 | # bm p, m, f 不能和撮口呼拼 83 | if (C in ['b', 'p', 'm', 'f']) and ((V[0] in ['u', 'v'] and V != "u") or 84 | V == 'ong'): 85 | return None 86 | 87 | # ua, uai, uang 不能和 d, t, n, l, r, z, c, s 拼 88 | if V in ['ua', 'uai', 89 | 'uang'] and C in ['d', 't', 'n', 'l', 'r', 'z', 'c', 's']: 90 | return None 91 | 92 | # sh 和 ong 不能拼 93 | if V == 'ong' and C in ['sh']: 94 | return None 95 | 96 | # o 和 gkh, zh ch sh r z c s 不能拼 97 | if V == "o" and C in [ 98 | 'd', 't', 'n', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'r', 'z', 'c', 's' 99 | ]: 100 | return None 101 | 102 | # ueng 只是 weng 这个 ad-hoc 其他情况下都是 ong 103 | if V == 'ueng' and C != '': 104 | return 105 | 106 | # 非儿化的 er 只能单独存在 107 | if V == 'er' and C != '': 108 | return None 109 | 110 | if C == '': 111 | if V in ["i", "in", "ing"]: 112 | C = 'y' 113 | elif V == 'u': 114 | C = 'w' 115 | elif V.startswith('i') and V not in ["ii", "iii"]: 116 | C = 'y' 117 | V = V[1:] 118 | elif V.startswith('u'): 119 | C = 'w' 120 | V = V[1:] 121 | elif V.startswith('v'): 122 | C = 'yu' 123 | V = V[1:] 124 | else: 125 | if C in ['j', 'q', 'x']: 126 | if V.startswith('v'): 127 | V = re.sub('v', 'u', V) 128 | if V == 'iou': 129 | V = 'iu' 130 | elif V == 'uei': 131 | V = 'ui' 132 | elif V == 'uen': 133 | V = 'un' 134 | result = C + V 135 | 136 | # Filter er 不能再儿化 137 | if result.endswith('r') and R == 'r': 138 | return None 139 | 140 | # ii and iii, change back to i 141 | result = re.sub(r'i+', 'i', result) 142 | 143 | result = result + R + T 144 | return result 145 | 146 | 147 | def generate_lexicon(with_tone=False, with_erhua=False): 148 | """Generate lexicon for Mandarin Chinese.""" 149 | syllables = OrderedDict() 150 | 151 | for C in [''] + INITIALS: 152 | for V in FINALS: 153 | for R in [''] if not with_erhua else ['', 'r']: 154 | for T in [''] if not with_tone else ['1', '2', '3', '4', '5']: 155 | result = rule(C, V, R, T) 156 | if result: 157 | syllables[result] = f'{C} {V}{R}{T}' 158 | return syllables 159 | -------------------------------------------------------------------------------- /text/cjenglish.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | import inflect 21 | from unidecode import unidecode 22 | import eng_to_ipa as ipa 23 | _inflect = inflect.engine() 24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 29 | _number_re = re.compile(r'[0-9]+') 30 | 31 | # List of (regular expression, replacement) pairs for abbreviations: 32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 33 | ('mrs', 'misess'), 34 | ('mr', 'mister'), 35 | ('dr', 'doctor'), 36 | ('st', 'saint'), 37 | ('co', 'company'), 38 | ('jr', 'junior'), 39 | ('maj', 'major'), 40 | ('gen', 'general'), 41 | ('drs', 'doctors'), 42 | ('rev', 'reverend'), 43 | ('lt', 'lieutenant'), 44 | ('hon', 'honorable'), 45 | ('sgt', 'sergeant'), 46 | ('capt', 'captain'), 47 | ('esq', 'esquire'), 48 | ('ltd', 'limited'), 49 | ('col', 'colonel'), 50 | ('ft', 'fort'), 51 | ]] 52 | 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | 190 | print(english_to_lazy_ipa2("happy new-year! vits")) -------------------------------------------------------------------------------- /yin.py: -------------------------------------------------------------------------------- 1 | # remove np from https://github.com/dhchoi99/NANSY/blob/master/models/yin.py 2 | # adapted from https://github.com/patriceguyot/Yin 3 | # https://github.com/NVIDIA/mellotron/blob/master/yin.py 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from math import log2, ceil 8 | 9 | 10 | def differenceFunction(x, N, tau_max): 11 | """ 12 | Compute difference function of data x. This corresponds to equation (6) in [1] 13 | This solution is implemented directly with torch rfft. 14 | 15 | 16 | :param x: audio data (Tensor) 17 | :param N: length of data 18 | :param tau_max: integration window size 19 | :return: difference function 20 | :rtype: list 21 | """ 22 | 23 | #x = np.array(x, np.float64) #[B,T] 24 | assert x.dim() == 2 25 | b, w = x.shape 26 | if w < tau_max: 27 | x = F.pad(x, (tau_max - w - (tau_max - w) // 2, (tau_max - w) // 2), 28 | 'constant', 29 | mode='reflect') 30 | w = tau_max 31 | #x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum())) 32 | x_cumsum = torch.cat( 33 | [torch.zeros([b, 1], device=x.device), (x * x).cumsum(dim=1)], dim=1) 34 | size = w + tau_max 35 | p2 = (size // 32).bit_length() 36 | #p2 = ceil(log2(size+1 // 32)) 37 | nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32) 38 | size_pad = min(n * 2**p2 for n in nice_numbers if n * 2**p2 >= size) 39 | fc = torch.fft.rfft(x, size_pad) #[B,F] 40 | conv = torch.fft.irfft(fc * fc.conj())[:, :tau_max] 41 | return x_cumsum[:, w:w - tau_max: 42 | -1] + x_cumsum[:, w] - x_cumsum[:, :tau_max] - 2 * conv 43 | 44 | 45 | def differenceFunction_np(x, N, tau_max): 46 | """ 47 | Compute difference function of data x. This corresponds to equation (6) in [1] 48 | This solution is implemented directly with Numpy fft. 49 | 50 | 51 | :param x: audio data 52 | :param N: length of data 53 | :param tau_max: integration window size 54 | :return: difference function 55 | :rtype: list 56 | """ 57 | 58 | x = np.array(x, np.float64) 59 | w = x.size 60 | tau_max = min(tau_max, w) 61 | x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum())) 62 | size = w + tau_max 63 | p2 = (size // 32).bit_length() 64 | nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32) 65 | size_pad = min(x * 2**p2 for x in nice_numbers if x * 2**p2 >= size) 66 | fc = np.fft.rfft(x, size_pad) 67 | conv = np.fft.irfft(fc * fc.conjugate())[:tau_max] 68 | return x_cumsum[w:w - 69 | tau_max:-1] + x_cumsum[w] - x_cumsum[:tau_max] - 2 * conv 70 | 71 | 72 | def cumulativeMeanNormalizedDifferenceFunction(df, N, eps=1e-8): 73 | """ 74 | Compute cumulative mean normalized difference function (CMND). 75 | 76 | This corresponds to equation (8) in [1] 77 | 78 | :param df: Difference function 79 | :param N: length of data 80 | :return: cumulative mean normalized difference function 81 | :rtype: list 82 | """ 83 | #np.seterr(divide='ignore', invalid='ignore') 84 | # scipy method, assert df>0 for all element 85 | # cmndf = df[1:] * np.asarray(list(range(1, N))) / (np.cumsum(df[1:]).astype(float) + eps) 86 | B, _ = df.shape 87 | cmndf = df[:, 88 | 1:] * torch.arange(1, N, device=df.device, dtype=df.dtype).view( 89 | 1, -1) / (df[:, 1:].cumsum(dim=-1) + eps) 90 | return torch.cat( 91 | [torch.ones([B, 1], device=df.device, dtype=df.dtype), cmndf], dim=-1) 92 | 93 | 94 | def differenceFunctionTorch(xs: torch.Tensor, N, tau_max) -> torch.Tensor: 95 | """pytorch backend batch-wise differenceFunction 96 | has 1e-4 level error with input shape of (32, 22050*1.5) 97 | Args: 98 | xs: 99 | N: 100 | tau_max: 101 | 102 | Returns: 103 | 104 | """ 105 | xs = xs.double() 106 | w = xs.shape[-1] 107 | tau_max = min(tau_max, w) 108 | zeros = torch.zeros((xs.shape[0], 1)) 109 | x_cumsum = torch.cat((torch.zeros((xs.shape[0], 1), device=xs.device), 110 | (xs * xs).cumsum(dim=-1, dtype=torch.double)), 111 | dim=-1) # B x w 112 | size = w + tau_max 113 | p2 = (size // 32).bit_length() 114 | nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32) 115 | size_pad = min(x * 2**p2 for x in nice_numbers if x * 2**p2 >= size) 116 | 117 | fcs = torch.fft.rfft(xs, n=size_pad, dim=-1) 118 | convs = torch.fft.irfft(fcs * fcs.conj())[:, :tau_max] 119 | y1 = torch.flip(x_cumsum[:, w - tau_max + 1:w + 1], dims=[-1]) 120 | y = y1 + x_cumsum[:, w].unsqueeze(-1) - x_cumsum[:, :tau_max] - 2 * convs 121 | return y 122 | 123 | 124 | def cumulativeMeanNormalizedDifferenceFunctionTorch(dfs: torch.Tensor, 125 | N, 126 | eps=1e-8) -> torch.Tensor: 127 | arange = torch.arange(1, N, device=dfs.device, dtype=torch.float64) 128 | cumsum = torch.cumsum(dfs[:, 1:], dim=-1, 129 | dtype=torch.float64).to(dfs.device) 130 | 131 | cmndfs = dfs[:, 1:] * arange / (cumsum + eps) 132 | cmndfs = torch.cat( 133 | (torch.ones(cmndfs.shape[0], 1, device=dfs.device), cmndfs), dim=-1) 134 | return cmndfs 135 | 136 | 137 | if __name__ == '__main__': 138 | wav = torch.randn(32, int(22050 * 1.5)).cuda() 139 | wav_numpy = wav.detach().cpu().numpy() 140 | x = wav_numpy[0] 141 | 142 | w_len = 2048 143 | w_step = 256 144 | tau_max = 2048 145 | W = 2048 146 | 147 | startFrames = list(range(0, x.shape[-1] - w_len, w_step)) 148 | startFrames = np.asarray(startFrames) 149 | # times = startFrames / sr 150 | frames = [x[..., t:t + W] for t in startFrames] 151 | frames = np.asarray(frames) 152 | frames_torch = torch.from_numpy(frames).cuda() 153 | 154 | cmndfs0 = [] 155 | for idx, frame in enumerate(frames): 156 | df = differenceFunction(frame, frame.shape[-1], tau_max) 157 | cmndf = cumulativeMeanNormalizedDifferenceFunction(df, tau_max) 158 | cmndfs0.append(cmndf) 159 | cmndfs0 = np.asarray(cmndfs0) 160 | 161 | dfs = differenceFunctionTorch(frames_torch, frames_torch.shape[-1], 162 | tau_max) 163 | cmndfs1 = cumulativeMeanNormalizedDifferenceFunctionTorch( 164 | dfs, tau_max).detach().cpu().numpy() 165 | print(cmndfs0.shape, cmndfs1.shape) 166 | print(np.sum(np.abs(cmndfs0 - cmndfs1))) 167 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import math 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def init_weights(m, mean=0.0, std=0.01): 8 | classname = m.__class__.__name__ 9 | if classname.find("Conv") != -1: 10 | m.weight.data.normal_(mean, std) 11 | 12 | 13 | def get_padding(kernel_size, dilation=1): 14 | return int((kernel_size * dilation - dilation) / 2) 15 | 16 | 17 | def convert_pad_shape(pad_shape): 18 | l = pad_shape[::-1] 19 | pad_shape = [item for sublist in l for item in sublist] 20 | return pad_shape 21 | 22 | 23 | def intersperse(lst, item): 24 | result = [item] * (len(lst) * 2 + 1) 25 | result[1::2] = lst 26 | return result 27 | 28 | def intersperse_with_language_id(text, lang, item): 29 | n = len(text) 30 | _text = [item] * (2 * n + 1) 31 | _lang = [None] * (2 * n + 1) 32 | _text[1::2] = text 33 | _lang[1::2] = lang 34 | _lang[::2] = lang + [lang[-1]] 35 | 36 | return _text, _lang 37 | 38 | def kl_divergence(m_p, logs_p, m_q, logs_q): 39 | """KL(P||Q)""" 40 | kl = (logs_q - logs_p) - 0.5 41 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 42 | return kl 43 | 44 | 45 | def rand_gumbel(shape): 46 | """Sample from the Gumbel distribution, protect from overflows.""" 47 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 48 | return -torch.log(-torch.log(uniform_samples)) 49 | 50 | 51 | def rand_gumbel_like(x): 52 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 53 | return g 54 | 55 | 56 | def slice_segments(x, ids_str, segment_size=4): 57 | ret = torch.zeros_like(x[:, :, :segment_size]) 58 | for i in range(x.size(0)): 59 | idx_str = ids_str[i] 60 | idx_end = idx_str + segment_size 61 | ret[i] = x[i, :, idx_str:idx_end] 62 | return ret 63 | 64 | 65 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 66 | b, d, t = x.size() 67 | if x_lengths is None: 68 | x_lengths = t 69 | ids_str_max = x_lengths - segment_size + 1 70 | ids_str = (torch.rand([b]).to(device=x.device) 71 | * ids_str_max).to(dtype=torch.long) 72 | ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to(dtype=torch.long) 73 | ret = slice_segments(x, ids_str, segment_size) 74 | return ret, ids_str 75 | 76 | def rand_slice_segments_for_cat(x, x_lengths=None, segment_size=4): 77 | b, d, t = x.size() 78 | if x_lengths is None: 79 | x_lengths = t 80 | ids_str_max = x_lengths - segment_size + 1 81 | ids_str = torch.rand([b//2]).to(device=x.device) 82 | ids_str = (torch.cat([ids_str,ids_str], dim=0) 83 | * ids_str_max).to(dtype=torch.long) 84 | ids_str = torch.max(torch.zeros(ids_str.size()).to(ids_str.device), ids_str).to(dtype=torch.long) 85 | ret = slice_segments(x, ids_str, segment_size) 86 | return ret, ids_str 87 | 88 | 89 | 90 | 91 | def get_timing_signal_1d( 92 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 93 | position = torch.arange(length, dtype=torch.float) 94 | num_timescales = channels // 2 95 | log_timescale_increment = ( 96 | math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) 97 | ) 98 | inv_timescales = min_timescale * torch.exp( 99 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 100 | ) 101 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 102 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 103 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 104 | signal = signal.view(1, channels, length) 105 | return signal 106 | 107 | 108 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 109 | b, channels, length = x.size() 110 | signal = get_timing_signal_1d( 111 | length, channels, min_timescale, max_timescale 112 | ) 113 | return x + signal.to(dtype=x.dtype, device=x.device) 114 | 115 | 116 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 117 | b, channels, length = x.size() 118 | signal = get_timing_signal_1d( 119 | length, channels, min_timescale, max_timescale 120 | ) 121 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 122 | 123 | 124 | def subsequent_mask(length): 125 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 126 | return mask 127 | 128 | 129 | @torch.jit.script 130 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 131 | n_channels_int = n_channels[0] 132 | in_act = input_a + input_b 133 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 134 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 135 | acts = t_act * s_act 136 | return acts 137 | 138 | 139 | def convert_pad_shape(pad_shape): 140 | l = pad_shape[::-1] 141 | pad_shape = [item for sublist in l for item in sublist] 142 | return pad_shape 143 | 144 | 145 | def shift_1d(x): 146 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 147 | return x 148 | 149 | 150 | def sequence_mask(length, max_length=None): 151 | if max_length is None: 152 | max_length = length.max() 153 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 154 | return x.unsqueeze(0) < length.unsqueeze(1) 155 | 156 | 157 | def generate_path(duration, mask): 158 | """ 159 | duration: [b, 1, t_x] 160 | mask: [b, 1, t_y, t_x] 161 | """ 162 | device = duration.device 163 | 164 | b, _, t_y, t_x = mask.shape 165 | cum_duration = torch.cumsum(duration, -1) 166 | 167 | cum_duration_flat = cum_duration.view(b * t_x) 168 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 169 | path = path.view(b, t_x, t_y) 170 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 171 | path = path.unsqueeze(1).transpose(2, 3) * mask 172 | return path 173 | 174 | 175 | def clip_grad_value_(parameters, clip_value, norm_type=2): 176 | if isinstance(parameters, torch.Tensor): 177 | parameters = [parameters] 178 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 179 | norm_type = float(norm_type) 180 | if clip_value is not None: 181 | clip_value = float(clip_value) 182 | 183 | total_norm = 0 184 | for p in parameters: 185 | param_norm = p.grad.data.norm(norm_type) 186 | total_norm += param_norm.item() ** norm_type 187 | if clip_value is not None: 188 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 189 | total_norm = total_norm ** (1. / norm_type) 190 | return total_norm 191 | -------------------------------------------------------------------------------- /text/korean.py: -------------------------------------------------------------------------------- 1 | import re 2 | from jamo import h2j, j2hcj 3 | import ko_pron 4 | 5 | 6 | # This is a list of Korean classifiers preceded by pure Korean numerals. 7 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' 8 | 9 | # List of (hangul, hangul divided) pairs: 10 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ 11 | ('ㄳ', 'ㄱㅅ'), 12 | ('ㄵ', 'ㄴㅈ'), 13 | ('ㄶ', 'ㄴㅎ'), 14 | ('ㄺ', 'ㄹㄱ'), 15 | ('ㄻ', 'ㄹㅁ'), 16 | ('ㄼ', 'ㄹㅂ'), 17 | ('ㄽ', 'ㄹㅅ'), 18 | ('ㄾ', 'ㄹㅌ'), 19 | ('ㄿ', 'ㄹㅍ'), 20 | ('ㅀ', 'ㄹㅎ'), 21 | ('ㅄ', 'ㅂㅅ'), 22 | ('ㅘ', 'ㅗㅏ'), 23 | ('ㅙ', 'ㅗㅐ'), 24 | ('ㅚ', 'ㅗㅣ'), 25 | ('ㅝ', 'ㅜㅓ'), 26 | ('ㅞ', 'ㅜㅔ'), 27 | ('ㅟ', 'ㅜㅣ'), 28 | ('ㅢ', 'ㅡㅣ'), 29 | ('ㅑ', 'ㅣㅏ'), 30 | ('ㅒ', 'ㅣㅐ'), 31 | ('ㅕ', 'ㅣㅓ'), 32 | ('ㅖ', 'ㅣㅔ'), 33 | ('ㅛ', 'ㅣㅗ'), 34 | ('ㅠ', 'ㅣㅜ') 35 | ]] 36 | 37 | # List of (Latin alphabet, hangul) pairs: 38 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 39 | ('a', '에이'), 40 | ('b', '비'), 41 | ('c', '시'), 42 | ('d', '디'), 43 | ('e', '이'), 44 | ('f', '에프'), 45 | ('g', '지'), 46 | ('h', '에이치'), 47 | ('i', '아이'), 48 | ('j', '제이'), 49 | ('k', '케이'), 50 | ('l', '엘'), 51 | ('m', '엠'), 52 | ('n', '엔'), 53 | ('o', '오'), 54 | ('p', '피'), 55 | ('q', '큐'), 56 | ('r', '아르'), 57 | ('s', '에스'), 58 | ('t', '티'), 59 | ('u', '유'), 60 | ('v', '브이'), 61 | ('w', '더블유'), 62 | ('x', '엑스'), 63 | ('y', '와이'), 64 | ('z', '제트') 65 | ]] 66 | 67 | # List of (ipa, lazy ipa) pairs: 68 | _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 69 | ('t͡ɕ','ʧ'), 70 | ('d͡ʑ','ʥ'), 71 | ('ɲ','n^'), 72 | ('ɕ','ʃ'), 73 | ('ʷ','w'), 74 | ('ɭ','l`'), 75 | ('ʎ','ɾ'), 76 | ('ɣ','ŋ'), 77 | ('ɰ','ɯ'), 78 | ('ʝ','j'), 79 | ('ʌ','ə'), 80 | ('ɡ','g'), 81 | ('\u031a','#'), 82 | ('\u0348','='), 83 | ('\u031e',''), 84 | ('\u0320',''), 85 | ('\u0339','') 86 | ]] 87 | 88 | 89 | def latin_to_hangul(text): 90 | for regex, replacement in _latin_to_hangul: 91 | text = re.sub(regex, replacement, text) 92 | return text 93 | 94 | 95 | def divide_hangul(text): 96 | text = j2hcj(h2j(text)) 97 | for regex, replacement in _hangul_divided: 98 | text = re.sub(regex, replacement, text) 99 | return text 100 | 101 | 102 | def hangul_number(num, sino=True): 103 | '''Reference https://github.com/Kyubyong/g2pK''' 104 | num = re.sub(',', '', num) 105 | 106 | if num == '0': 107 | return '영' 108 | if not sino and num == '20': 109 | return '스무' 110 | 111 | digits = '123456789' 112 | names = '일이삼사오육칠팔구' 113 | digit2name = {d: n for d, n in zip(digits, names)} 114 | 115 | modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉' 116 | decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔' 117 | digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} 118 | digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} 119 | 120 | spelledout = [] 121 | for i, digit in enumerate(num): 122 | i = len(num) - i - 1 123 | if sino: 124 | if i == 0: 125 | name = digit2name.get(digit, '') 126 | elif i == 1: 127 | name = digit2name.get(digit, '') + '십' 128 | name = name.replace('일십', '십') 129 | else: 130 | if i == 0: 131 | name = digit2mod.get(digit, '') 132 | elif i == 1: 133 | name = digit2dec.get(digit, '') 134 | if digit == '0': 135 | if i % 4 == 0: 136 | last_three = spelledout[-min(3, len(spelledout)):] 137 | if ''.join(last_three) == '': 138 | spelledout.append('') 139 | continue 140 | else: 141 | spelledout.append('') 142 | continue 143 | if i == 2: 144 | name = digit2name.get(digit, '') + '백' 145 | name = name.replace('일백', '백') 146 | elif i == 3: 147 | name = digit2name.get(digit, '') + '천' 148 | name = name.replace('일천', '천') 149 | elif i == 4: 150 | name = digit2name.get(digit, '') + '만' 151 | name = name.replace('일만', '만') 152 | elif i == 5: 153 | name = digit2name.get(digit, '') + '십' 154 | name = name.replace('일십', '십') 155 | elif i == 6: 156 | name = digit2name.get(digit, '') + '백' 157 | name = name.replace('일백', '백') 158 | elif i == 7: 159 | name = digit2name.get(digit, '') + '천' 160 | name = name.replace('일천', '천') 161 | elif i == 8: 162 | name = digit2name.get(digit, '') + '억' 163 | elif i == 9: 164 | name = digit2name.get(digit, '') + '십' 165 | elif i == 10: 166 | name = digit2name.get(digit, '') + '백' 167 | elif i == 11: 168 | name = digit2name.get(digit, '') + '천' 169 | elif i == 12: 170 | name = digit2name.get(digit, '') + '조' 171 | elif i == 13: 172 | name = digit2name.get(digit, '') + '십' 173 | elif i == 14: 174 | name = digit2name.get(digit, '') + '백' 175 | elif i == 15: 176 | name = digit2name.get(digit, '') + '천' 177 | spelledout.append(name) 178 | return ''.join(elem for elem in spelledout) 179 | 180 | 181 | def number_to_hangul(text): 182 | '''Reference https://github.com/Kyubyong/g2pK''' 183 | tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text)) 184 | for token in tokens: 185 | num, classifier = token 186 | if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: 187 | spelledout = hangul_number(num, sino=False) 188 | else: 189 | spelledout = hangul_number(num, sino=True) 190 | text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}') 191 | # digit by digit for remaining digits 192 | digits = '0123456789' 193 | names = '영일이삼사오육칠팔구' 194 | for d, n in zip(digits, names): 195 | text = text.replace(d, n) 196 | return text 197 | 198 | 199 | def korean_to_lazy_ipa(text): 200 | text = latin_to_hangul(text) 201 | text = number_to_hangul(text) 202 | text=re.sub('[\uac00-\ud7af]+',lambda x:ko_pron.romanise(x.group(0),'ipa').split('] ~ [')[0],text) 203 | for regex, replacement in _ipa_to_lazy_ipa: 204 | text = re.sub(regex, replacement, text) 205 | return text 206 | 207 | 208 | def korean_to_ipa(text): 209 | text = korean_to_lazy_ipa(text) 210 | return text.replace('ʧ','tʃ').replace('ʥ','dʑ') 211 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import argparse 3 | import torch 4 | import commons 5 | import utils 6 | from models import ( 7 | SynthesizerTrn, ) 8 | from text import cleaned_text_to_sequence 9 | from text.cleaners import text_to_sequence, _clean_text 10 | 11 | from text.symbols import symbols 12 | 13 | # we use Kyubyong/g2p for demo instead of our internal g2p 14 | # https://github.com/Kyubyong/g2p 15 | def get_text(text, hps): 16 | cleaned_text, lang = _clean_text(text) 17 | text_norm = cleaned_text_to_sequence(cleaned_text) 18 | if hps.data.add_blank: 19 | text_norm,lang = commons.intersperse_with_language_id(text_norm,lang, 0) 20 | text_norm = torch.LongTensor(text_norm) 21 | lang = torch.LongTensor(lang) 22 | return text_norm,lang,cleaned_text 23 | 24 | class GradioApp: 25 | 26 | def __init__(self, args): 27 | self.hps = utils.get_hparams_from_file(args.config) 28 | self.device = "cpu" 29 | self.net_g = SynthesizerTrn(len(symbols), 30 | self.hps.data.filter_length // 2 + 1, 31 | self.hps.train.segment_size // 32 | self.hps.data.hop_length, 33 | midi_start=-5, 34 | midi_end=75, 35 | octave_range=24, 36 | n_speakers=len(self.hps.data.speakers), 37 | **self.hps.model).to(self.device) 38 | _ = self.net_g.eval() 39 | _ = utils.load_checkpoint(args.checkpoint_path, model_g=self.net_g) 40 | self.interface = self._gradio_interface() 41 | 42 | def get_phoneme(self, text): 43 | cleaned_text, lang = _clean_text(text) 44 | text_norm = cleaned_text_to_sequence(cleaned_text) 45 | if self.hps.data.add_blank: 46 | text_norm, lang = commons.intersperse_with_language_id(text_norm, lang, 0) 47 | text_norm = torch.LongTensor(text_norm) 48 | lang = torch.LongTensor(lang) 49 | return text_norm, lang, cleaned_text 50 | 51 | def inference(self, text, speaker_id_val, seed, scope_shift, duration): 52 | seed = int(seed) 53 | scope_shift = int(scope_shift) 54 | torch.manual_seed(seed) 55 | text_norm, tone, phones = self.get_phoneme(text) 56 | x_tst = text_norm.to(self.device).unsqueeze(0) 57 | t_tst = tone.to(self.device).unsqueeze(0) 58 | x_tst_lengths = torch.LongTensor([text_norm.size(0)]).to(self.device) 59 | speaker_id = torch.LongTensor([speaker_id_val]).to(self.device) 60 | decoder_inputs,*_ = self.net_g.infer_pre_decoder( 61 | x_tst, 62 | t_tst, 63 | x_tst_lengths, 64 | sid=speaker_id, 65 | noise_scale=0.667, 66 | noise_scale_w=0.8, 67 | length_scale=duration, 68 | scope_shift=scope_shift) 69 | audio = self.net_g.infer_decode_chunk( 70 | decoder_inputs, sid=speaker_id)[0, 0].data.cpu().float().numpy() 71 | del decoder_inputs, 72 | return phones, (self.hps.data.sampling_rate, audio) 73 | 74 | 75 | def _gradio_interface(self): 76 | title = "PITS Demo" 77 | self.inputs = [ 78 | gr.Textbox(label="Text (150 words limitation)", 79 | value="[JA]こんにちは、私は綾地寧々です。[JA]", 80 | elem_id="tts-input"), 81 | gr.Dropdown(list(self.hps.data.speakers), 82 | value=self.hps.data.speakers[0], 83 | label="Speaker Identity", 84 | type="index"), 85 | gr.Slider(0, 65536, value=0, step=1, label="random seed"), 86 | gr.Slider(-15, 15, value=0, step=1, label="scope-shift"), 87 | gr.Slider(0.5, 2., value=1., step=0.1, 88 | label="duration multiplier"), 89 | ] 90 | self.outputs = [ 91 | gr.Textbox(label="Phonemes"), 92 | gr.Audio(type="numpy", label="Output audio") 93 | ] 94 | description = "Welcome to the Gradio demo for PITS: Variational Pitch Inference without Fundamental Frequency for End-to-End Pitch-controllable TTS.\n In this demo, we utilize an open-source G2P library (g2p_en) with stress removing, instead of our internal G2P.\n You can fix the latent z by controlling random seed.\n You can shift the pitch scope, but please note that this is opposite to pitch-shift. In addition, it is cropped from fixed z so please check pitch-controllability by comparing with normal synthesis.\n Thank you for trying out our PITS demo!" 95 | article = "Github:https://github.com/anonymous-pits/pits \n Our current preprint contains several errors. Please wait for next update." 96 | examples = [["[JA]こんにちは、私は綾地寧々です。[JA]"],["[JA]This is a demo page of the PITS.[JA]"]] 97 | return gr.Interface( 98 | fn=self.inference, 99 | inputs=self.inputs, 100 | outputs=self.outputs, 101 | title=title, 102 | description=description, 103 | article=article, 104 | cache_examples=False, 105 | examples=examples, 106 | ) 107 | 108 | def launch(self): 109 | return self.interface.launch(share=True) 110 | 111 | 112 | def parsearg(): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('-c', 115 | '--config', 116 | type=str, 117 | default="./configs/config_cjke.yaml", 118 | help='Path to configuration file') 119 | parser.add_argument('-m', 120 | '--model', 121 | type=str, 122 | default='PITS', 123 | help='Model name') 124 | parser.add_argument('-r', 125 | '--checkpoint_path', 126 | type=str, 127 | default='./logs/cjke/cjke_36800.pth', 128 | help='Path to checkpoint for resume') 129 | parser.add_argument('-f', 130 | '--force_resume', 131 | type=str, 132 | help='Path to checkpoint for force resume') 133 | parser.add_argument('-d', 134 | '--dir', 135 | type=str, 136 | default='/DATA/audio/pits_samples', 137 | help='root dir') 138 | args = parser.parse_args() 139 | return args 140 | 141 | if __name__ == "__main__": 142 | args = parsearg() 143 | app = GradioApp(args) 144 | app.launch() 145 | -------------------------------------------------------------------------------- /text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from g2p_en import G2p 3 | 4 | 5 | ''' 6 | Cleaners are transformations that run over the input text at both training and eval time. 7 | 8 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 9 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 10 | 1. "english_cleaners" for English text 11 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 12 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 13 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 14 | the symbols in symbols.py to match your data). 15 | ''' 16 | 17 | 18 | # Regular expression matching whitespace: 19 | g2p = G2p() 20 | 21 | import re 22 | import inflect 23 | from unidecode import unidecode 24 | import eng_to_ipa as ipa 25 | _inflect = inflect.engine() 26 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 27 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 28 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 29 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 30 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 31 | _number_re = re.compile(r'[0-9]+') 32 | 33 | # List of (regular expression, replacement) pairs for abbreviations: 34 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 35 | ('mrs', 'misess'), 36 | ('mr', 'mister'), 37 | ('dr', 'doctor'), 38 | ('st', 'saint'), 39 | ('co', 'company'), 40 | ('jr', 'junior'), 41 | ('maj', 'major'), 42 | ('gen', 'general'), 43 | ('drs', 'doctors'), 44 | ('rev', 'reverend'), 45 | ('lt', 'lieutenant'), 46 | ('hon', 'honorable'), 47 | ('sgt', 'sergeant'), 48 | ('capt', 'captain'), 49 | ('esq', 'esquire'), 50 | ('ltd', 'limited'), 51 | ('col', 'colonel'), 52 | ('ft', 'fort'), 53 | ]] 54 | 55 | 56 | # List of (ipa, lazy ipa) pairs: 57 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 58 | ('r', 'ɹ'), 59 | ('æ', 'e'), 60 | ('ɑ', 'a'), 61 | ('ɔ', 'o'), 62 | ('ð', 'z'), 63 | ('θ', 's'), 64 | ('ɛ', 'e'), 65 | ('ɪ', 'i'), 66 | ('ʊ', 'u'), 67 | ('ʒ', 'ʥ'), 68 | ('ʤ', 'ʥ'), 69 | ('ˈ', '↓'), 70 | ]] 71 | 72 | # List of (ipa, lazy ipa2) pairs: 73 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 74 | ('r', 'ɹ'), 75 | ('ð', 'z'), 76 | ('θ', 's'), 77 | ('ʒ', 'ʑ'), 78 | ('ʤ', 'dʑ'), 79 | ('ˈ', '↓'), 80 | ]] 81 | 82 | # List of (ipa, ipa2) pairs 83 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 84 | ('r', 'ɹ'), 85 | ('ʤ', 'dʒ'), 86 | ('ʧ', 'tʃ') 87 | ]] 88 | 89 | 90 | def expand_abbreviations(text): 91 | for regex, replacement in _abbreviations: 92 | text = re.sub(regex, replacement, text) 93 | return text 94 | 95 | 96 | def collapse_whitespace(text): 97 | return re.sub(r'\s+', ' ', text) 98 | 99 | 100 | def _remove_commas(m): 101 | return m.group(1).replace(',', '') 102 | 103 | 104 | def _expand_decimal_point(m): 105 | return m.group(1).replace('.', ' point ') 106 | 107 | 108 | def _expand_dollars(m): 109 | match = m.group(1) 110 | parts = match.split('.') 111 | if len(parts) > 2: 112 | return match + ' dollars' # Unexpected format 113 | dollars = int(parts[0]) if parts[0] else 0 114 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 115 | if dollars and cents: 116 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 117 | cent_unit = 'cent' if cents == 1 else 'cents' 118 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 119 | elif dollars: 120 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 121 | return '%s %s' % (dollars, dollar_unit) 122 | elif cents: 123 | cent_unit = 'cent' if cents == 1 else 'cents' 124 | return '%s %s' % (cents, cent_unit) 125 | else: 126 | return 'zero dollars' 127 | 128 | 129 | def _expand_ordinal(m): 130 | return _inflect.number_to_words(m.group(0)) 131 | 132 | 133 | def _expand_number(m): 134 | num = int(m.group(0)) 135 | if num > 1000 and num < 3000: 136 | if num == 2000: 137 | return 'two thousand' 138 | elif num > 2000 and num < 2010: 139 | return 'two thousand ' + _inflect.number_to_words(num % 100) 140 | elif num % 100 == 0: 141 | return _inflect.number_to_words(num // 100) + ' hundred' 142 | else: 143 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 144 | else: 145 | return _inflect.number_to_words(num, andword='') 146 | 147 | 148 | def normalize_numbers(text): 149 | text = re.sub(_comma_number_re, _remove_commas, text) 150 | text = re.sub(_pounds_re, r'\1 pounds', text) 151 | text = re.sub(_dollars_re, _expand_dollars, text) 152 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 153 | text = re.sub(_ordinal_re, _expand_ordinal, text) 154 | text = re.sub(_number_re, _expand_number, text) 155 | return text 156 | 157 | 158 | def mark_dark_l(text): 159 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 160 | 161 | 162 | def english_to_ipa(text): 163 | text = text.replace("-", " ") 164 | text = unidecode(text).lower() 165 | text = expand_abbreviations(text) 166 | text = normalize_numbers(text) 167 | phonemes = ipa.convert(text) 168 | phonemes = unrecognized_words_to_ipa(phonemes) 169 | phonemes = collapse_whitespace(phonemes) 170 | return phonemes 171 | 172 | 173 | def english_to_lazy_ipa(text): 174 | text = english_to_ipa(text) 175 | for regex, replacement in _lazy_ipa: 176 | text = re.sub(regex, replacement, text) 177 | return text 178 | 179 | 180 | def english_to_ipa2(text): 181 | text = english_to_ipa(text) 182 | text = mark_dark_l(text) 183 | for regex, replacement in _ipa_to_ipa2: 184 | text = re.sub(regex, replacement, text) 185 | return text.replace('...', '…') 186 | 187 | 188 | def convert_to_ipa(phones): 189 | ipa = "" 190 | symbols = {"a": "ə", "ey": "eɪ", "aa": "ɑ", "ae": "æ", "ah": "ə", "ao": "ɔ", 191 | "aw": "aʊ", "ay": "aɪ", "ch": "ʧ", "dh": "ð", "eh": "ɛ", "er": "ər", 192 | "hh": "h", "ih": "ɪ", "jh": "ʤ", "ng": "ŋ", "ow": "oʊ", "oy": "ɔɪ", 193 | "sh": "ʃ", "th": "θ", "uh": "ʊ", "uw": "u", "zh": "ʒ", "iy": "i", "y": "j"} 194 | for ph in phones: 195 | ph = ph.lower() 196 | try: 197 | if ph[-1] in "01234": 198 | ipa+=symbols[ph[:-1]] 199 | else: 200 | ipa += symbols[ph] 201 | except: 202 | ipa += ph 203 | return ipa 204 | 205 | def unrecognized_words_to_ipa(text): 206 | matches = re.findall(r'\s([\w|\']+\*)', text) 207 | for word in matches: 208 | ipa = convert_to_ipa(g2p(word)) 209 | text = text.replace(word, ipa) 210 | matches = re.findall(r'^([\w|\']+\*)', text) 211 | for word in matches: 212 | ipa = convert_to_ipa(g2p(word)) 213 | text = text.replace(word, ipa) 214 | return text 215 | 216 | def english_to_lazy_ipa2(text): 217 | text = english_to_ipa(text) 218 | for regex, replacement in _lazy_ipa2: 219 | text = re.sub(regex, replacement, text) 220 | return text 221 | -------------------------------------------------------------------------------- /text/frontend/zh_normalization/num.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Rules to verbalize numbers into Chinese characters. 16 | https://zh.wikipedia.org/wiki/中文数字#現代中文 17 | """ 18 | import re 19 | from collections import OrderedDict 20 | from typing import List 21 | 22 | DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')} 23 | UNITS = OrderedDict({ 24 | 1: '十', 25 | 2: '百', 26 | 3: '千', 27 | 4: '万', 28 | 8: '亿', 29 | }) 30 | 31 | COM_QUANTIFIERS = '(所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' 32 | 33 | # 分数表达式 34 | RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') 35 | 36 | 37 | def replace_frac(match) -> str: 38 | """ 39 | Args: 40 | match (re.Match) 41 | Returns: 42 | str 43 | """ 44 | sign = match.group(1) 45 | nominator = match.group(2) 46 | denominator = match.group(3) 47 | sign: str = "负" if sign else "" 48 | nominator: str = num2str(nominator) 49 | denominator: str = num2str(denominator) 50 | result = f"{sign}{denominator}分之{nominator}" 51 | return result 52 | 53 | 54 | # 百分数表达式 55 | RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%') 56 | 57 | 58 | def replace_percentage(match) -> str: 59 | """ 60 | Args: 61 | match (re.Match) 62 | Returns: 63 | str 64 | """ 65 | sign = match.group(1) 66 | percent = match.group(2) 67 | sign: str = "负" if sign else "" 68 | percent: str = num2str(percent) 69 | result = f"{sign}百分之{percent}" 70 | return result 71 | 72 | 73 | # 整数表达式 74 | # 带负号的整数 -10 75 | RE_INTEGER = re.compile(r'(-)' r'(\d+)') 76 | 77 | 78 | def replace_negative_num(match) -> str: 79 | """ 80 | Args: 81 | match (re.Match) 82 | Returns: 83 | str 84 | """ 85 | sign = match.group(1) 86 | number = match.group(2) 87 | sign: str = "负" if sign else "" 88 | number: str = num2str(number) 89 | result = f"{sign}{number}" 90 | return result 91 | 92 | 93 | # 编号-无符号整形 94 | # 00078 95 | RE_DEFAULT_NUM = re.compile(r'\d{3}\d*') 96 | 97 | 98 | def replace_default_num(match): 99 | """ 100 | Args: 101 | match (re.Match) 102 | Returns: 103 | str 104 | """ 105 | number = match.group(0) 106 | return verbalize_digit(number) 107 | 108 | 109 | # 数字表达式 110 | # 纯小数 111 | RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))') 112 | # 正整数 + 量词 113 | RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS) 114 | RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))') 115 | 116 | 117 | def replace_positive_quantifier(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | number = match.group(1) 125 | match_2 = match.group(2) 126 | if match_2 == "+": 127 | match_2 = "多" 128 | match_2: str = match_2 if match_2 else "" 129 | quantifiers: str = match.group(3) 130 | number: str = num2str(number) 131 | result = f"{number}{match_2}{quantifiers}" 132 | return result 133 | 134 | 135 | def replace_number(match) -> str: 136 | """ 137 | Args: 138 | match (re.Match) 139 | Returns: 140 | str 141 | """ 142 | sign = match.group(1) 143 | number = match.group(2) 144 | pure_decimal = match.group(5) 145 | if pure_decimal: 146 | result = num2str(pure_decimal) 147 | else: 148 | sign: str = "负" if sign else "" 149 | number: str = num2str(number) 150 | result = f"{sign}{number}" 151 | return result 152 | 153 | 154 | # 范围表达式 155 | # match.group(1) and match.group(8) are copy from RE_NUMBER 156 | 157 | RE_RANGE = re.compile( 158 | r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))[-~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))') 159 | 160 | 161 | def replace_range(match) -> str: 162 | """ 163 | Args: 164 | match (re.Match) 165 | Returns: 166 | str 167 | """ 168 | first, second = match.group(1), match.group(8) 169 | first = RE_NUMBER.sub(replace_number, first) 170 | second = RE_NUMBER.sub(replace_number, second) 171 | result = f"{first}到{second}" 172 | return result 173 | 174 | 175 | def _get_value(value_string: str, use_zero: bool=True) -> List[str]: 176 | stripped = value_string.lstrip('0') 177 | if len(stripped) == 0: 178 | return [] 179 | elif len(stripped) == 1: 180 | if use_zero and len(stripped) < len(value_string): 181 | return [DIGITS['0'], DIGITS[stripped]] 182 | else: 183 | return [DIGITS[stripped]] 184 | else: 185 | largest_unit = next( 186 | power for power in reversed(UNITS.keys()) if power < len(stripped)) 187 | first_part = value_string[:-largest_unit] 188 | second_part = value_string[-largest_unit:] 189 | return _get_value(first_part) + [UNITS[largest_unit]] + _get_value( 190 | second_part) 191 | 192 | 193 | def verbalize_cardinal(value_string: str) -> str: 194 | if not value_string: 195 | return '' 196 | 197 | # 000 -> '零' , 0 -> '零' 198 | value_string = value_string.lstrip('0') 199 | if len(value_string) == 0: 200 | return DIGITS['0'] 201 | 202 | result_symbols = _get_value(value_string) 203 | # verbalized number starting with '一十*' is abbreviated as `十*` 204 | if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[ 205 | '1'] and result_symbols[1] == UNITS[1]: 206 | result_symbols = result_symbols[1:] 207 | return ''.join(result_symbols) 208 | 209 | 210 | def verbalize_digit(value_string: str, alt_one=False) -> str: 211 | result_symbols = [DIGITS[digit] for digit in value_string] 212 | result = ''.join(result_symbols) 213 | if alt_one: 214 | result = result.replace("一", "幺") 215 | return result 216 | 217 | 218 | def num2str(value_string: str) -> str: 219 | integer_decimal = value_string.split('.') 220 | if len(integer_decimal) == 1: 221 | integer = integer_decimal[0] 222 | decimal = '' 223 | elif len(integer_decimal) == 2: 224 | integer, decimal = integer_decimal 225 | else: 226 | raise ValueError( 227 | f"The value string: '${value_string}' has more than one point in it." 228 | ) 229 | 230 | result = verbalize_cardinal(integer) 231 | 232 | decimal = decimal.rstrip('0') 233 | if decimal: 234 | # '.22' is verbalized as '零点二二' 235 | # '3.20' is verbalized as '三点二 236 | result = result if result else "零" 237 | result += '点' + verbalize_digit(decimal) 238 | return result 239 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | from text.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3 3 | from text.korean import latin_to_hangul, number_to_hangul, divide_hangul, korean_to_lazy_ipa, korean_to_ipa 4 | from text.mandarin import number_to_chinese, latin_to_bopomofo, chinese_to_romaji, \ 5 | chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2, pinyin_to_ipa 6 | from text.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2 7 | from text.symbols import symbols 8 | from text import cleaned_text_to_sequence 9 | 10 | 11 | def str_replace( data): 12 | chinaTab = [";", ":", "\"", "'"] 13 | englishTab = [".", ",", ' ', " "] 14 | for index in range(len(chinaTab)): 15 | if chinaTab[index] in data: 16 | data = data.replace(chinaTab[index], englishTab[index]) 17 | return data 18 | 19 | 20 | def _clean_text(text): 21 | cleaned_text, lang_seq = cjke_cleaners3(text) 22 | cleaned_text = str_replace(cleaned_text) 23 | cleaned_text, lang_seq = remove_invalid_text(cleaned_text, lang_seq) 24 | 25 | return cleaned_text, lang_seq 26 | 27 | def text_to_sequence(text): 28 | cleaned_text, lang_seq = _clean_text(text) 29 | return cleaned_text_to_sequence(cleaned_text), lang_seq 30 | 31 | 32 | def japanese_cleaners(text): 33 | text = japanese_to_romaji_with_accent(text) 34 | text = re.sub(r'([A-Za-z])$', r'\1.', text) 35 | return text 36 | 37 | 38 | def japanese_cleaners2(text): 39 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') 40 | 41 | 42 | def korean_cleaners(text): 43 | '''Pipeline for Korean text''' 44 | text = latin_to_hangul(text) 45 | text = number_to_hangul(text) 46 | text = divide_hangul(text) 47 | text = re.sub(r'([\u3131-\u3163])$', r'\1.', text) 48 | return text 49 | 50 | 51 | def chinese_cleaners(text): 52 | '''Pipeline for Chinese text''' 53 | text = number_to_chinese(text) 54 | text = chinese_to_bopomofo(text) 55 | text = latin_to_bopomofo(text) 56 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text) 57 | return text 58 | 59 | 60 | def zh_ja_mixture_cleaners(text): 61 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 62 | lambda x: chinese_to_romaji(x.group(1))+' ', text) 63 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent( 64 | x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…')+' ', text) 65 | text = re.sub(r'\s+$', '', text) 66 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 67 | return text 68 | 69 | 70 | def sanskrit_cleaners(text): 71 | text = text.replace('॥', '।').replace('ॐ', 'ओम्') 72 | text = re.sub(r'([^।])$', r'\1।', text) 73 | return text 74 | 75 | 76 | def cjks_cleaners(text): 77 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 78 | lambda x: chinese_to_lazy_ipa(x.group(1))+' ', text) 79 | text = re.sub(r'\[JA\](.*?)\[JA\]', 80 | lambda x: japanese_to_ipa(x.group(1))+' ', text) 81 | text = re.sub(r'\[KO\](.*?)\[KO\]', 82 | lambda x: korean_to_lazy_ipa(x.group(1))+' ', text) 83 | text = re.sub(r'\[SA\](.*?)\[SA\]', 84 | lambda x: devanagari_to_ipa(x.group(1))+' ', text) 85 | text = re.sub(r'\[EN\](.*?)\[EN\]', 86 | lambda x: english_to_lazy_ipa(x.group(1))+' ', text) 87 | text = re.sub(r'\s+$', '', text) 88 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 89 | return text 90 | 91 | 92 | def cjke_cleaners(text): 93 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace( 94 | 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn')+' ', text) 95 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace( 96 | 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz')+' ', text) 97 | text = re.sub(r'\[KO\](.*?)\[KO\]', 98 | lambda x: korean_to_ipa(x.group(1))+' ', text) 99 | text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace( 100 | 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u')+' ', text) 101 | text = re.sub(r'\s+$', '', text) 102 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 103 | return text 104 | 105 | 106 | def cjke_cleaners2(text): 107 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 108 | lambda x: chinese_to_ipa(x.group(1))+' ', text) 109 | text = re.sub(r'\[JA\](.*?)\[JA\]', 110 | lambda x: japanese_to_ipa2(x.group(1))+' ', text) 111 | text = re.sub(r'\[KO\](.*?)\[KO\]', 112 | lambda x: korean_to_ipa(x.group(1))+' ', text) 113 | text = re.sub(r'\[EN\](.*?)\[EN\]', 114 | lambda x: english_to_ipa2(x.group(1))+' ', text) 115 | text = re.sub(r'\[P\](.*?)\[P\]', 116 | lambda x: pinyin_to_ipa(x.group(1))+' ', text) 117 | text = re.sub(r'\s+$', '', text) 118 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 119 | return text 120 | 121 | lang_map = { 122 | "ZH": 0, 123 | "JA": 1, 124 | "KO": 2, 125 | "EN": 3, 126 | "P": 0, 127 | "other":5 128 | } 129 | 130 | 131 | def cjke_cleaners3(text: str): 132 | text = str_replace(text).replace("\"", '') 133 | # find all text blocks enclosed in [JA], [ZH], [EN], [P] 134 | original_text = text 135 | blocks = re.finditer(r'\[(JA|ZH|EN|P|KO)\](.*?)\[\1\]', text) 136 | cleaned_text = "" 137 | lang_seq = [] 138 | last_end = 0 139 | for block in blocks: 140 | start, end = block.span() 141 | # insert text not enclosed in any blocks 142 | remaining_text = original_text[last_end:start] 143 | ipa = others_to_ipa(remaining_text) 144 | lang_seq += [lang_map["other"] for i in ipa] 145 | cleaned_text += ipa 146 | last_end = end 147 | language = block.group(1) 148 | text = block.group(2) 149 | if language == 'P': 150 | ipa = pinyin_to_ipa(text) 151 | lang_seq += [lang_map[language] for i in ipa] 152 | cleaned_text += ipa 153 | if language == 'JA': 154 | ipa = japanese_to_ipa2(text) 155 | lang_seq += [lang_map[language] for i in ipa] 156 | cleaned_text += ipa 157 | elif language == 'ZH': 158 | ipa = chinese_to_ipa(text) 159 | lang_seq += [lang_map[language] for i in ipa] 160 | cleaned_text += ipa 161 | elif language == 'EN': 162 | ipa = english_to_ipa2(text) 163 | lang_seq += [lang_map[language] for i in ipa] 164 | cleaned_text += ipa 165 | elif language == 'KO': 166 | ipa = korean_to_ipa(text) 167 | lang_seq += [lang_map[language] for i in ipa] 168 | cleaned_text += ipa 169 | remaining_text = original_text[last_end:] 170 | 171 | ipa = others_to_ipa(remaining_text) 172 | lang_seq += [lang_map["other"] for i in ipa] 173 | cleaned_text += ipa 174 | assert len(cleaned_text) == len(lang_seq) 175 | return cleaned_text, lang_seq 176 | 177 | def others_to_ipa(text): 178 | return text 179 | 180 | def remove_invalid_text(cleaned_text, lang_seq): 181 | new_cleaned_text = '' 182 | new_lang_seq = [] 183 | for symbol, la in zip(cleaned_text, lang_seq): 184 | if symbol not in symbols: 185 | print(cleaned_text) 186 | print("skip:", symbol) 187 | continue 188 | if la == lang_map["other"]: 189 | print("skip:", symbol) 190 | continue 191 | new_cleaned_text += symbol 192 | new_lang_seq.append(la) 193 | return new_cleaned_text, new_lang_seq 194 | 195 | if __name__ == '__main__': 196 | # print(_clean_text("%[EN]Miss Radcliffe's letter had told him [EN]")) 197 | print(cjke_cleaners3("[EN]Miss Radcliffe's letter had told him [EN]你好 hello[ZH]你好啊[ZH]")) 198 | # print(_clean_text("[P]ke3 % xian4 zai4 % jia4 ge2 % zhi2 jiang4 dao4 % yi2 wan4 duo1 $[P]")) 199 | # print(_clean_text("[ZH]可现在价格是降到一万多[ZH]")) -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1., 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE 23 | ): 24 | 25 | if tails is None: 26 | spline_fn = rational_quadratic_spline 27 | spline_kwargs = {} 28 | else: 29 | spline_fn = unconstrained_rational_quadratic_spline 30 | spline_kwargs = { 31 | 'tails': tails, 32 | 'tail_bound': tail_bound 33 | } 34 | 35 | outputs, logabsdet = spline_fn( 36 | inputs=inputs, 37 | unnormalized_widths=unnormalized_widths, 38 | unnormalized_heights=unnormalized_heights, 39 | unnormalized_derivatives=unnormalized_derivatives, 40 | inverse=inverse, 41 | min_bin_width=min_bin_width, 42 | min_bin_height=min_bin_height, 43 | min_derivative=min_derivative, 44 | **spline_kwargs 45 | ) 46 | return outputs, logabsdet 47 | 48 | 49 | def searchsorted(bin_locations, inputs, eps=1e-6): 50 | bin_locations[..., -1] += eps 51 | return torch.sum( 52 | inputs[..., None] >= bin_locations, 53 | dim=-1 54 | ) - 1 55 | 56 | 57 | def unconstrained_rational_quadratic_spline( 58 | inputs, 59 | unnormalized_widths, 60 | unnormalized_heights, 61 | unnormalized_derivatives, 62 | inverse=False, 63 | tails='linear', 64 | tail_bound=1., 65 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 66 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 67 | min_derivative=DEFAULT_MIN_DERIVATIVE 68 | ): 69 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 70 | outside_interval_mask = ~inside_interval_mask 71 | 72 | outputs = torch.zeros_like(inputs) 73 | logabsdet = torch.zeros_like(inputs) 74 | 75 | if tails == 'linear': 76 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 77 | constant = np.log(np.exp(1 - min_derivative) - 1) 78 | unnormalized_derivatives[..., 0] = constant 79 | unnormalized_derivatives[..., -1] = constant 80 | 81 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 82 | logabsdet[outside_interval_mask] = 0 83 | else: 84 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 85 | 86 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 87 | inputs=inputs[inside_interval_mask], 88 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 89 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 90 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 91 | inverse=inverse, 92 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 93 | min_bin_width=min_bin_width, 94 | min_bin_height=min_bin_height, 95 | min_derivative=min_derivative 96 | ) 97 | 98 | return outputs, logabsdet 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0., right=1., bottom=0., top=1., 107 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 108 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 109 | min_derivative=DEFAULT_MIN_DERIVATIVE 110 | ): 111 | if torch.min(inputs) < left or torch.max(inputs) > right: 112 | raise ValueError('Input to a transform is not within its domain') 113 | 114 | num_bins = unnormalized_widths.shape[-1] 115 | 116 | if min_bin_width * num_bins > 1.0: 117 | raise ValueError('Minimal bin width too large for the number of bins') 118 | if min_bin_height * num_bins > 1.0: 119 | raise ValueError('Minimal bin height too large for the number of bins') 120 | 121 | widths = F.softmax(unnormalized_widths, dim=-1) 122 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 123 | cumwidths = torch.cumsum(widths, dim=-1) 124 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 125 | cumwidths = (right - left) * cumwidths + left 126 | cumwidths[..., 0] = left 127 | cumwidths[..., -1] = right 128 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 129 | 130 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 131 | 132 | heights = F.softmax(unnormalized_heights, dim=-1) 133 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 134 | cumheights = torch.cumsum(heights, dim=-1) 135 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 136 | cumheights = (top - bottom) * cumheights + bottom 137 | cumheights[..., 0] = bottom 138 | cumheights[..., -1] = top 139 | heights = cumheights[..., 1:] - cumheights[..., :-1] 140 | 141 | if inverse: 142 | bin_idx = searchsorted(cumheights, inputs)[..., None] 143 | else: 144 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 145 | 146 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 147 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 148 | 149 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 150 | delta = heights / widths 151 | input_delta = delta.gather(-1, bin_idx)[..., 0] 152 | 153 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 154 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 155 | 156 | input_heights = heights.gather(-1, bin_idx)[..., 0] 157 | 158 | if inverse: 159 | a = (((inputs - input_cumheights) * (input_derivatives 160 | + input_derivatives_plus_one 161 | - 2 * input_delta) 162 | + input_heights * (input_delta - input_derivatives))) 163 | b = (input_heights * input_derivatives 164 | - (inputs - input_cumheights) * (input_derivatives 165 | + input_derivatives_plus_one 166 | - 2 * input_delta)) 167 | c = - input_delta * (inputs - input_cumheights) 168 | 169 | discriminant = b.pow(2) - 4 * a * c 170 | assert (discriminant >= 0).all() 171 | 172 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 173 | outputs = root * input_bin_widths + input_cumwidths 174 | 175 | theta_one_minus_theta = root * (1 - root) 176 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 177 | * theta_one_minus_theta) 178 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 179 | + 2 * input_delta * theta_one_minus_theta 180 | + input_derivatives * (1 - root).pow(2)) 181 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 182 | 183 | return outputs, -logabsdet 184 | else: 185 | theta = (inputs - input_cumwidths) / input_bin_widths 186 | theta_one_minus_theta = theta * (1 - theta) 187 | 188 | numerator = input_heights * (input_delta * theta.pow(2) 189 | + input_derivatives * theta_one_minus_theta) 190 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 191 | * theta_one_minus_theta) 192 | outputs = input_cumheights + numerator / denominator 193 | 194 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 195 | + 2 * input_delta * theta_one_minus_theta 196 | + input_derivatives * (1 - theta).pow(2)) 197 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 198 | 199 | return outputs, logabsdet 200 | -------------------------------------------------------------------------------- /text/mandarin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | from pypinyin import lazy_pinyin, BOPOMOFO 5 | import jieba 6 | import cn2an 7 | import logging 8 | 9 | 10 | # List of (Latin alphabet, bopomofo) pairs: 11 | from text.paddle_zh_frontend import zh_to_bopomofo, pinyin_to_bopomofo 12 | 13 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 14 | ('a', 'ㄟˉ'), 15 | ('b', 'ㄅㄧˋ'), 16 | ('c', 'ㄙㄧˉ'), 17 | ('d', 'ㄉㄧˋ'), 18 | ('e', 'ㄧˋ'), 19 | ('f', 'ㄝˊㄈㄨˋ'), 20 | ('g', 'ㄐㄧˋ'), 21 | ('h', 'ㄝˇㄑㄩˋ'), 22 | ('i', 'ㄞˋ'), 23 | ('j', 'ㄐㄟˋ'), 24 | ('k', 'ㄎㄟˋ'), 25 | ('l', 'ㄝˊㄛˋ'), 26 | ('m', 'ㄝˊㄇㄨˋ'), 27 | ('n', 'ㄣˉ'), 28 | ('o', 'ㄡˉ'), 29 | ('p', 'ㄆㄧˉ'), 30 | ('q', 'ㄎㄧㄡˉ'), 31 | ('r', 'ㄚˋ'), 32 | ('s', 'ㄝˊㄙˋ'), 33 | ('t', 'ㄊㄧˋ'), 34 | ('u', 'ㄧㄡˉ'), 35 | ('v', 'ㄨㄧˉ'), 36 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 37 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 38 | ('y', 'ㄨㄞˋ'), 39 | ('z', 'ㄗㄟˋ') 40 | ]] 41 | 42 | # List of (bopomofo, romaji) pairs: 43 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 44 | ('ㄅㄛ', 'p⁼wo'), 45 | ('ㄆㄛ', 'pʰwo'), 46 | ('ㄇㄛ', 'mwo'), 47 | ('ㄈㄛ', 'fwo'), 48 | ('ㄅ', 'p⁼'), 49 | ('ㄆ', 'pʰ'), 50 | ('ㄇ', 'm'), 51 | ('ㄈ', 'f'), 52 | ('ㄉ', 't⁼'), 53 | ('ㄊ', 'tʰ'), 54 | ('ㄋ', 'n'), 55 | ('ㄌ', 'l'), 56 | ('ㄍ', 'k⁼'), 57 | ('ㄎ', 'kʰ'), 58 | ('ㄏ', 'h'), 59 | ('ㄐ', 'ʧ⁼'), 60 | ('ㄑ', 'ʧʰ'), 61 | ('ㄒ', 'ʃ'), 62 | ('ㄓ', 'ʦ`⁼'), 63 | ('ㄔ', 'ʦ`ʰ'), 64 | ('ㄕ', 's`'), 65 | ('ㄖ', 'ɹ`'), 66 | ('ㄗ', 'ʦ⁼'), 67 | ('ㄘ', 'ʦʰ'), 68 | ('ㄙ', 's'), 69 | ('ㄚ', 'a'), 70 | ('ㄛ', 'o'), 71 | ('ㄜ', 'ə'), 72 | ('ㄝ', 'e'), 73 | ('ㄞ', 'ai'), 74 | ('ㄟ', 'ei'), 75 | ('ㄠ', 'au'), 76 | ('ㄡ', 'ou'), 77 | ('ㄧㄢ', 'yeNN'), 78 | ('ㄢ', 'aNN'), 79 | ('ㄧㄣ', 'iNN'), 80 | ('ㄣ', 'əNN'), 81 | ('ㄤ', 'aNg'), 82 | ('ㄧㄥ', 'iNg'), 83 | ('ㄨㄥ', 'uNg'), 84 | ('ㄩㄥ', 'yuNg'), 85 | ('ㄥ', 'əNg'), 86 | ('ㄦ', 'əɻ'), 87 | ('ㄧ', 'i'), 88 | ('ㄨ', 'u'), 89 | ('ㄩ', 'ɥ'), 90 | ('ˉ', '→'), 91 | ('ˊ', '↑'), 92 | ('ˇ', '↓↑'), 93 | ('ˋ', '↓'), 94 | ('˙', ''), 95 | (',', ','), 96 | ('。', '.'), 97 | ('!', '!'), 98 | ('?', '?'), 99 | ('—', '-') 100 | ]] 101 | 102 | # List of (romaji, ipa) pairs: 103 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 104 | ('ʃy', 'ʃ'), 105 | ('ʧʰy', 'ʧʰ'), 106 | ('ʧ⁼y', 'ʧ⁼'), 107 | ('NN', 'n'), 108 | ('Ng', 'ŋ'), 109 | ('y', 'j'), 110 | ('h', 'x') 111 | ]] 112 | 113 | # List of (bopomofo, ipa) pairs: 114 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 115 | ('ㄅㄛ', 'p⁼wo'), 116 | ('ㄆㄛ', 'pʰwo'), 117 | ('ㄇㄛ', 'mwo'), 118 | ('ㄈㄛ', 'fwo'), 119 | ('ㄅ', 'p⁼'), 120 | ('ㄆ', 'pʰ'), 121 | ('ㄇ', 'm'), 122 | ('ㄈ', 'f'), 123 | ('ㄉ', 't⁼'), 124 | ('ㄊ', 'tʰ'), 125 | ('ㄋ', 'n'), 126 | ('ㄌ', 'l'), 127 | ('ㄍ', 'k⁼'), 128 | ('ㄎ', 'kʰ'), 129 | ('ㄏ', 'x'), 130 | ('ㄐ', 'tʃ⁼'), 131 | ('ㄑ', 'tʃʰ'), 132 | ('ㄒ', 'ʃ'), 133 | ('ㄓ', 'ts`⁼'), 134 | ('ㄔ', 'ts`ʰ'), 135 | ('ㄕ', 's`'), 136 | ('ㄖ', 'ɹ`'), 137 | ('ㄗ', 'ts⁼'), 138 | ('ㄘ', 'tsʰ'), 139 | ('ㄙ', 's'), 140 | ('ㄚ', 'a'), 141 | ('ㄛ', 'o'), 142 | ('ㄜ', 'ə'), 143 | ('ㄝ', 'ɛ'), 144 | ('ㄞ', 'aɪ'), 145 | ('ㄟ', 'eɪ'), 146 | ('ㄠ', 'ɑʊ'), 147 | ('ㄡ', 'oʊ'), 148 | ('ㄧㄢ', 'jɛn'), 149 | ('ㄩㄢ', 'ɥæn'), 150 | ('ㄢ', 'an'), 151 | ('ㄧㄣ', 'in'), 152 | ('ㄩㄣ', 'ɥn'), 153 | ('ㄣ', 'ən'), 154 | ('ㄤ', 'ɑŋ'), 155 | ('ㄧㄥ', 'iŋ'), 156 | ('ㄨㄥ', 'ʊŋ'), 157 | ('ㄩㄥ', 'jʊŋ'), 158 | ('ㄥ', 'əŋ'), 159 | ('ㄦ', 'əɻ'), 160 | ('ㄧ', 'i'), 161 | ('ㄨ', 'u'), 162 | ('ㄩ', 'ɥ'), 163 | ('ˉ', '→'), 164 | ('ˊ', '↑'), 165 | ('ˇ', '↓↑'), 166 | ('ˋ', '↓'), 167 | ('˙', ''), 168 | (',', ','), 169 | ('。', '.'), 170 | ('!', '!'), 171 | ('?', '?'), 172 | ('—', '-') 173 | ]] 174 | 175 | # List of (bopomofo, ipa2) pairs: 176 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 177 | ('ㄅㄛ', 'pwo'), 178 | ('ㄆㄛ', 'pʰwo'), 179 | ('ㄇㄛ', 'mwo'), 180 | ('ㄈㄛ', 'fwo'), 181 | ('ㄅ', 'p'), 182 | ('ㄆ', 'pʰ'), 183 | ('ㄇ', 'm'), 184 | ('ㄈ', 'f'), 185 | ('ㄉ', 't'), 186 | ('ㄊ', 'tʰ'), 187 | ('ㄋ', 'n'), 188 | ('ㄌ', 'l'), 189 | ('ㄍ', 'k'), 190 | ('ㄎ', 'kʰ'), 191 | ('ㄏ', 'h'), 192 | ('ㄐ', 'tɕ'), 193 | ('ㄑ', 'tɕʰ'), 194 | ('ㄒ', 'ɕ'), 195 | ('ㄓ', 'tʂ'), 196 | ('ㄔ', 'tʂʰ'), 197 | ('ㄕ', 'ʂ'), 198 | ('ㄖ', 'ɻ'), 199 | ('ㄗ', 'ts'), 200 | ('ㄘ', 'tsʰ'), 201 | ('ㄙ', 's'), 202 | ('ㄚ', 'a'), 203 | ('ㄛ', 'o'), 204 | ('ㄜ', 'ɤ'), 205 | ('ㄝ', 'ɛ'), 206 | ('ㄞ', 'aɪ'), 207 | ('ㄟ', 'eɪ'), 208 | ('ㄠ', 'ɑʊ'), 209 | ('ㄡ', 'oʊ'), 210 | ('ㄧㄢ', 'jɛn'), 211 | ('ㄩㄢ', 'yæn'), 212 | ('ㄢ', 'an'), 213 | ('ㄧㄣ', 'in'), 214 | ('ㄩㄣ', 'yn'), 215 | ('ㄣ', 'ən'), 216 | ('ㄤ', 'ɑŋ'), 217 | ('ㄧㄥ', 'iŋ'), 218 | ('ㄨㄥ', 'ʊŋ'), 219 | ('ㄩㄥ', 'jʊŋ'), 220 | ('ㄥ', 'ɤŋ'), 221 | ('ㄦ', 'əɻ'), 222 | ('ㄧ', 'i'), 223 | ('ㄨ', 'u'), 224 | ('ㄩ', 'y'), 225 | ('ˉ', '˥'), 226 | ('ˊ', '˧˥'), 227 | ('ˇ', '˨˩˦'), 228 | ('ˋ', '˥˩'), 229 | ('˙', ''), 230 | (',', ','), 231 | ('。', '.'), 232 | ('!', '!'), 233 | ('?', '?'), 234 | ('—', '-') 235 | ]] 236 | 237 | 238 | def number_to_chinese(text): 239 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 240 | for number in numbers: 241 | text = text.replace(number, cn2an.an2cn(number), 1) 242 | return text 243 | 244 | 245 | 246 | def latin_to_bopomofo(text): 247 | for regex, replacement in _latin_to_bopomofo: 248 | text = re.sub(regex, replacement, text) 249 | return text 250 | 251 | 252 | def bopomofo_to_romaji(text): 253 | for regex, replacement in _bopomofo_to_romaji: 254 | text = re.sub(regex, replacement, text) 255 | return text 256 | 257 | 258 | def bopomofo_to_ipa(text): 259 | for regex, replacement in _bopomofo_to_ipa: 260 | text = re.sub(regex, replacement, text) 261 | return text 262 | 263 | 264 | def bopomofo_to_ipa2(text): 265 | for regex, replacement in _bopomofo_to_ipa2: 266 | text = re.sub(regex, replacement, text) 267 | return text 268 | 269 | 270 | def chinese_to_romaji(text): 271 | text = number_to_chinese(text) 272 | text = zh_to_bopomofo(text) 273 | # text = chinese_to_bopomofo(text) 274 | text = latin_to_bopomofo(text) 275 | text = bopomofo_to_romaji(text) 276 | text = re.sub('i([aoe])', r'y\1', text) 277 | text = re.sub('u([aoəe])', r'w\1', text) 278 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 279 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 280 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 281 | return text 282 | 283 | 284 | def chinese_to_lazy_ipa(text): 285 | text = chinese_to_romaji(text) 286 | for regex, replacement in _romaji_to_ipa: 287 | text = re.sub(regex, replacement, text) 288 | return text 289 | 290 | 291 | def chinese_to_ipa(text): 292 | text = number_to_chinese(text) 293 | text = zh_to_bopomofo(text) 294 | text = latin_to_bopomofo(text) 295 | text = bopomofo_to_ipa(text) 296 | text = re.sub('i([aoe])', r'j\1', text) 297 | text = re.sub('u([aoəe])', r'w\1', text) 298 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 299 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 300 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 301 | return text 302 | 303 | def pinyin_to_ipa(text): 304 | text = pinyin_to_bopomofo(text) 305 | text = latin_to_bopomofo(text) 306 | text = bopomofo_to_ipa(text) 307 | text = re.sub('i([aoe])', r'j\1', text) 308 | text = re.sub('u([aoəe])', r'w\1', text) 309 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 310 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 311 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 312 | text = text.replace("%", " %").replace( "$", " $") 313 | return text 314 | 315 | def chinese_to_ipa2(text): 316 | text = number_to_chinese(text) 317 | text = zh_to_bopomofo(text) 318 | text = latin_to_bopomofo(text) 319 | text = bopomofo_to_ipa2(text) 320 | text = re.sub(r'i([aoe])', r'j\1', text) 321 | text = re.sub(r'u([aoəe])', r'w\1', text) 322 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 323 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 324 | return text 325 | 326 | 327 | if __name__ == '__main__': 328 | # test_text = "[JA]こんにちは。こんにちは\}{ll[JA]abc你好[ZH]你好[ZH][EN]Hello你好.vits![EN][JA]こんにちは。[JA]" 329 | # text = "借还款,他只是一个纸老虎,开户行,奥大家好33啊我是Ab3s,?萨达撒abst 123、~~、、 但是、、、A B C D!" 330 | text = "奥大家,好33啊,こんにちは我是Ab3s,?萨达撒abst 123、~~、、 但*是、、、A B C D!" 331 | text = "奥大家,好33啊" 332 | text = "他只是一个纸老虎。" 333 | # text = '[JA]シシ…すご,,いじゃんシシラシャミョンありがとうえーっとシンモーレシンモーレじゃんシシラシャミョン[JA]' 334 | # text="大家好#要筛#,你好#也#要筛#。" 335 | # text = "嗯?什么东西…沉甸甸的…下午1:00,今天是2022/5/10" 336 | # text = "A I人工智能" 337 | # text = '[P]pin1 yin1 zhen1 hao3 wan2[P]扎堆儿-#' 338 | # text = "[JA]それより仮屋さん,例の件ですが――[JA]" 339 | # text = "早上好,今天是2020/10/29,最低温度是-3°C。" 340 | # # text = "…………" 341 | # print(text_to_sequence(text)) 342 | # 343 | # print(time.time()-t) 344 | print(chinese_to_lazy_ipa(text)) -------------------------------------------------------------------------------- /text/frontend/arpabet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from text.frontend.phonectic import Phonetics 15 | """ 16 | A phonology system with ARPABET symbols and limited punctuations. The G2P 17 | conversion is done by g2p_en. 18 | 19 | Note that g2p_en does not handle words with hypen well. So make sure the input 20 | sentence is first normalized. 21 | """ 22 | from text.frontend.vocab import Vocab 23 | from g2p_en import G2p 24 | 25 | 26 | class ARPABET(Phonetics): 27 | """A phonology for English that uses ARPABET as the phoneme vocabulary. 28 | See http://www.speech.cs.cmu.edu/cgi-bin/cmudict for more details. 29 | Phoneme Example Translation 30 | ------- ------- ----------- 31 | AA odd AA D 32 | AE at AE T 33 | AH hut HH AH T 34 | AO ought AO T 35 | AW cow K AW 36 | AY hide HH AY D 37 | B be B IY 38 | CH cheese CH IY Z 39 | D dee D IY 40 | DH thee DH IY 41 | EH Ed EH D 42 | ER hurt HH ER T 43 | EY ate EY T 44 | F fee F IY 45 | G green G R IY N 46 | HH he HH IY 47 | IH it IH T 48 | IY eat IY T 49 | JH gee JH IY 50 | K key K IY 51 | L lee L IY 52 | M me M IY 53 | N knee N IY 54 | NG ping P IH NG 55 | OW oat OW T 56 | OY toy T OY 57 | P pee P IY 58 | R read R IY D 59 | S sea S IY 60 | SH she SH IY 61 | T tea T IY 62 | TH theta TH EY T AH 63 | UH hood HH UH D 64 | UW two T UW 65 | V vee V IY 66 | W we W IY 67 | Y yield Y IY L D 68 | Z zee Z IY 69 | ZH seizure S IY ZH ER 70 | """ 71 | phonemes = [ 72 | 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 73 | 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 74 | 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UW', 'UH', 'V', 'W', 'Y', 'Z', 75 | 'ZH' 76 | ] 77 | punctuations = [',', '.', '?', '!'] 78 | symbols = phonemes + punctuations 79 | _stress_to_no_stress_ = { 80 | 'AA0': 'AA', 81 | 'AA1': 'AA', 82 | 'AA2': 'AA', 83 | 'AE0': 'AE', 84 | 'AE1': 'AE', 85 | 'AE2': 'AE', 86 | 'AH0': 'AH', 87 | 'AH1': 'AH', 88 | 'AH2': 'AH', 89 | 'AO0': 'AO', 90 | 'AO1': 'AO', 91 | 'AO2': 'AO', 92 | 'AW0': 'AW', 93 | 'AW1': 'AW', 94 | 'AW2': 'AW', 95 | 'AY0': 'AY', 96 | 'AY1': 'AY', 97 | 'AY2': 'AY', 98 | 'EH0': 'EH', 99 | 'EH1': 'EH', 100 | 'EH2': 'EH', 101 | 'ER0': 'ER', 102 | 'ER1': 'ER', 103 | 'ER2': 'ER', 104 | 'EY0': 'EY', 105 | 'EY1': 'EY', 106 | 'EY2': 'EY', 107 | 'IH0': 'IH', 108 | 'IH1': 'IH', 109 | 'IH2': 'IH', 110 | 'IY0': 'IY', 111 | 'IY1': 'IY', 112 | 'IY2': 'IY', 113 | 'OW0': 'OW', 114 | 'OW1': 'OW', 115 | 'OW2': 'OW', 116 | 'OY0': 'OY', 117 | 'OY1': 'OY', 118 | 'OY2': 'OY', 119 | 'UH0': 'UH', 120 | 'UH1': 'UH', 121 | 'UH2': 'UH', 122 | 'UW0': 'UW', 123 | 'UW1': 'UW', 124 | 'UW2': 'UW' 125 | } 126 | 127 | def __init__(self): 128 | self.backend = G2p() 129 | self.vocab = Vocab(self.phonemes + self.punctuations) 130 | 131 | def _remove_vowels(self, phone): 132 | return self._stress_to_no_stress_.get(phone, phone) 133 | 134 | def phoneticize(self, sentence, add_start_end=False): 135 | """ Normalize the input text sequence and convert it into pronunciation sequence. 136 | Args: 137 | sentence (str): The input text sequence. 138 | 139 | Returns: 140 | List[str]: The list of pronunciation sequence. 141 | """ 142 | phonemes = [ 143 | self._remove_vowels(item) for item in self.backend(sentence) 144 | ] 145 | if add_start_end: 146 | start = self.vocab.start_symbol 147 | end = self.vocab.end_symbol 148 | phonemes = [start] + phonemes + [end] 149 | phonemes = [item for item in phonemes if item in self.vocab.stoi] 150 | return phonemes 151 | 152 | def numericalize(self, phonemes): 153 | """ Convert pronunciation sequence into pronunciation id sequence. 154 | 155 | Args: 156 | phonemes (List[str]): The list of pronunciation sequence. 157 | 158 | Returns: 159 | List[int]: The list of pronunciation id sequence. 160 | """ 161 | ids = [self.vocab.lookup(item) for item in phonemes] 162 | return ids 163 | 164 | def reverse(self, ids): 165 | """ Reverse the list of pronunciation id sequence to a list of pronunciation sequence. 166 | 167 | Args: 168 | ids( List[int]): The list of pronunciation id sequence. 169 | 170 | Returns: 171 | List[str]: 172 | The list of pronunciation sequence. 173 | """ 174 | return [self.vocab.reverse(i) for i in ids] 175 | 176 | def __call__(self, sentence, add_start_end=False): 177 | """ Convert the input text sequence into pronunciation id sequence. 178 | 179 | Args: 180 | sentence (str): The input text sequence. 181 | 182 | Returns: 183 | List[str]: The list of pronunciation id sequence. 184 | """ 185 | return self.numericalize( 186 | self.phoneticize(sentence, add_start_end=add_start_end)) 187 | 188 | @property 189 | def vocab_size(self): 190 | """ Vocab size. 191 | """ 192 | # 47 = 39 phones + 4 punctuations + 4 special tokens 193 | return len(self.vocab) 194 | 195 | 196 | class ARPABETWithStress(Phonetics): 197 | phonemes = [ 198 | 'AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0', 199 | 'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 200 | 'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 201 | 'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 202 | 'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 203 | 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW0', 'UW1', 'UW2', 'V', 204 | 'W', 'Y', 'Z', 'ZH' 205 | ] 206 | punctuations = [',', '.', '?', '!'] 207 | symbols = phonemes + punctuations 208 | 209 | def __init__(self): 210 | self.backend = G2p() 211 | self.vocab = Vocab(self.phonemes + self.punctuations) 212 | 213 | def phoneticize(self, sentence, add_start_end=False): 214 | """ Normalize the input text sequence and convert it into pronunciation sequence. 215 | 216 | Args: 217 | sentence (str): The input text sequence. 218 | 219 | Returns: 220 | List[str]: The list of pronunciation sequence. 221 | """ 222 | phonemes = self.backend(sentence) 223 | if add_start_end: 224 | start = self.vocab.start_symbol 225 | end = self.vocab.end_symbol 226 | phonemes = [start] + phonemes + [end] 227 | phonemes = [item for item in phonemes if item in self.vocab.stoi] 228 | return phonemes 229 | 230 | def numericalize(self, phonemes): 231 | """ Convert pronunciation sequence into pronunciation id sequence. 232 | 233 | Args: 234 | phonemes (List[str]): The list of pronunciation sequence. 235 | 236 | Returns: 237 | List[int]: The list of pronunciation id sequence. 238 | """ 239 | ids = [self.vocab.lookup(item) for item in phonemes] 240 | return ids 241 | 242 | def reverse(self, ids): 243 | """ Reverse the list of pronunciation id sequence to a list of pronunciation sequence. 244 | Args: 245 | ids (List[int]): The list of pronunciation id sequence. 246 | 247 | Returns: 248 | List[str]: The list of pronunciation sequence. 249 | """ 250 | return [self.vocab.reverse(i) for i in ids] 251 | 252 | def __call__(self, sentence, add_start_end=False): 253 | """ Convert the input text sequence into pronunciation id sequence. 254 | Args: 255 | sentence (str): The input text sequence. 256 | 257 | Returns: 258 | List[str]: The list of pronunciation id sequence. 259 | """ 260 | return self.numericalize( 261 | self.phoneticize(sentence, add_start_end=add_start_end)) 262 | 263 | @property 264 | def vocab_size(self): 265 | """ Vocab size. 266 | """ 267 | # 77 = 69 phones + 4 punctuations + 4 special tokens 268 | return len(self.vocab) 269 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import os 3 | import sys 4 | import logging 5 | import subprocess 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from scipy.io.wavfile import read 10 | 11 | MATPLOTLIB_FLAG = False 12 | 13 | logging.basicConfig( 14 | stream=sys.stdout, 15 | level=logging.INFO, 16 | format='[%(levelname)s|%(filename)s:%(lineno)s][%(asctime)s] >>> %(message)s' 17 | ) 18 | logger = logging 19 | 20 | 21 | def load_checkpoint(checkpoint_path, rank=0, model_g=None, model_d=None, optim_g=None, optim_d=None): 22 | assert os.path.isfile(checkpoint_path) 23 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 24 | iteration = checkpoint_dict['iteration'] 25 | learning_rate = checkpoint_dict['learning_rate'] 26 | config = checkpoint_dict['config'] 27 | 28 | if model_g is not None: 29 | model_g, optim_g = load_model( 30 | model_g, 31 | checkpoint_dict['model_g'], 32 | optim_g, 33 | checkpoint_dict['optimizer_g']) 34 | 35 | if model_d is not None: 36 | model_d, optim_d = load_model( 37 | model_d, 38 | checkpoint_dict['model_d'], 39 | optim_d, 40 | checkpoint_dict['optimizer_d']) 41 | if rank == 0: 42 | logger.info( 43 | "Loaded checkpoint '{}' (iteration {})".format( 44 | checkpoint_path, 45 | iteration 46 | ) 47 | ) 48 | return model_g, model_d, optim_g, optim_d, learning_rate, iteration, config 49 | 50 | def load_checkpoint_diffsize(checkpoint_path, rank=0, model_g=None, model_d=None): 51 | assert os.path.isfile(checkpoint_path) 52 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 53 | iteration = checkpoint_dict['iteration'] 54 | learning_rate = checkpoint_dict['learning_rate'] 55 | config = checkpoint_dict['config'] 56 | 57 | if model_g is not None: 58 | model_g = load_model_diffsize( 59 | model_g, 60 | checkpoint_dict['model_g']) 61 | if model_d is not None: 62 | model_d = load_model_diffsize( 63 | model_d, 64 | checkpoint_dict['model_d']) 65 | if rank == 0: 66 | logger.info( 67 | "Loaded checkpoint '{}' (iteration {})".format( 68 | checkpoint_path, 69 | iteration 70 | ) 71 | ) 72 | del checkpoint_dict 73 | return model_g, model_d, learning_rate, iteration, config 74 | 75 | def load_model_diffsize(model, model_state_dict): 76 | if hasattr(model, 'module'): 77 | state_dict = model.module.state_dict() 78 | else: 79 | state_dict = model.state_dict() 80 | 81 | for k, v in model_state_dict.items(): 82 | if k in state_dict and state_dict[k].size() == v.size(): 83 | state_dict[k] = v 84 | 85 | if hasattr(model, 'module'): 86 | model.module.load_state_dict(state_dict, strict=False) 87 | else: 88 | model.load_state_dict(state_dict, strict=False) 89 | 90 | return model 91 | 92 | 93 | 94 | def load_model(model, model_state_dict, optim, optim_state_dict): 95 | if optim is not None: 96 | optim.load_state_dict(optim_state_dict) 97 | 98 | if hasattr(model, 'module'): 99 | state_dict = model.module.state_dict() 100 | else: 101 | state_dict = model.state_dict() 102 | 103 | for k, v in model_state_dict.items(): 104 | if k in state_dict and state_dict[k].size() == v.size(): 105 | state_dict[k] = v 106 | 107 | if hasattr(model, 'module'): 108 | model.module.load_state_dict(state_dict) 109 | else: 110 | model.load_state_dict(state_dict) 111 | 112 | return model, optim 113 | 114 | 115 | def save_checkpoint(net_g, optim_g, net_d, optim_d, hps, epoch, learning_rate, save_path): 116 | 117 | def get_state_dict(model): 118 | if hasattr(model, 'module'): 119 | state_dict = model.module.state_dict() 120 | else: 121 | state_dict = model.state_dict() 122 | return state_dict 123 | 124 | torch.save({'model_g': get_state_dict(net_g), 125 | 'model_d': get_state_dict(net_d), 126 | 'optimizer_g': optim_g.state_dict(), 127 | 'optimizer_d': optim_d.state_dict(), 128 | 'config': str(hps), 129 | 'iteration': epoch, 130 | 'learning_rate': learning_rate}, save_path) 131 | 132 | 133 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 134 | for k, v in scalars.items(): 135 | writer.add_scalar(k, v, global_step) 136 | for k, v in histograms.items(): 137 | writer.add_histogram(k, v, global_step) 138 | for k, v in images.items(): 139 | writer.add_image(k, v, global_step, dataformats='HWC') 140 | for k, v in audios.items(): 141 | writer.add_audio(k, v, global_step, audio_sampling_rate) 142 | 143 | 144 | def plot_spectrogram_to_numpy(spectrogram): 145 | global MATPLOTLIB_FLAG 146 | if not MATPLOTLIB_FLAG: 147 | import matplotlib 148 | matplotlib.use("Agg") 149 | MATPLOTLIB_FLAG = True 150 | mpl_logger = logging.getLogger('matplotlib') 151 | mpl_logger.setLevel(logging.WARNING) 152 | import matplotlib.pylab as plt 153 | import numpy as np 154 | 155 | fig, ax = plt.subplots(figsize=(10, 2)) 156 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 157 | interpolation='none') 158 | plt.colorbar(im, ax=ax) 159 | plt.xlabel("Frames") 160 | plt.ylabel("Channels") 161 | plt.tight_layout() 162 | 163 | fig.canvas.draw() 164 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 165 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 166 | plt.close() 167 | return data 168 | 169 | 170 | def plot_alignment_to_numpy(alignment, info=None): 171 | global MATPLOTLIB_FLAG 172 | if not MATPLOTLIB_FLAG: 173 | import matplotlib 174 | matplotlib.use("Agg") 175 | MATPLOTLIB_FLAG = True 176 | mpl_logger = logging.getLogger('matplotlib') 177 | mpl_logger.setLevel(logging.WARNING) 178 | import matplotlib.pylab as plt 179 | import numpy as np 180 | 181 | fig, ax = plt.subplots(figsize=(6, 4)) 182 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 183 | interpolation='none') 184 | fig.colorbar(im, ax=ax) 185 | xlabel = 'Decoder timestep' 186 | if info is not None: 187 | xlabel += '\n\n' + info 188 | plt.xlabel(xlabel) 189 | plt.ylabel('Encoder timestep') 190 | plt.tight_layout() 191 | 192 | fig.canvas.draw() 193 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 194 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 195 | plt.close() 196 | return data 197 | 198 | 199 | def load_wav_to_torch(full_path): 200 | sampling_rate, wav = read(full_path) 201 | 202 | if len(wav.shape) == 2: 203 | wav = wav[:, 0] 204 | 205 | if wav.dtype == np.int16: 206 | wav = wav / 32768.0 207 | elif wav.dtype == np.int32: 208 | wav = wav / 2147483648.0 209 | elif wav.dtype == np.uint8: 210 | wav = (wav - 128) / 128.0 211 | wav = wav.astype(np.float32) 212 | return torch.FloatTensor(wav), sampling_rate 213 | 214 | 215 | def load_filepaths_and_text(filename, split="|"): 216 | with open(filename, encoding='utf-8') as f: 217 | filepaths_and_text = [line.strip().split(split) for line in f] 218 | return filepaths_and_text 219 | 220 | 221 | def get_hparams(args, init=True): 222 | config = OmegaConf.load(args.config) 223 | hparams = HParams(**config) 224 | model_dir = os.path.join(hparams.train.log_path, args.model) 225 | 226 | if not os.path.exists(model_dir): 227 | os.makedirs(model_dir) 228 | hparams.model_name = args.model 229 | hparams.model_dir = model_dir 230 | config_save_path = os.path.join(model_dir, "config.yaml") 231 | 232 | if init: 233 | OmegaConf.save(config, config_save_path) 234 | 235 | return hparams 236 | 237 | 238 | def get_hparams_from_file(config_path): 239 | config = OmegaConf.load(config_path) 240 | hparams = HParams(**config) 241 | return hparams 242 | 243 | 244 | def check_git_hash(model_dir): 245 | source_dir = os.path.dirname(os.path.realpath(__file__)) 246 | if not os.path.exists(os.path.join(source_dir, ".git")): 247 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 248 | source_dir 249 | )) 250 | return 251 | 252 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 253 | 254 | path = os.path.join(model_dir, "githash") 255 | if os.path.exists(path): 256 | saved_hash = open(path).read() 257 | if saved_hash != cur_hash: 258 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 259 | saved_hash[:8], cur_hash[:8])) 260 | else: 261 | open(path, "w").write(cur_hash) 262 | 263 | 264 | def get_logger(model_dir, filename="train.log"): 265 | global logger 266 | logger = logging.getLogger(os.path.basename(model_dir)) 267 | logger.setLevel(logging.DEBUG) 268 | 269 | formatter = logging.Formatter( 270 | "%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 271 | if not os.path.exists(model_dir): 272 | os.makedirs(model_dir) 273 | h = logging.FileHandler(os.path.join(model_dir, filename)) 274 | h.setLevel(logging.DEBUG) 275 | h.setFormatter(formatter) 276 | logger.addHandler(h) 277 | return logger 278 | 279 | 280 | class HParams(): 281 | def __init__(self, **kwargs): 282 | for k, v in kwargs.items(): 283 | if type(v) == dict: 284 | v = HParams(**v) 285 | self[k] = v 286 | 287 | def keys(self): 288 | return self.__dict__.keys() 289 | 290 | def items(self): 291 | return self.__dict__.items() 292 | 293 | def values(self): 294 | return self.__dict__.values() 295 | 296 | def __len__(self): 297 | return len(self.__dict__) 298 | 299 | def __getitem__(self, key): 300 | return getattr(self, key) 301 | 302 | def __setitem__(self, key, value): 303 | return setattr(self, key, value) 304 | 305 | def __contains__(self, key): 306 | return key in self.__dict__ 307 | 308 | def __repr__(self): 309 | return self.__dict__.__repr__() 310 | -------------------------------------------------------------------------------- /text/frontend/zh_frontend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | from typing import Dict 16 | from typing import List 17 | 18 | import jieba.posseg as psg 19 | import numpy as np 20 | from g2pM import G2pM 21 | from pypinyin import lazy_pinyin 22 | from pypinyin import load_phrases_dict 23 | from pypinyin import load_single_dict 24 | from pypinyin import Style 25 | from pypinyin_dict.phrase_pinyin_data import large_pinyin 26 | 27 | from text.frontend.generate_lexicon import generate_lexicon 28 | from text.frontend.tone_sandhi import ToneSandhi 29 | from text.frontend.zh_normalization.text_normlization import TextNormalizer 30 | 31 | class Frontend(): 32 | def __init__(self, 33 | g2p_model="pypinyin", 34 | phone_vocab_path=None, 35 | tone_vocab_path=None): 36 | self.tone_modifier = ToneSandhi() 37 | self.text_normalizer = TextNormalizer() 38 | 39 | self.punc = ['!', '?', '…', ",", ".", "#", '-', "%", "$"] 40 | # g2p_model can be pypinyin and g2pM 41 | self.g2p_model = g2p_model 42 | self.add_word_sep = True 43 | if self.g2p_model == "g2pM": 44 | self.g2pM_model = G2pM() 45 | self.pinyin2phone = generate_lexicon( 46 | with_tone=True, with_erhua=False) 47 | else: 48 | 49 | self.__init__pypinyin() 50 | self.must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"} 51 | self.not_erhua = { 52 | "虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿", 53 | "拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿", 54 | "流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿", 55 | "孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿", 56 | "狗儿" 57 | } 58 | self.vocab_phones = {} 59 | self.vocab_tones = {} 60 | if phone_vocab_path: 61 | with open(phone_vocab_path, 'rt') as f: 62 | phn_id = [line.strip().split() for line in f.readlines()] 63 | for phn, id in phn_id: 64 | self.vocab_phones[phn] = int(id) 65 | if tone_vocab_path: 66 | with open(tone_vocab_path, 'rt') as f: 67 | tone_id = [line.strip().split() for line in f.readlines()] 68 | for tone, id in tone_id: 69 | self.vocab_tones[tone] = int(id) 70 | print("initialized zh frontend") 71 | def __init__pypinyin(self): 72 | large_pinyin.load() 73 | # 74 | # load_phrases_dict({u'开户行': [[u'ka1i'], [u'hu4'], [u'hang2']]}) 75 | # load_phrases_dict({u'发卡行': [[u'fa4'], [u'ka3'], [u'hang2']]}) 76 | # load_phrases_dict({u'放款行': [[u'fa4ng'], [u'kua3n'], [u'hang2']]}) 77 | # load_phrases_dict({u'茧行': [[u'jia3n'], [u'hang2']]}) 78 | # load_phrases_dict({u'行号': [[u'hang2'], [u'ha4o']]}) 79 | # load_phrases_dict({u'各地': [[u'ge4'], [u'di4']]}) 80 | # load_phrases_dict({u'借还款': [[u'jie4'], [u'hua2n'], [u'kua3n']]}) 81 | # load_phrases_dict({u'时间为': [[u'shi2'], [u'jia1n'], [u'we2i']]}) 82 | # load_phrases_dict({u'为准': [[u'we2i'], [u'zhu3n']]}) 83 | # load_phrases_dict({u'色差': [[u'se4'], [u'cha1']]}) 84 | 85 | # 调整字的拼音顺序 86 | load_single_dict({ord(u'地'): u'de,di4'}) 87 | 88 | def _get_initials_finals(self, word: str) -> List[List[str]]: 89 | initials = [] 90 | finals = [] 91 | if self.g2p_model == "pypinyin": 92 | orig_initials = lazy_pinyin( 93 | word, neutral_tone_with_five=True, style=Style.INITIALS) 94 | orig_finals = lazy_pinyin( 95 | word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) 96 | for c, v in zip(orig_initials, orig_finals): 97 | if re.match(r'i\d', v): 98 | if c in ['z', 'c', 's']: 99 | v = re.sub('i', 'ii', v) 100 | elif c in ['zh', 'ch', 'sh', 'r']: 101 | v = re.sub('i', 'iii', v) 102 | initials.append(c) 103 | finals.append(v) 104 | elif self.g2p_model == "g2pM": 105 | pinyins = self.g2pM_model(word, tone=True, char_split=False) 106 | for pinyin in pinyins: 107 | pinyin = pinyin.replace("u:", "v") 108 | if pinyin in self.pinyin2phone: 109 | initial_final_list = self.pinyin2phone[pinyin].split(" ") 110 | if len(initial_final_list) == 2: 111 | initials.append(initial_final_list[0]) 112 | finals.append(initial_final_list[1]) 113 | elif len(initial_final_list) == 1: 114 | initials.append('') 115 | finals.append(initial_final_list[1]) 116 | else: 117 | # If it's not pinyin (possibly punctuation) or no conversion is required 118 | initials.append(pinyin) 119 | finals.append(pinyin) 120 | return initials, finals 121 | 122 | # if merge_sentences, merge all sentences into one phone sequence 123 | def _g2p(self, 124 | sentences: List[str], 125 | merge_sentences: bool=True, 126 | with_erhua: bool=True) -> List[List[str]]: 127 | segments = sentences 128 | phones_list = [] 129 | for seg in segments: 130 | phones = [] 131 | # Replace all English words in the sentence 132 | seg = re.sub('[a-zA-Z]+', '', seg) 133 | seg_cut = psg.lcut(seg) 134 | initials = [] 135 | finals = [] 136 | seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut) 137 | for word, pos in seg_cut: 138 | if self.add_word_sep and word == "#": 139 | continue 140 | if pos == 'eng': 141 | continue 142 | sub_initials, sub_finals = self._get_initials_finals(word) 143 | sub_finals = self.tone_modifier.modified_tone(word, pos, 144 | sub_finals) 145 | if with_erhua: 146 | sub_initials, sub_finals = self._merge_erhua( 147 | sub_initials, sub_finals, word, pos) 148 | initials.append(sub_initials) 149 | finals.append(sub_finals) 150 | if self.add_word_sep and word not in self.punc: 151 | initials.append(["#"]) 152 | finals.append(["#"]) 153 | 154 | # assert len(sub_initials) == len(sub_finals) == len(word) 155 | initials = sum(initials, []) 156 | finals = sum(finals, []) 157 | 158 | for c, v in zip(initials, finals): 159 | # NOTE: post process for pypinyin outputs 160 | # we discriminate i, ii and iii 161 | if c: 162 | phones.append(c) 163 | if v and v not in self.punc: 164 | phones.append(v) 165 | 166 | phones_list.append(phones) 167 | if merge_sentences: 168 | merge_list = sum(phones_list, []) 169 | # rm the last 'sp' to avoid the noise at the end 170 | # cause in the training data, no 'sp' in the end 171 | if merge_list[-1] == 'sp': 172 | merge_list = merge_list[:-1] 173 | phones_list = [] 174 | phones_list.append(merge_list) 175 | return phones_list 176 | 177 | def _merge_erhua(self, 178 | initials: List[str], 179 | finals: List[str], 180 | word: str, 181 | pos: str) -> List[List[str]]: 182 | if word not in self.must_erhua and (word in self.not_erhua or 183 | pos in {"a", "j", "nr"}): 184 | return initials, finals 185 | # "……" 等情况直接返回 186 | if len(finals) != len(word): 187 | return initials, finals 188 | 189 | assert len(finals) == len(word) 190 | 191 | new_initials = [] 192 | new_finals = [] 193 | for i, phn in enumerate(finals): 194 | if i == len(finals) - 1 and word[i] == "儿" and phn in { 195 | "er2", "er5" 196 | } and word[-2:] not in self.not_erhua and new_finals: 197 | new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1] 198 | else: 199 | new_finals.append(phn) 200 | new_initials.append(initials[i]) 201 | return new_initials, new_finals 202 | 203 | def _p2id(self, phonemes: List[str]) -> np.array: 204 | # replace unk phone with sp 205 | phonemes = [ 206 | phn if phn in self.vocab_phones else "sp" for phn in phonemes 207 | ] 208 | phone_ids = [self.vocab_phones[item] for item in phonemes] 209 | return np.array(phone_ids, np.int64) 210 | 211 | def _t2id(self, tones: List[str]) -> np.array: 212 | # replace unk phone with sp 213 | tones = [tone if tone in self.vocab_tones else "0" for tone in tones] 214 | tone_ids = [self.vocab_tones[item] for item in tones] 215 | return np.array(tone_ids, np.int64) 216 | 217 | def _get_phone_tone(self, phonemes: List[str], 218 | get_tone_ids: bool=False) -> List[List[str]]: 219 | phones = [] 220 | tones = [] 221 | if get_tone_ids and self.vocab_tones: 222 | for full_phone in phonemes: 223 | # split tone from finals 224 | match = re.match(r'^(\w+)([012345])$', full_phone) 225 | if match: 226 | phone = match.group(1) 227 | tone = match.group(2) 228 | # if the merged erhua not in the vocab 229 | # assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, we split 'iaor' into ['iao','er'] 230 | # and the tones accordingly change from ['3'] to ['3','2'], while '2' is the tone of 'er2' 231 | if len(phone) >= 2 and phone != "er" and phone[ 232 | -1] == 'r' and phone not in self.vocab_phones and phone[: 233 | -1] in self.vocab_phones: 234 | phones.append(phone[:-1]) 235 | phones.append("er") 236 | tones.append(tone) 237 | tones.append("2") 238 | else: 239 | phones.append(phone) 240 | tones.append(tone) 241 | else: 242 | phones.append(full_phone) 243 | tones.append('0') 244 | else: 245 | for phone in phonemes: 246 | # if the merged erhua not in the vocab 247 | # assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, change ['iaor3'] to ['iao3','er2'] 248 | if len(phone) >= 3 and phone[:-1] != "er" and phone[ 249 | -2] == 'r' and phone not in self.vocab_phones and ( 250 | phone[:-2] + phone[-1]) in self.vocab_phones: 251 | phones.append((phone[:-2] + phone[-1])) 252 | phones.append("er2") 253 | else: 254 | phones.append(phone) 255 | return phones, tones 256 | 257 | def get_phonemes(self, 258 | sentence: str, 259 | merge_sentences: bool=True, 260 | with_erhua: bool=False, 261 | robot: bool=False, 262 | print_info: bool=False) -> List[List[str]]: 263 | sentence = sentence.replace("嗯", "恩") 264 | sentences = self.text_normalizer.normalize(sentence) 265 | phonemes = self._g2p( 266 | sentences, merge_sentences=merge_sentences, with_erhua=with_erhua) 267 | # change all tones to `1` 268 | if robot: 269 | new_phonemes = [] 270 | for sentence in phonemes: 271 | new_sentence = [] 272 | for item in sentence: 273 | # `er` only have tone `2` 274 | if item[-1] in "12345" and item != "er2": 275 | item = item[:-1] + "1" 276 | new_sentence.append(item) 277 | new_phonemes.append(new_sentence) 278 | phonemes = new_phonemes 279 | if print_info: 280 | print("----------------------------") 281 | print("text norm results:") 282 | print(sentences) 283 | print("----------------------------") 284 | print("g2p results:") 285 | print(phonemes) 286 | print("----------------------------") 287 | return phonemes 288 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/jaywalnut310/vits 2 | import os 3 | import random 4 | import torch 5 | import torch.utils.data 6 | 7 | import commons 8 | from mel_processing import spectrogram_torch 9 | from utils import load_wav_to_torch, load_filepaths_and_text 10 | from text import cleaned_text_to_sequence 11 | from analysis import Pitch 12 | """ Modified from Multi speaker version of VITS""" 13 | 14 | 15 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 16 | """ 17 | 1) loads audio, speaker_id, text pairs 18 | 2) normalizes text and converts them to sequences of integers 19 | 3) computes spectrograms from audio files. 20 | """ 21 | 22 | def __init__(self, audiopaths_sid_text, hparams, pt_run=False): 23 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 24 | self.text_cleaners = hparams.text_cleaners 25 | self.sampling_rate = hparams.sampling_rate 26 | self.filter_length = hparams.filter_length 27 | self.hop_length = hparams.hop_length 28 | self.win_length = hparams.win_length 29 | 30 | self.lang = hparams.languages 31 | 32 | self.add_blank = hparams.add_blank 33 | self.min_text_len = 1 34 | self.max_text_len = 190 35 | 36 | self.speaker_dict = { 37 | speaker: idx 38 | for idx, speaker in enumerate(hparams.speakers) 39 | } 40 | self.data_path = hparams.data_path 41 | 42 | self.pitch = Pitch(sr=hparams.sampling_rate, 43 | W=hparams.tau_max, 44 | tau_max=hparams.tau_max, 45 | midi_start=hparams.midi_start, 46 | midi_end=hparams.midi_end, 47 | octave_range=hparams.octave_range) 48 | 49 | random.seed(1234) 50 | random.shuffle(self.audiopaths_sid_text) 51 | self._filter() 52 | if pt_run: 53 | for _audiopaths_sid_text in self.audiopaths_sid_text: 54 | _ = self.get_audio_text_speaker_pair(_audiopaths_sid_text, 55 | True) 56 | 57 | def _filter(self): 58 | """ 59 | Filter text & store spec lengths 60 | """ 61 | # Store spectrogram lengths for Bucketing 62 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 63 | # spec_length = wav_length // hop_length 64 | 65 | audiopaths_sid_text_new = [] 66 | lengths = [] 67 | for audiopath, spk, text, lang in self.audiopaths_sid_text: 68 | if self.min_text_len <= len(text) and len( 69 | text) <= self.max_text_len: 70 | audiopath = os.path.join(self.data_path, audiopath) 71 | if not os.path.exists(audiopath): 72 | print(audiopath, "not exist!") 73 | continue 74 | try: 75 | audio, sampling_rate = load_wav_to_torch(audiopath) 76 | except: 77 | print(audiopath, "load error!") 78 | continue 79 | audiopaths_sid_text_new.append([audiopath, spk, text, lang]) 80 | lengths.append( 81 | os.path.getsize(audiopath) // (2 * self.hop_length)) 82 | self.audiopaths_sid_text = audiopaths_sid_text_new 83 | self.lengths = lengths 84 | 85 | def get_audio_text_speaker_pair(self, audiopath_sid_text, pt_run=False): 86 | # separate filename, speaker_id and text 87 | audiopath, spk, text, lang = audiopath_sid_text 88 | text, lang = self.get_text(text, lang) 89 | spec, ying, wav = self.get_audio(audiopath, pt_run) 90 | sid = self.get_sid(self.speaker_dict[spk]) 91 | return (text, spec, ying, wav, sid, lang) 92 | 93 | def get_audio(self, filename, pt_run=False): 94 | audio, sampling_rate = load_wav_to_torch(filename) 95 | if sampling_rate != self.sampling_rate: 96 | raise ValueError("{} {} SR doesn't match target {} SR".format( 97 | sampling_rate, self.sampling_rate)) 98 | audio_norm = audio.unsqueeze(0) 99 | spec_filename = filename.replace(".wav", ".spec.pt") 100 | ying_filename = filename.replace(".wav", ".ying.pt") 101 | if os.path.exists(spec_filename) and not pt_run: 102 | spec = torch.load(spec_filename, map_location='cpu') 103 | else: 104 | spec = spectrogram_torch(audio_norm, 105 | self.filter_length, 106 | self.sampling_rate, 107 | self.hop_length, 108 | self.win_length, 109 | center=False) 110 | spec = torch.squeeze(spec, 0) 111 | torch.save(spec, spec_filename) 112 | if os.path.exists(ying_filename) and not pt_run: 113 | ying = torch.load(ying_filename, map_location='cpu') 114 | else: 115 | wav = torch.nn.functional.pad( 116 | audio_norm.unsqueeze(0), 117 | (self.filter_length - self.hop_length, 118 | self.filter_length - self.hop_length + 119 | (-audio_norm.shape[1]) % self.hop_length + self.hop_length * (audio_norm.shape[1] % self.hop_length == 0)), 120 | mode='constant').squeeze(0) 121 | ying = self.pitch.yingram(wav)[0] 122 | torch.save(ying, ying_filename) 123 | return spec, ying, audio_norm 124 | 125 | def get_text(self, text, lang): 126 | text_norm = cleaned_text_to_sequence(text) 127 | lang = [int(i) for i in lang.split(" ")] 128 | if self.add_blank: 129 | text_norm,lang = commons.intersperse_with_language_id(text_norm,lang, 0) 130 | text_norm = torch.LongTensor(text_norm) 131 | lang = torch.LongTensor(lang) 132 | return text_norm, lang 133 | 134 | def get_sid(self, sid): 135 | sid = torch.LongTensor([int(sid)]) 136 | return sid 137 | 138 | def __getitem__(self, index): 139 | return self.get_audio_text_speaker_pair( 140 | self.audiopaths_sid_text[index]) 141 | 142 | def __len__(self): 143 | return len(self.audiopaths_sid_text) 144 | 145 | 146 | class TextAudioSpeakerCollate(): 147 | """ Zero-pads model inputs and targets""" 148 | 149 | def __init__(self, return_ids=False): 150 | self.return_ids = return_ids 151 | 152 | def __call__(self, batch): 153 | """Collate's training batch from normalized text, audio and speaker identities 154 | PARAMS 155 | ------ 156 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 157 | """ 158 | # Right zero-pad all one-hot text sequences to max input length 159 | _, ids_sorted_decreasing = torch.sort(torch.LongTensor( 160 | [x[1].size(1) for x in batch]), 161 | dim=0, 162 | descending=True) 163 | 164 | max_text_len = max([len(x[0]) for x in batch]) 165 | max_spec_len = max([x[1].size(1) for x in batch]) 166 | max_ying_len = max([x[2].size(1) for x in batch]) 167 | max_wav_len = max([x[3].size(1) for x in batch]) 168 | 169 | text_lengths = torch.LongTensor(len(batch)) 170 | spec_lengths = torch.LongTensor(len(batch)) 171 | ying_lengths = torch.LongTensor(len(batch)) 172 | wav_lengths = torch.LongTensor(len(batch)) 173 | sid = torch.LongTensor(len(batch)) 174 | 175 | text_padded = torch.LongTensor(len(batch), max_text_len) 176 | tone_padded = torch.LongTensor(len(batch), max_text_len) 177 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), 178 | max_spec_len) 179 | ying_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), 180 | max_ying_len) 181 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 182 | text_padded.zero_() 183 | tone_padded.zero_() 184 | spec_padded.zero_() 185 | ying_padded.zero_() 186 | wav_padded.zero_() 187 | for i in range(len(ids_sorted_decreasing)): 188 | row = batch[ids_sorted_decreasing[i]] 189 | 190 | text = row[0] 191 | text_padded[i, :text.size(0)] = text 192 | text_lengths[i] = text.size(0) 193 | 194 | spec = row[1] 195 | spec_padded[i, :, :spec.size(1)] = spec 196 | spec_lengths[i] = spec.size(1) 197 | 198 | ying = row[2] 199 | ying_padded[i, :, :ying.size(1)] = ying 200 | ying_lengths[i] = ying.size(1) 201 | 202 | wav = row[3] 203 | wav_padded[i, :, :wav.size(1)] = wav 204 | wav_lengths[i] = wav.size(1) 205 | 206 | tone = row[5] 207 | tone_padded[i, :text.size(0)] = tone 208 | 209 | sid[i] = row[4] 210 | 211 | if self.return_ids: 212 | return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing 213 | return text_padded, text_lengths, spec_padded, spec_lengths, ying_padded, ying_lengths, wav_padded, wav_lengths, sid, tone_padded 214 | 215 | 216 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler 217 | ): 218 | """ 219 | Maintain similar input lengths in a batch. 220 | Length groups are specified by boundaries. 221 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 222 | 223 | It removes samples which are not included in the boundaries. 224 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 225 | """ 226 | 227 | def __init__(self, 228 | dataset, 229 | batch_size, 230 | boundaries, 231 | num_replicas=None, 232 | rank=None, 233 | shuffle=True): 234 | super().__init__(dataset, 235 | num_replicas=num_replicas, 236 | rank=rank, 237 | shuffle=shuffle) 238 | self.lengths = dataset.lengths 239 | self.batch_size = batch_size 240 | self.boundaries = boundaries 241 | 242 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 243 | self.total_size = sum(self.num_samples_per_bucket) 244 | self.num_samples = self.total_size // self.num_replicas 245 | 246 | def _create_buckets(self): 247 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 248 | for i in range(len(self.lengths)): 249 | length = self.lengths[i] 250 | idx_bucket = self._bisect(length) 251 | if idx_bucket != -1: 252 | buckets[idx_bucket].append(i) 253 | 254 | for i in range(len(buckets) - 1, -1, -1): 255 | if len(buckets[i]) == 0: 256 | buckets.pop(i) 257 | self.boundaries.pop(i + 1) 258 | 259 | num_samples_per_bucket = [] 260 | for i in range(len(buckets)): 261 | len_bucket = len(buckets[i]) 262 | total_batch_size = self.num_replicas * self.batch_size 263 | rem = (total_batch_size - 264 | (len_bucket % total_batch_size)) % total_batch_size 265 | num_samples_per_bucket.append(len_bucket + rem) 266 | return buckets, num_samples_per_bucket 267 | 268 | def __iter__(self): 269 | # deterministically shuffle based on epoch 270 | g = torch.Generator() 271 | g.manual_seed(self.epoch) 272 | 273 | indices = [] 274 | if self.shuffle: 275 | for bucket in self.buckets: 276 | indices.append( 277 | torch.randperm(len(bucket), generator=g).tolist()) 278 | else: 279 | for bucket in self.buckets: 280 | indices.append(list(range(len(bucket)))) 281 | 282 | batches = [] 283 | for i in range(len(self.buckets)): 284 | bucket = self.buckets[i] 285 | len_bucket = len(bucket) 286 | ids_bucket = indices[i] 287 | num_samples_bucket = self.num_samples_per_bucket[i] 288 | 289 | # add extra samples to make it evenly divisible 290 | rem = num_samples_bucket - len_bucket 291 | ids_bucket = ids_bucket + ids_bucket * \ 292 | (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] 293 | 294 | # subsample 295 | ids_bucket = ids_bucket[self.rank::self.num_replicas] 296 | 297 | # batching 298 | for j in range(len(ids_bucket) // self.batch_size): 299 | batch = [ 300 | bucket[idx] 301 | for idx in ids_bucket[j * self.batch_size:(j + 1) * 302 | self.batch_size] 303 | ] 304 | batches.append(batch) 305 | 306 | if self.shuffle: 307 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 308 | batches = [batches[i] for i in batch_ids] 309 | self.batches = batches 310 | 311 | assert len(self.batches) * self.batch_size == self.num_samples 312 | return iter(self.batches) 313 | 314 | def _bisect(self, x, lo=0, hi=None): 315 | if hi is None: 316 | hi = len(self.boundaries) - 1 317 | 318 | if hi > lo: 319 | mid = (hi + lo) // 2 320 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 321 | return mid 322 | elif x <= self.boundaries[mid]: 323 | return self._bisect(x, lo, mid) 324 | else: 325 | return self._bisect(x, mid + 1, hi) 326 | else: 327 | return -1 328 | 329 | def __len__(self): 330 | return self.num_samples // self.batch_size 331 | 332 | 333 | def create_spec(audiopaths_sid_text, hparams): 334 | audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 335 | for audiopath, _, _, _ in audiopaths_sid_text: 336 | audiopath = os.path.join(hparams.data_path, audiopath) 337 | if not os.path.exists(audiopath): 338 | print(audiopath, "not exist!") 339 | continue 340 | try: 341 | audio, sampling_rate = load_wav_to_torch(audiopath) 342 | except: 343 | print(audiopath, "load error!") 344 | continue 345 | if sampling_rate != hparams.sampling_rate: 346 | raise ValueError("{} {} SR doesn't match target {} SR".format( 347 | sampling_rate, hparams.sampling_rate)) 348 | audio_norm = audio.unsqueeze(0) 349 | specpath = audiopath.replace(".wav", ".spec.pt") 350 | 351 | if not os.path.exists(specpath): 352 | spec = spectrogram_torch(audio_norm, 353 | hparams.filter_length, 354 | hparams.sampling_rate, 355 | hparams.hop_length, 356 | hparams.win_length, 357 | center=False) 358 | spec = torch.squeeze(spec, 0) 359 | torch.save(spec, specpath) 360 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/jaywalnut310/vits 2 | import math 3 | import torch 4 | from torch import nn 5 | from torch.nn import Conv1d 6 | from torch.nn import functional as F 7 | from torch.nn.utils import weight_norm, remove_weight_norm 8 | 9 | import commons 10 | from commons import init_weights, get_padding 11 | from transforms import piecewise_rational_quadratic_transform 12 | 13 | 14 | LRELU_SLOPE = 0.1 15 | 16 | 17 | class LayerNorm(nn.Module): 18 | def __init__(self, channels, eps=1e-5): 19 | super().__init__() 20 | self.channels = channels 21 | self.eps = eps 22 | 23 | self.gamma = nn.Parameter(torch.ones(channels)) 24 | self.beta = nn.Parameter(torch.zeros(channels)) 25 | 26 | def forward(self, x): 27 | x = x.transpose(1, -1) 28 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 29 | return x.transpose(1, -1) 30 | 31 | 32 | class ConvReluNorm(nn.Module): 33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 34 | super().__init__() 35 | self.in_channels = in_channels 36 | self.hidden_channels = hidden_channels 37 | self.out_channels = out_channels 38 | self.kernel_size = kernel_size 39 | self.n_layers = n_layers 40 | self.p_dropout = p_dropout 41 | assert n_layers > 1, "Number of layers should be larger than 0." 42 | 43 | self.conv_layers = nn.ModuleList() 44 | self.norm_layers = nn.ModuleList() 45 | self.conv_layers.append( 46 | nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2) 47 | ) 48 | self.norm_layers.append(LayerNorm(hidden_channels)) 49 | self.relu_drop = nn.Sequential( 50 | nn.ReLU(), 51 | nn.Dropout(p_dropout)) 52 | for _ in range(n_layers-1): 53 | self.conv_layers.append(nn.Conv1d( 54 | hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2) 55 | ) 56 | self.norm_layers.append(LayerNorm(hidden_channels)) 57 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 58 | self.proj.weight.data.zero_() 59 | self.proj.bias.data.zero_() 60 | 61 | def forward(self, x, x_mask): 62 | x_org = x 63 | for i in range(self.n_layers): 64 | x = self.conv_layers[i](x * x_mask) 65 | x = self.norm_layers[i](x) 66 | x = self.relu_drop(x) 67 | x = x_org + self.proj(x) 68 | return x * x_mask 69 | 70 | 71 | class DDSConv(nn.Module): 72 | """Dialted and Depth-Separable Convolution""" 73 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 74 | super().__init__() 75 | self.channels = channels 76 | self.kernel_size = kernel_size 77 | self.n_layers = n_layers 78 | self.p_dropout = p_dropout 79 | 80 | self.drop = nn.Dropout(p_dropout) 81 | self.convs_sep = nn.ModuleList() 82 | self.convs_1x1 = nn.ModuleList() 83 | self.norms_1 = nn.ModuleList() 84 | self.norms_2 = nn.ModuleList() 85 | for i in range(n_layers): 86 | dilation = kernel_size ** i 87 | padding = (kernel_size * dilation - dilation) // 2 88 | self.convs_sep.append( 89 | nn.Conv1d( 90 | channels, 91 | channels, 92 | kernel_size, 93 | groups=channels, 94 | dilation=dilation, 95 | padding=padding 96 | ) 97 | ) 98 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 99 | self.norms_1.append(LayerNorm(channels)) 100 | self.norms_2.append(LayerNorm(channels)) 101 | 102 | def forward(self, x, x_mask, g=None): 103 | if g is not None: 104 | x = x + g 105 | for i in range(self.n_layers): 106 | y = self.convs_sep[i](x * x_mask) 107 | y = self.norms_1[i](y) 108 | y = F.gelu(y) 109 | y = self.convs_1x1[i](y) 110 | y = self.norms_2[i](y) 111 | y = F.gelu(y) 112 | y = self.drop(y) 113 | x = x + y 114 | return x * x_mask 115 | 116 | 117 | class WN(torch.nn.Module): 118 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 119 | super(WN, self).__init__() 120 | assert(kernel_size % 2 == 1) 121 | self.hidden_channels = hidden_channels 122 | self.kernel_size = kernel_size, 123 | self.dilation_rate = dilation_rate 124 | self.n_layers = n_layers 125 | self.gin_channels = gin_channels 126 | self.p_dropout = p_dropout 127 | 128 | self.in_layers = torch.nn.ModuleList() 129 | self.res_skip_layers = torch.nn.ModuleList() 130 | self.drop = nn.Dropout(p_dropout) 131 | 132 | if gin_channels != 0: 133 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 134 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 135 | 136 | for i in range(n_layers): 137 | dilation = dilation_rate ** i 138 | padding = int((kernel_size * dilation - dilation) / 2) 139 | in_layer = torch.nn.Conv1d( 140 | hidden_channels, 141 | 2*hidden_channels, 142 | kernel_size, 143 | dilation=dilation, 144 | padding=padding 145 | ) 146 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 147 | self.in_layers.append(in_layer) 148 | 149 | # last one is not necessary 150 | if i < n_layers - 1: 151 | res_skip_channels = 2 * hidden_channels 152 | else: 153 | res_skip_channels = hidden_channels 154 | 155 | res_skip_layer = torch.nn.Conv1d( 156 | hidden_channels, res_skip_channels, 1 157 | ) 158 | res_skip_layer = torch.nn.utils.weight_norm( 159 | res_skip_layer, name='weight' 160 | ) 161 | self.res_skip_layers.append(res_skip_layer) 162 | 163 | def forward(self, x, x_mask, g=None, **kwargs): 164 | output = torch.zeros_like(x) 165 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 166 | 167 | if g is not None: 168 | g = self.cond_layer(g) 169 | 170 | for i in range(self.n_layers): 171 | x_in = self.in_layers[i](x) 172 | if g is not None: 173 | cond_offset = i * 2 * self.hidden_channels 174 | g_l = g[:, cond_offset:cond_offset+2*self.hidden_channels, :] 175 | else: 176 | g_l = torch.zeros_like(x_in) 177 | 178 | acts = commons.fused_add_tanh_sigmoid_multiply( 179 | x_in, 180 | g_l, 181 | n_channels_tensor 182 | ) 183 | acts = self.drop(acts) 184 | 185 | res_skip_acts = self.res_skip_layers[i](acts) 186 | if i < self.n_layers - 1: 187 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 188 | x = (x + res_acts) * x_mask 189 | output = output + res_skip_acts[:, self.hidden_channels:, :] 190 | else: 191 | output = output + res_skip_acts 192 | return output * x_mask 193 | 194 | def remove_weight_norm(self): 195 | if self.gin_channels != 0: 196 | torch.nn.utils.remove_weight_norm(self.cond_layer) 197 | for l in self.in_layers: 198 | torch.nn.utils.remove_weight_norm(l) 199 | for l in self.res_skip_layers: 200 | torch.nn.utils.remove_weight_norm(l) 201 | 202 | 203 | class ResBlock1(torch.nn.Module): 204 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 205 | super(ResBlock1, self).__init__() 206 | self.convs1 = nn.ModuleList([ 207 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 208 | padding=get_padding(kernel_size, dilation[0]))), 209 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 210 | padding=get_padding(kernel_size, dilation[1]))), 211 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 212 | padding=get_padding(kernel_size, dilation[2]))) 213 | ]) 214 | self.convs1.apply(init_weights) 215 | 216 | self.convs2 = nn.ModuleList([ 217 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 218 | padding=get_padding(kernel_size, 1))), 219 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 220 | padding=get_padding(kernel_size, 1))), 221 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 222 | padding=get_padding(kernel_size, 1))) 223 | ]) 224 | self.convs2.apply(init_weights) 225 | 226 | def forward(self, x, x_mask=None): 227 | for c1, c2 in zip(self.convs1, self.convs2): 228 | xt = F.leaky_relu(x, LRELU_SLOPE) 229 | if x_mask is not None: 230 | xt = xt * x_mask 231 | xt = c1(xt) 232 | xt = F.leaky_relu(xt, LRELU_SLOPE) 233 | if x_mask is not None: 234 | xt = xt * x_mask 235 | xt = c2(xt) 236 | x = xt + x 237 | if x_mask is not None: 238 | x = x * x_mask 239 | return x 240 | 241 | def remove_weight_norm(self): 242 | for l in self.convs1: 243 | remove_weight_norm(l) 244 | for l in self.convs2: 245 | remove_weight_norm(l) 246 | 247 | 248 | class ResBlock2(torch.nn.Module): 249 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 250 | super(ResBlock2, self).__init__() 251 | self.convs = nn.ModuleList([ 252 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 253 | padding=get_padding(kernel_size, dilation[0]))), 254 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 255 | padding=get_padding(kernel_size, dilation[1]))) 256 | ]) 257 | self.convs.apply(init_weights) 258 | 259 | def forward(self, x, x_mask=None): 260 | for c in self.convs: 261 | xt = F.leaky_relu(x, LRELU_SLOPE) 262 | if x_mask is not None: 263 | xt = xt * x_mask 264 | xt = c(xt) 265 | x = xt + x 266 | if x_mask is not None: 267 | x = x * x_mask 268 | return x 269 | 270 | def remove_weight_norm(self): 271 | for l in self.convs: 272 | remove_weight_norm(l) 273 | 274 | 275 | class Log(nn.Module): 276 | def forward(self, x, x_mask, reverse=False, **kwargs): 277 | if not reverse: 278 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 279 | logdet = torch.sum(-y, [1, 2]) 280 | return y, logdet 281 | else: 282 | x = torch.exp(x) * x_mask 283 | return x 284 | 285 | 286 | class Flip(nn.Module): 287 | def forward(self, x, *args, reverse=False, **kwargs): 288 | x = torch.flip(x, [1]) 289 | if not reverse: 290 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 291 | return x, logdet 292 | else: 293 | return x 294 | 295 | 296 | class ElementwiseAffine(nn.Module): 297 | def __init__(self, channels): 298 | super().__init__() 299 | self.channels = channels 300 | self.m = nn.Parameter(torch.zeros(channels, 1)) 301 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 302 | 303 | def forward(self, x, x_mask, reverse=False, **kwargs): 304 | if not reverse: 305 | y = self.m + torch.exp(self.logs) * x 306 | y = y * x_mask 307 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 308 | return y, logdet 309 | else: 310 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 311 | return x 312 | 313 | 314 | class ResidualCouplingLayer(nn.Module): 315 | def __init__( 316 | self, 317 | channels, 318 | hidden_channels, 319 | kernel_size, 320 | dilation_rate, 321 | n_layers, 322 | p_dropout=0, 323 | gin_channels=0, 324 | mean_only=False 325 | ): 326 | assert channels % 2 == 0, "channels should be divisible by 2" 327 | super().__init__() 328 | self.channels = channels 329 | self.hidden_channels = hidden_channels 330 | self.kernel_size = kernel_size 331 | self.dilation_rate = dilation_rate 332 | self.n_layers = n_layers 333 | self.half_channels = channels // 2 334 | self.mean_only = mean_only 335 | 336 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 337 | self.enc = WN( 338 | hidden_channels, 339 | kernel_size, 340 | dilation_rate, 341 | n_layers, 342 | p_dropout=p_dropout, 343 | gin_channels=gin_channels 344 | ) 345 | self.post = nn.Conv1d( 346 | hidden_channels, self.half_channels * (2 - mean_only), 1 347 | ) 348 | self.post.weight.data.zero_() 349 | self.post.bias.data.zero_() 350 | 351 | def forward(self, x, x_mask, g=None, reverse=False): 352 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 353 | h = self.pre(x0) * x_mask 354 | h = self.enc(h, x_mask, g=g) 355 | stats = self.post(h) * x_mask 356 | if not self.mean_only: 357 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 358 | else: 359 | m = stats 360 | logs = torch.zeros_like(m) 361 | 362 | if not reverse: 363 | x1 = m + x1 * torch.exp(logs) * x_mask 364 | x = torch.cat([x0, x1], 1) 365 | logdet = torch.sum(logs, [1, 2]) 366 | return x, logdet 367 | else: 368 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 369 | x = torch.cat([x0, x1], 1) 370 | return x 371 | 372 | 373 | class ConvFlow(nn.Module): 374 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 375 | super().__init__() 376 | self.in_channels = in_channels 377 | self.filter_channels = filter_channels 378 | self.kernel_size = kernel_size 379 | self.n_layers = n_layers 380 | self.num_bins = num_bins 381 | self.tail_bound = tail_bound 382 | self.half_channels = in_channels // 2 383 | 384 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 385 | self.convs = DDSConv( 386 | filter_channels, kernel_size, n_layers, p_dropout=0. 387 | ) 388 | self.proj = nn.Conv1d( 389 | filter_channels, self.half_channels * (num_bins * 3 - 1), 1 390 | ) 391 | self.proj.weight.data.zero_() 392 | self.proj.bias.data.zero_() 393 | 394 | def forward(self, x, x_mask, g=None, reverse=False): 395 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 396 | h = self.pre(x0) 397 | h = self.convs(h, x_mask, g=g) 398 | h = self.proj(h) * x_mask 399 | 400 | b, c, t = x0.shape 401 | # [b, cx?, t] -> [b, c, t, ?] 402 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) 403 | 404 | unnormalized_widths = h[..., :self.num_bins] / \ 405 | math.sqrt(self.filter_channels) 406 | unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / \ 407 | math.sqrt(self.filter_channels) 408 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 409 | 410 | x1, logabsdet = piecewise_rational_quadratic_transform( 411 | x1, 412 | unnormalized_widths, 413 | unnormalized_heights, 414 | unnormalized_derivatives, 415 | inverse=reverse, 416 | tails='linear', 417 | tail_bound=self.tail_bound 418 | ) 419 | 420 | x = torch.cat([x0, x1], 1) * x_mask 421 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 422 | if not reverse: 423 | return x, logdet 424 | else: 425 | return x 426 | --------------------------------------------------------------------------------