├── LICENSE ├── PL_BERT_ja ├── Configs │ └── config.yml ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── __pycache__ │ ├── convert_label.cpython-38.pyc │ ├── model.cpython-38.pyc │ ├── phonemize.cpython-38.pyc │ └── text_utils.cpython-38.pyc ├── convert_label.py ├── dataloader.py ├── model.py ├── phonemize.py ├── preprocess.py ├── simple_loader.py ├── text │ ├── __pycache__ │ │ ├── cmudict.cpython-38.pyc │ │ ├── pinyin.cpython-38.pyc │ │ └── symbols.cpython-38.pyc │ ├── cmudict.py │ ├── numbers.py │ ├── pinyin.py │ └── symbols.py ├── text_utils.py ├── train.py └── utils.py ├── README.md ├── attentions.py ├── commons.py ├── configs └── jvnv_base.json ├── data_utils.py ├── inference.py ├── losses.py ├── mel_processing.py ├── models.py ├── modules.py ├── monotonic_align ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── build │ ├── lib.linux-x86_64-cpython-38 │ │ └── monotonic_align │ │ │ └── core.cpython-38-x86_64-linux-gnu.so │ └── temp.linux-x86_64-cpython-38 │ │ └── core.o ├── core.c ├── core.pyx ├── monotonic_align │ └── core.cpython-38-x86_64-linux-gnu.so └── setup.py ├── preprocess_ja.py ├── requirements.txt ├── text ├── LICENSE ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── cleaners.cpython-38.pyc │ └── symbols.cpython-38.pyc ├── cleaners.py └── symbols.py ├── train_ms.py ├── transforms.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PL_BERT_ja/Configs/config.yml: -------------------------------------------------------------------------------- 1 | log_dir: "Checkpoint" 2 | mixed_precision: "fp16" 3 | data_folder: "./dataset/wikipedia-ja/processed_data" 4 | batch_size: 8 5 | save_interval: 500000 6 | log_interval: 10 7 | num_process: 1 # number of GPUs 8 | num_steps: 10000000 9 | 10 | dataset_params: 11 | tokenizer: "cl-tohoku/bert-base-japanese-whole-word-masking" 12 | token_separator: " " # token used for phoneme separator (space) 13 | token_mask: "M" # token used for phoneme mask (M) 14 | word_separator: 3 # token used for word separator ([SEP]) 15 | token_maps: "token_maps.pkl" # token map path 16 | 17 | max_mel_length: 512 # max phoneme length 18 | 19 | word_mask_prob: 0.15 # probability to mask the entire word 20 | phoneme_mask_prob: 0.1 # probability to mask each phoneme 21 | replace_prob: 0.2 # probablity to replace phonemes 22 | 23 | model_params: 24 | vocab_size: 379 25 | hidden_size: 768 26 | num_attention_heads: 12 27 | intermediate_size: 2048 28 | max_position_embeddings: 512 29 | num_hidden_layers: 12 30 | dropout: 0.1 31 | -------------------------------------------------------------------------------- /PL_BERT_ja/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Aaron (Yinghao) Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PL_BERT_ja/Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | pandas = "*" 8 | singleton-decorator = "*" 9 | datasets = "*" 10 | transformers = "*" 11 | accelerate = "*" 12 | nltk = "*" 13 | phonemizer = "*" 14 | sacremoses = "*" 15 | pebble = "*" 16 | jupyterlab = "*" 17 | pathlib = "*" 18 | apache-beam = "*" 19 | tqdm = "*" 20 | fugashi = "*" 21 | ipadic = "*" 22 | mwparserfromhell = "*" 23 | install = "*" 24 | mecab-python3 = "*" 25 | cython = "*" 26 | pyopenjtalk = "*" 27 | tensorboard = "*" 28 | 29 | [dev-packages] 30 | 31 | [requires] 32 | python_version = "3.8" 33 | python_full_version = "3.8.7" 34 | -------------------------------------------------------------------------------- /PL_BERT_ja/__pycache__/convert_label.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/PL_BERT_ja/__pycache__/convert_label.cpython-38.pyc -------------------------------------------------------------------------------- /PL_BERT_ja/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/PL_BERT_ja/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /PL_BERT_ja/__pycache__/phonemize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/PL_BERT_ja/__pycache__/phonemize.cpython-38.pyc -------------------------------------------------------------------------------- /PL_BERT_ja/__pycache__/text_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/PL_BERT_ja/__pycache__/text_utils.cpython-38.pyc -------------------------------------------------------------------------------- /PL_BERT_ja/convert_label.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import sys 4 | 5 | 6 | class ExtentionException(Exception): 7 | pass 8 | 9 | class EmptyLabelException(Exception): 10 | pass 11 | 12 | 13 | class Segment: 14 | """ 15 | a unit of speech (i.e. phoneme, mora) 16 | """ 17 | def __init__(self, tStart, tEnd, label): 18 | self.tStart = tStart 19 | self.tEnd = tEnd 20 | self.label = label 21 | 22 | def __add__(self, other): 23 | return Segment(self.tStart, other.tEnd, self.label + other.label) 24 | 25 | def can_follow(self, other): 26 | """ 27 | return True if Segment self can follow Segment other in one mora, 28 | otherwise return False 29 | example: (other, self) 30 | True: ('s', 'a'), ('sh', 'i'), ('ky', 'o:'), ('t', 's') 31 | False: ('a', 'q'), ('a', 's'), ('u', 'e'), ('s', 'ha') 32 | """ 33 | vowels = ['a', 'i', 'u', 'e', 'o', 'a:', 'i:', 'u:', 'e:', 'o:'] 34 | consonants = ['w', 'r', 't', 'y', 'p', 's', 'd', 'f', 'g', 'h', 'j', 35 | 'k', 'z', 'c', 'b', 'n', 'm'] 36 | only_consonants = lambda x: all([c in consonants for c in x]) 37 | if only_consonants(other.label) and self.label in vowels: 38 | return True 39 | if only_consonants(other.label) and only_consonants(self.label): 40 | return True 41 | return False 42 | 43 | def to_textgrid_lines(self, segmentIndex): 44 | label = '' if self.label in ['silB', 'silE'] else self.label 45 | return [f' intervals [{segmentIndex}]:', 46 | f' xmin = {self.tStart} ', 47 | f' xmax = {self.tEnd} ', 48 | f' text = "{label}" '] 49 | 50 | 51 | 52 | def openjtalk2julius(p3): 53 | if p3 in ['A','I','U',"E", "O"]: 54 | return p3.lower() 55 | if p3 == 'cl': 56 | return 'q' 57 | if p3 == 'pau': 58 | return 'sp' 59 | return p3 60 | 61 | def read_lab(filename): 62 | """ 63 | read label file (.lab) generated by Julius segmentation kit and 64 | return SegmentationLabel object 65 | """ 66 | try: 67 | if not re.search(r'\.lab$', filename): 68 | raise ExtentionException("read_lab supports only .lab") 69 | except ExtentionException as e: 70 | print(e) 71 | return None 72 | 73 | with open(filename, 'r') as f: 74 | labeldata = [line.split() for line in f if line != ''] 75 | segments = [Segment(tStart=float(line[0])/10e6, tEnd=float(line[1])/10e6, 76 | label=openjtalk2julius(re.search(r"\-(.*?)\+", line[2]).group(1))) for line in labeldata] 77 | return SegmentationLabel(segments) 78 | 79 | 80 | class SegmentationLabel: 81 | """ 82 | list of segments 83 | """ 84 | def __init__(self, segments, separatedByMora=False): 85 | self.segments = segments 86 | self.separatedByMora = separatedByMora 87 | 88 | def by_moras(self): 89 | """ 90 | return new SegmentationLabel object whose segment are moras 91 | """ 92 | if self.separatedByMora == True: 93 | return self 94 | 95 | moraSegments = [] 96 | curMoraSegment = None 97 | for segment in self.segments: 98 | if curMoraSegment is None: 99 | curMoraSegment = segment 100 | elif segment.can_follow(curMoraSegment): 101 | curMoraSegment += segment 102 | else: 103 | moraSegments.append(curMoraSegment) 104 | curMoraSegment = segment 105 | if curMoraSegment: 106 | moraSegments.append(curMoraSegment) 107 | return SegmentationLabel(moraSegments, separatedByMora=True) 108 | 109 | def _textgrid_headers(self): 110 | segmentKind = 'mora' if self.separatedByMora else 'phones' 111 | return ['File type = "ooTextFile"', 112 | 'Object class = "TextGrid"', 113 | ' ', 114 | 'xmin = 0 ', 115 | f'xmax = {self.segments[-1].tEnd} ', 116 | 'tiers? ', 117 | 'size = 1 ', 118 | 'item []: ', 119 | ' item [1]: ', 120 | ' class = "IntervalTier" ', 121 | f' name = "{segmentKind}" ', 122 | ' xmin = 0 ', 123 | f' xmax = {self.segments[-1].tEnd} ', 124 | f' intervals: size = {len(self.segments)} '] 125 | 126 | def to_textgrid(self, textgridFileName): 127 | """ 128 | save to .TextGrid file, which is available for Praat 129 | """ 130 | try: 131 | if not self.segments: 132 | raise EmptyLabelException(f'warning: no label data found in ' 133 | f'{textgridFileName}') 134 | except EmptyLabelException as e: 135 | print(e) 136 | return 137 | 138 | textgridLines = self._textgrid_headers() 139 | for i, segment in enumerate(self.segments): 140 | textgridLines.extend(segment.to_textgrid_lines(i + 1)) 141 | with open(textgridFileName, 'w') as f: 142 | f.write('\n'.join(textgridLines)) 143 | 144 | 145 | if __name__ == '__main__': 146 | args = sys.argv 147 | if len(args) >= 2: 148 | mainDirectory = args[1] 149 | else: 150 | mainDirectory = os.curdir 151 | 152 | answer = None 153 | while not answer in ['y', 'Y', 'n', 'N']: 154 | answer = input('change segmentation unit to mora?'\ 155 | ' (default:phoneme) y/n:') 156 | choosesMora = answer in ['y', 'Y'] 157 | 158 | for dirPath, dirNames, fileNames in os.walk(mainDirectory): 159 | labFileNames = [n for n in fileNames if re.search(r'\.lab$', n)] 160 | 161 | for labFileName in labFileNames: 162 | label = read_lab(os.path.join(dirPath, labFileName)) 163 | if choosesMora: 164 | label = label.by_moras() 165 | textgridFileName = re.sub(r"\.lab$", ".TextGrid", labFileName) 166 | label.to_textgrid(os.path.join(dirPath, textgridFileName)) 167 | -------------------------------------------------------------------------------- /PL_BERT_ja/dataloader.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | 3 | import os 4 | import os.path as osp 5 | import time 6 | import random 7 | import numpy as np 8 | import random 9 | import string 10 | import pickle 11 | 12 | from datasets import load_from_disk 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | import yaml 18 | 19 | from text_utils import TextCleaner 20 | 21 | import logging 22 | logger = logging.getLogger(__name__) 23 | logger.setLevel(logging.DEBUG) 24 | 25 | np.random.seed(1) 26 | random.seed(1) 27 | 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | 31 | class FilePathDataset(torch.utils.data.Dataset): 32 | def __init__(self, dataset, 33 | token_maps="token_maps.pkl", 34 | tokenizer="cl-tohoku/bert-base-japanese-whole-word-masking", 35 | word_separator=3, 36 | token_separator=" ", 37 | token_mask="M", 38 | max_mel_length=512, 39 | word_mask_prob=0.15, 40 | phoneme_mask_prob=0.1, 41 | replace_prob=0.2): 42 | 43 | self.data = dataset 44 | self.max_mel_length = max_mel_length 45 | self.word_mask_prob = word_mask_prob 46 | self.phoneme_mask_prob = phoneme_mask_prob 47 | self.replace_prob = replace_prob 48 | self.text_cleaner = TextCleaner() 49 | 50 | self.word_separator = word_separator 51 | self.token_separator = token_separator 52 | self.token_mask = token_mask 53 | 54 | with open(token_maps, 'rb') as handle: 55 | self.token_maps = pickle.load(handle) 56 | 57 | def __len__(self): 58 | return len(self.data) 59 | 60 | def __getitem__(self, idx): 61 | 62 | phonemes = self.data[idx]['phonemes'] 63 | input_ids = self.data[idx]['input_ids'] 64 | 65 | words = [] 66 | labels = "" 67 | phoneme = "" 68 | 69 | phoneme_list = ''.join(phonemes) 70 | masked_index = [] 71 | for z in zip(phonemes, input_ids): 72 | z = list(z) 73 | 74 | words.extend([z[1]] * len(z[0])) 75 | # words.append(self.word_separator) 76 | 77 | labels += z[0] 78 | # labels += self.token_separator 79 | 80 | if np.random.rand() < self.word_mask_prob: 81 | if np.random.rand() < self.replace_prob: 82 | if np.random.rand() < (self.phoneme_mask_prob / self.replace_prob): 83 | phoneme += ''.join([phoneme_list[np.random.randint(0, len(phoneme_list))] for _ in range(len(z[0]))]) # randomized 84 | else: 85 | phoneme += z[0] 86 | else: 87 | phoneme += self.token_mask * len(z[0]) # masked 88 | 89 | masked_index.extend((np.arange(len(phoneme) - len(z[0]), len(phoneme))).tolist()) 90 | else: 91 | phoneme += z[0] 92 | # phoneme += self.token_separator 93 | 94 | mel_length = len(phoneme) 95 | masked_idx = np.array(masked_index) 96 | masked_index = [] 97 | if mel_length > self.max_mel_length: 98 | random_start = np.random.randint(0, mel_length - self.max_mel_length) 99 | phoneme = phoneme[random_start:random_start + self.max_mel_length] 100 | words = words[random_start:random_start + self.max_mel_length] 101 | labels = labels[random_start:random_start + self.max_mel_length] 102 | 103 | for m in masked_idx: 104 | if m >= random_start and m < random_start + self.max_mel_length: 105 | masked_index.append(m - random_start) 106 | 107 | phoneme = self.text_cleaner(phoneme) 108 | labels = self.text_cleaner(labels) 109 | words = [self.token_maps[w]['token'] for w in words] 110 | 111 | assert len(phoneme) == len(words) 112 | assert len(phoneme) == len(labels) 113 | 114 | phonemes = torch.LongTensor(phoneme) 115 | labels = torch.LongTensor(labels) 116 | words = torch.LongTensor(words) 117 | 118 | return phonemes, words, labels, masked_index 119 | 120 | class Collater(object): 121 | """ 122 | Args: 123 | adaptive_batch_size (bool): if true, decrease batch size when long data comes. 124 | """ 125 | 126 | def __init__(self, return_wave=False): 127 | self.text_pad_index = 0 128 | self.return_wave = return_wave 129 | 130 | 131 | def __call__(self, batch): 132 | # batch[0] = wave, mel, text, f0, speakerid 133 | batch_size = len(batch) 134 | 135 | # sort by mel length 136 | lengths = [b[1].shape[0] for b in batch] 137 | batch_indexes = np.argsort(lengths)[::-1] 138 | batch = [batch[bid] for bid in batch_indexes] 139 | 140 | max_text_length = max([b[1].shape[0] for b in batch]) 141 | 142 | words = torch.zeros((batch_size, max_text_length)).long() 143 | labels = torch.zeros((batch_size, max_text_length)).long() 144 | phonemes = torch.zeros((batch_size, max_text_length)).long() 145 | input_lengths = [] 146 | masked_indices = [] 147 | for bid, (phoneme, word, label, masked_index) in enumerate(batch): 148 | 149 | text_size = phoneme.size(0) 150 | words[bid, :text_size] = word 151 | labels[bid, :text_size] = label 152 | phonemes[bid, :text_size] = phoneme 153 | input_lengths.append(text_size) 154 | masked_indices.append(masked_index) 155 | 156 | return words, labels, phonemes, input_lengths, masked_indices 157 | 158 | 159 | def build_dataloader(df, 160 | validation=False, 161 | batch_size=4, 162 | num_workers=1, 163 | device=torch.device("cpu"), 164 | collate_config={}, 165 | dataset_config={}): 166 | 167 | dataset = FilePathDataset(df, **dataset_config) 168 | collate_fn = Collater(**collate_config) 169 | data_loader = DataLoader(dataset, 170 | batch_size=batch_size, 171 | shuffle=(not validation), 172 | num_workers=num_workers, 173 | drop_last=(not validation), 174 | collate_fn=collate_fn, 175 | pin_memory=(device != torch.device("cpu"))) 176 | 177 | return data_loader 178 | 179 | 180 | if __name__ == '__main__': 181 | config_path = "Configs/config.yml" # you can change it to anything else 182 | config = yaml.safe_load(open(config_path)) 183 | dataset = load_from_disk(config['data_folder']) 184 | train_loader = build_dataloader(dataset, batch_size=1, num_workers=0, dataset_config=config['dataset_params']) 185 | print(len(dataset)) 186 | print(len(train_loader)) 187 | _, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader)) 188 | print(words, labels, phonemes, input_lengths, masked_indices) 189 | -------------------------------------------------------------------------------- /PL_BERT_ja/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class MultiTaskModel(nn.Module): 6 | def __init__(self, model, num_tokens, num_vocab, hidden_size): 7 | super().__init__() 8 | 9 | self.encoder = model 10 | self.mask_predictor = nn.Linear(hidden_size, num_tokens) 11 | self.word_predictor = nn.Linear(hidden_size, num_vocab) 12 | 13 | def forward(self, phonemes, attention_mask=None): 14 | output = self.encoder(phonemes, attention_mask=attention_mask, output_hidden_states=True) 15 | tokens_pred = self.mask_predictor(output.last_hidden_state) 16 | words_pred = self.word_predictor(output.last_hidden_state) 17 | 18 | return tokens_pred, words_pred, output 19 | -------------------------------------------------------------------------------- /PL_BERT_ja/phonemize.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | import pyopenjtalk 4 | import unicodedata 5 | 6 | from PL_BERT_ja.convert_label import openjtalk2julius 7 | 8 | 9 | # 実装上の問題から2文字の音素を1文字の適当な記号にマッピング 10 | _japanese = ['ky','sp', 'sh', 'ch', 'ts','ty', 'ry', 'ny', 'by', 'hy', 'gy', 'kw', 'gw', 'kj', 'gj', 'my', 'py','dy'] 11 | japanese = ['$', '%', '&', '「', '」', '=', '~', '^', '|', '[', ']', '{', '}', '*', '+', '#', '<', '>'] 12 | _japanese2japanese = { 13 | 'ky': '$', 14 | 'sp': '%', 15 | 'sh': '&', 16 | 'ch': '「', 17 | 'ts': '」', 18 | 'ty': '=', 19 | 'ry': '~', 20 | 'ny': '^', 21 | 'by': '|', 22 | 'hy': '[', 23 | 'gy': ']', 24 | 'kw': '{', 25 | 'gw': '}', 26 | 'kj': '*', 27 | 'gj': '+', 28 | 'my': '#', 29 | 'py': '<', 30 | 'dy': '>', 31 | } 32 | 33 | 34 | def global_phonemize(text: str): 35 | if text == "?" or text == "!": 36 | return text 37 | if text == "。" or text == "、": 38 | return "_" # padとして扱う 39 | phonemes = pyopenjtalk.g2p(text).split(' ') 40 | phonemes = [openjtalk2julius(p) for p in phonemes if p != ''] 41 | for i in range(len(phonemes)): 42 | phoneme = phonemes[i] 43 | if phoneme in _japanese: 44 | phonemes[i] = _japanese2japanese[phoneme] 45 | return phonemes 46 | 47 | 48 | def phonemize(text, tokenizer): 49 | text = unicodedata.normalize("NFKC", text) 50 | words = tokenizer.tokenize(text) 51 | input_ids_ = tokenizer.convert_tokens_to_ids(words) 52 | 53 | phonemes = [] 54 | input_ids = [] 55 | for i in range(len(words)): 56 | word = words[i] 57 | input_id = input_ids_[i] 58 | phoneme = global_phonemize(word.replace('#', '')) 59 | if len(phoneme) != 0: 60 | phonemes.append(''.join(phoneme)) 61 | input_ids.append(input_id) 62 | 63 | assert len(input_ids) == len(phonemes) 64 | return {'input_ids' : input_ids, 'phonemes': phonemes} -------------------------------------------------------------------------------- /PL_BERT_ja/preprocess.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import TimeoutError 2 | import os 3 | from pebble import ProcessPool 4 | import pickle 5 | 6 | from dataloader import build_dataloader as build_trainloader 7 | import datasets 8 | from datasets import load_from_disk, concatenate_datasets 9 | import pathlib 10 | import phonemizer 11 | import torch 12 | from tqdm import tqdm 13 | from transformers import BertJapaneseTokenizer 14 | import yaml 15 | 16 | from simple_loader import FilePathDataset, build_dataloader 17 | from phonemize import phonemize 18 | 19 | device = "cuda" if torch.cuda.is_available() else "cpu" 20 | 21 | 22 | def process_shard(i): 23 | directory = root_directory + "/shard_" + str(i) 24 | if os.path.exists(directory): 25 | print("Shard %d already exists!" % i) 26 | return 27 | print('Processing shard %d ...' % i) 28 | shard = dataset.shard(num_shards=num_shards, index=i) 29 | processed_dataset = shard.map(lambda t: phonemize(t['text'], tokenizer), remove_columns=['text']) 30 | if not os.path.exists(directory): 31 | os.makedirs(directory) 32 | processed_dataset.save_to_disk(directory) 33 | 34 | 35 | if __name__ == '__main__': 36 | ##### config ##### 37 | config_path = "Configs/config.yml" # you can change it to anything else 38 | config = yaml.safe_load(open(config_path)) 39 | 40 | ##### set tokenizer ##### 41 | tokenizer = BertJapaneseTokenizer.from_pretrained(config['dataset_params']['tokenizer']) 42 | 43 | ##### download dataset ##### 44 | # comment out the following line in hogehoge/datasets/wikipedia/wikipedia.py 45 | # | "Distribute" >> beam.transforms.Reshuffle() 46 | datasets.config.DOWNLOADED_DATASETS_PATH = pathlib.Path("./dataset/wikipedia-ja") 47 | dataset = datasets.load_dataset( 48 | 'wikipedia', language="ja", date="20230601", beam_runner="DirectRunner", 49 | cache_dir="./dataset/wikipedia-ja/.cache" 50 | ) 51 | dataset = dataset['train'] 52 | 53 | ##### make shards ##### 54 | root_directory = "./wiki_phoneme" 55 | num_shards = 50000 56 | max_workers = 20 # change this to the number of CPU cores your machine has 57 | with ProcessPool(max_workers=max_workers) as pool: 58 | pool.map(process_shard, range(num_shards), timeout=60) 59 | 60 | ##### correct shards ##### 61 | output = [dI for dI in os.listdir(root_directory) if os.path.isdir(os.path.join(root_directory,dI))] 62 | datasets = [] 63 | for o in output: 64 | directory = root_directory + "/" + o 65 | try: 66 | shard = load_from_disk(directory) 67 | datasets.append(shard) 68 | print("%s loaded" % o) 69 | except: 70 | continue 71 | dataset = concatenate_datasets(datasets) 72 | dataset.save_to_disk(config['data_folder']) 73 | print('Dataset saved to %s' % config['data_folder']) 74 | 75 | ##### Remove unneccessary tokens from the pre-trained tokenizer ##### 76 | dataset = load_from_disk(config['data_folder']) 77 | file_data = FilePathDataset(dataset) 78 | loader = build_dataloader(file_data, num_workers=20, batch_size=128, device=device) 79 | 80 | special_token = config['dataset_params']['word_separator'] 81 | 82 | unique_index = [special_token] 83 | for _, batch in enumerate(tqdm(loader)): 84 | unique_index.extend(batch["input_ids"]) 85 | unique_index = list(set(unique_index)) 86 | 87 | token_maps = {} 88 | for t in tqdm(unique_index): 89 | word = tokenizer.decode([t]) 90 | token_maps[t] = {'word': word, 'token': unique_index.index(t)} 91 | 92 | with open(config['dataset_params']['token_maps'], 'wb') as handle: 93 | pickle.dump(token_maps, handle) 94 | print('Token mapper saved to %s' % config['dataset_params']['token_maps']) 95 | -------------------------------------------------------------------------------- /PL_BERT_ja/simple_loader.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | 3 | import os 4 | import os.path as osp 5 | import time 6 | import random 7 | import numpy as np 8 | import random 9 | 10 | import string 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | 17 | import logging 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.DEBUG) 20 | 21 | np.random.seed(1) 22 | random.seed(1) 23 | 24 | class FilePathDataset(torch.utils.data.Dataset): 25 | def __init__(self, df): 26 | self.data = df 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, idx): 32 | input_ids = self.data[idx] 33 | 34 | return input_ids 35 | 36 | class Collater(object): 37 | """ 38 | Args: 39 | adaptive_batch_size (bool): if true, decrease batch size when long data comes. 40 | """ 41 | 42 | def __init__(self, return_wave=False): 43 | self.text_pad_index = 0 44 | self.return_wave = return_wave 45 | 46 | 47 | def __call__(self, batch): 48 | # batch[0] = wave, mel, text, f0, speakerid 49 | batch_size = len(batch) 50 | input_ids = [] 51 | phonemes = [] 52 | 53 | for bid, (data) in enumerate(batch): 54 | 55 | input_ids.extend(data['input_ids']) 56 | phonemes.extend(data['phonemes']) 57 | 58 | return {"input_ids": input_ids, "phonemes": phonemes} 59 | 60 | def build_dataloader(df, 61 | validation=False, 62 | batch_size=4, 63 | num_workers=1, 64 | device='cpu', 65 | collate_config={}, 66 | dataset_config={}): 67 | 68 | dataset = FilePathDataset(df, **dataset_config) 69 | collate_fn = Collater(**collate_config) 70 | data_loader = DataLoader(dataset, 71 | batch_size=batch_size, 72 | shuffle=(not validation), 73 | num_workers=num_workers, 74 | drop_last=(not validation), 75 | collate_fn=collate_fn, 76 | pin_memory=(device != 'cpu')) 77 | 78 | return data_loader -------------------------------------------------------------------------------- /PL_BERT_ja/text/__pycache__/cmudict.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/PL_BERT_ja/text/__pycache__/cmudict.cpython-38.pyc -------------------------------------------------------------------------------- /PL_BERT_ja/text/__pycache__/pinyin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/PL_BERT_ja/text/__pycache__/pinyin.cpython-38.pyc -------------------------------------------------------------------------------- /PL_BERT_ja/text/__pycache__/symbols.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/PL_BERT_ja/text/__pycache__/symbols.cpython-38.pyc -------------------------------------------------------------------------------- /PL_BERT_ja/text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /PL_BERT_ja/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /PL_BERT_ja/text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /PL_BERT_ja/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from PL_BERT_ja.text import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | _japanese = ['ky','sp', 'sh', 'ch', 'ts','ty', 'ry', 'ny', 'by', 'hy', 'gy', 'kw', 'gw', 'kj', 'gj', 'my', 'py','dy'] 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | + _japanese 30 | ) 31 | -------------------------------------------------------------------------------- /PL_BERT_ja/text_utils.py: -------------------------------------------------------------------------------- 1 | from PL_BERT_ja.text import cmudict, pinyin 2 | import pyopenjtalk 3 | import PL_BERT_ja.phonemize 4 | 5 | _pad = "_" 6 | _punctuation = "!'(),.:;? " 7 | _special = "-" 8 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 9 | _silences = ["@sp", "@spn", "@sil"] 10 | # _japanese = ['ky','sp', 'sh', 'ch', 'ts','ty', 'ry', 'ny', 'by', 'hy', 'gy', 'kw', 'gw', 'kj', 'gj', 'my', 'py','dy'] 11 | japanese = ['$', '%', '&', '「', '」', '=', '~', '^', '|', '[', ']', '{', '}', '*', '+', '#', '<', '>'] 12 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 13 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 14 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 15 | 16 | # Export all symbols: 17 | symbols = ( 18 | [_pad] 19 | + list(_special) 20 | + list(_punctuation) 21 | + list(_letters) 22 | + _arpabet 23 | + _pinyin 24 | + _silences 25 | + japanese 26 | ) 27 | 28 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 29 | 30 | class TextCleaner: 31 | def __init__(self, dummy=None): 32 | self.word_index_dictionary = symbol_to_id 33 | def __call__(self, text): 34 | indexes = [] 35 | japanese = False 36 | for char in text: 37 | try: 38 | indexes.append(self.word_index_dictionary[char]) 39 | except: 40 | if char == "。" or char == "、": 41 | indexes.append(0) # padとして扱う 42 | 43 | return indexes 44 | 45 | 46 | if __name__ == '__main__': 47 | print(pyopenjtalk.g2p("こんにちは。")) 48 | print(symbols) 49 | cleaner = TextCleaner() 50 | -------------------------------------------------------------------------------- /PL_BERT_ja/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pickle 4 | import shutil 5 | 6 | from datasets import load_from_disk 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | # from transformers import BertConfig, BertModel 12 | from transformers import AlbertConfig, AlbertModel 13 | from transformers import BertJapaneseTokenizer 14 | import yaml 15 | 16 | from dataloader import build_dataloader 17 | from model import MultiTaskModel 18 | from utils import length_to_mask, scan_checkpoint 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | def train(): 24 | curr_steps = 0 25 | 26 | dataset = load_from_disk(config["data_folder"]) 27 | 28 | log_dir = config['log_dir'] 29 | if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True) 30 | shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) 31 | 32 | log_for_tensorboard = 'logs' 33 | if not osp.exists(log_for_tensorboard): os.makedirs(log_for_tensorboard, exist_ok=True) 34 | train_logger = SummaryWriter(log_for_tensorboard) 35 | 36 | batch_size = config["batch_size"] 37 | train_loader = build_dataloader( 38 | dataset, 39 | batch_size=batch_size, 40 | dataset_config=config['dataset_params'], 41 | num_workers=8, 42 | device=device, 43 | ) 44 | 45 | albert_base_configuration = AlbertConfig(**config['model_params']) 46 | bert_ = AlbertModel(albert_base_configuration).to(device) 47 | num_vocab = max([m['token'] for m in token_maps.values()]) + 1 # 30923 + 1 48 | bert = MultiTaskModel( 49 | bert_, 50 | num_vocab=num_vocab, 51 | num_tokens=config['model_params']['vocab_size'], 52 | hidden_size=config['model_params']['hidden_size'] 53 | ).to(device) 54 | 55 | load = True 56 | try: 57 | files = os.listdir(log_dir) 58 | ckpts = [] 59 | for f in files: 60 | if f.endswith(".pth.tar"): 61 | ckpts.append(f) 62 | 63 | iters = [int(f.split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))] 64 | iters = sorted(iters)[-1] 65 | except: 66 | iters = 0 67 | load = False 68 | 69 | optimizer = torch.optim.AdamW(bert.parameters(), lr=4e-6) 70 | 71 | if load: 72 | checkpoint = torch.load(os.path.join(log_dir, "{}.pth.tar".format(iters))) 73 | bert.load_state_dict(checkpoint['model'], strict=False) 74 | optimizer.load_state_dict(checkpoint['optimizer']) 75 | 76 | print('Start training...') 77 | bert.train() 78 | 79 | running_loss = 0 80 | epoch = 0 81 | while True: 82 | for _, batch in enumerate(train_loader): 83 | curr_steps += 1 84 | 85 | words, labels, phonemes, input_lengths, masked_indices = batch 86 | words, labels, phonemes = words.to(device), labels.to(device), phonemes.to(device) 87 | text_mask = length_to_mask(torch.Tensor(input_lengths)).to(device) 88 | 89 | tokens_pred, words_pred = bert(phonemes, attention_mask=(~text_mask).int()) 90 | 91 | loss_vocab = 0 92 | for _s2s_pred, _text_input, _text_length, _masked_indices in zip(words_pred, words, input_lengths, masked_indices): 93 | loss_vocab += criterion(_s2s_pred[:_text_length], 94 | _text_input[:_text_length]) 95 | loss_vocab /= words.size(0) 96 | 97 | loss_token = 0 98 | sizes = 0 99 | for _s2s_pred, _text_input, _text_length, _masked_indices in zip(tokens_pred, labels, input_lengths, masked_indices): 100 | if len(_masked_indices) > 0: 101 | _text_input = _text_input[:_text_length][_masked_indices] 102 | loss_tmp = criterion(_s2s_pred[:_text_length][_masked_indices], 103 | _text_input[:_text_length]) 104 | loss_token += loss_tmp 105 | sizes += 1 106 | loss_token /= sizes 107 | 108 | loss = loss_vocab + loss_token 109 | 110 | optimizer.zero_grad() 111 | loss.backward() 112 | optimizer.step() 113 | 114 | running_loss += loss.item() 115 | 116 | iters = iters + 1 117 | if (iters+1) % log_interval == 0: 118 | total_loss = running_loss / log_interval 119 | print('Step [%d/%d], Loss: %.5f, Vocab Loss: %.5f, Token Loss: %.5f'%(iters+1, num_steps, total_loss, loss_vocab, loss_token)) 120 | train_logger.add_scalar("Total Loss", total_loss, iters+1) 121 | train_logger.add_scalar("Vocab Loss", loss_vocab, iters+1) 122 | train_logger.add_scalar("Token Loss", loss_token, iters+1) 123 | running_loss = 0 124 | 125 | if (iters+1) % save_interval == 0: 126 | torch.save( 127 | { 128 | "model": bert.state_dict(), 129 | "step": iters, 130 | "optimizer": optimizer.state_dict(), 131 | }, 132 | os.path.join(log_dir, "{}.pth.tar".format(iters+1)), 133 | ) 134 | 135 | if curr_steps > num_steps: 136 | print(f"epoch: {epoch}") 137 | return 138 | 139 | epoch += 1 140 | print(f"epoch: {epoch}") 141 | 142 | 143 | if __name__ == '__main__': 144 | config_path = "Configs/config.yml" # you can change it to anything else 145 | config = yaml.safe_load(open(config_path)) 146 | 147 | with open(config['dataset_params']['token_maps'], 'rb') as handle: 148 | token_maps = pickle.load(handle) 149 | 150 | tokenizer = BertJapaneseTokenizer.from_pretrained(config['dataset_params']['tokenizer']) 151 | 152 | criterion = nn.CrossEntropyLoss().to(device) 153 | 154 | best_loss = float('inf') # best test loss 155 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 156 | loss_train_record = list([]) 157 | loss_test_record = list([]) 158 | 159 | num_steps = config['num_steps'] 160 | log_interval = config['log_interval'] 161 | save_interval = config['save_interval'] 162 | 163 | train() -------------------------------------------------------------------------------- /PL_BERT_ja/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import torch 4 | 5 | def scan_checkpoint(cp_dir): 6 | pattern = os.path.join(cp_dir) 7 | cp_list = glob.glob(pattern) 8 | print(cp_list) 9 | if len(cp_list) == 0: 10 | return None 11 | return sorted(cp_list)[-1] 12 | 13 | def length_to_mask(lengths): 14 | mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) 15 | mask = torch.gt(mask+1, lengths.unsqueeze(1)) 16 | return mask -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Phoneme-Level BERT-VITS2 (48000Hz 日本語版) 2 | 3 | このリポジトリは、 48000Hzの日本語音声を学習および出力できるように編集した[VITS2](https://github.com/daniilrobnikov/vits2)に、 4 | [Phoneme-Level Japanese BERT](https://github.com/yl4579/PL-BERT)の中間潜在表現を用いた音声合成モデルです。 5 | 6 | ## 1. 環境構築 7 | 8 | Anacondaによる実行環境構築を想定する。 9 | 10 | 0. Anacondaで"PLBERTVITS2"という名前の仮想環境を作成する。[y]or nを聞かれたら[y]を入力する。 11 | ```sh 12 | conda create -n PLBERTVITS2 python=3.8 13 | ``` 14 | 0. 仮想環境を有効化する。 15 | ```sh 16 | conda activate PLBERTVITS2 17 | ``` 18 | 0. このレポジトリをクローンする(もしくはDownload Zipでダウンロードする) 19 | 20 | ```sh 21 | git clone https://github.com/tonnetonne814/PL-Bert-VITS2.git 22 | cd PL-Bert-VITS2 # フォルダへ移動 23 | ``` 24 | 25 | 0. [https://pytorch.org/](https://pytorch.org/)のURLよりPyTorchをインストールする。 26 | 27 | ```sh 28 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 # cuda11.7 linuxの例 29 | ``` 30 | 31 | 0. その他、必要なパッケージをインストールする。 32 | ```sh 33 | pip install -r requirements.txt 34 | ``` 35 | 0. Monotonoic Alignment Searchをビルドする。 36 | ```sh 37 | cd monotonic_align 38 | mkdir monotonic_align 39 | python setup.py build_ext --inplace 40 | cd .. 41 | ``` 42 | 1. [PL-BERT-ja](https://github.com/kyamauchi1023/PL-BERT-ja?tab=readme-ov-file)より、日本語版のPhoneme−Level Bertの事前学習モデルをダウンロード及び展開する。 43 | 44 | ## 2. データセットの準備 45 | 46 | [JVNV Speech dataset](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus?authuser=0)による48000Hz音声の学習生成を想定する。 47 | 48 | 1. [JVNV Speech dataset](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus?authuser=0)をダウンロード及び展開する。 49 | 50 | 1. 展開したフォルダの中にあるjvnv_ver1フォルダ及びplb-ja_10000000-stepsフォルダを指定して、以下を実行する。 51 | ```sh 52 | python3 ./preprocess_ja.py --jvnv_dir ./path/to/jvnv_ver1/ --pl_bert_dir ./path/to/plb-ja_10000000-steps 53 | ``` 54 | 55 | 56 | ## 3. [configs](configs)フォルダ内のjsonを編集 57 | 主要なパラメータを説明します。必要であれば編集する。 58 | | 分類 | パラメータ名 | 説明 | 59 | |:-----:|:-----------------:|:---------------------------------------------------------:| 60 | | train | log_interval | 指定ステップ毎にロスを算出し記録する | 61 | | train | eval_interval | 指定ステップ毎にモデル評価を行う | 62 | | train | epochs | 学習データ全体を学習する回数 | 63 | | train | batch_size | 一度のパラメータ更新に使用する学習データ数 | 64 | | data | training_files | 学習用filelistのテキストパス | 65 | | data | validation_files | 検証用filelistのテキストパス | 66 | 67 | 68 | ## 4. 学習 69 | 次のコマンドを入力することで、学習を開始する。 70 | > ⚠CUDA Out of Memoryのエラーが出た場合には、config.jsonにてbatch_sizeを小さくする。 71 | 72 | ```sh 73 | python train_ms.py --config configs/jvnv_base.json -m JVNV_Dataset 74 | ``` 75 | 76 | 77 | 学習経過はターミナルにも表示されるが、tensorboardを用いて確認することで、生成音声の視聴や、スペクトログラム、各ロス遷移を目視で確認することができます。 78 | ```sh 79 | tensorboard --logdir logs 80 | ``` 81 | 82 | ## 5. 推論 83 | 次のコマンドを入力することで、推論を開始する。config.jsonへのパスと、生成器モデルパスと、PL-BERT-jaのcheckpointsのフォルダを指定する。 84 | ```sh 85 | python3 inference.py --model_ckpt_path ./path/to/ckpt.pth --model_cnfg_path ./path/to/config.json --pl_bert_dir /path/to/plb-ja_10000000-steps 86 | ``` 87 | Terminal上にて使用するデバイスを選択後、テキストを入力することで、音声が生成さされます。音声は自動的に再生され、infer_logsフォルダ(存在しない場合は自動作成)に保存されます。 88 | 89 | ## 事前学習モデル 90 | - 後ほど追加します。 91 | 92 | 93 | ## 参考文献 94 | - https://github.com/kyamauchi1023/PL-BERT-ja 95 | - https://github.com/fishaudio/Bert-VITS2 96 | - https://github.com/yl4579/PL-BERT 97 | - https://github.com/daniilrobnikov/vits2 98 | -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.utils import remove_weight_norm, weight_norm 8 | 9 | import commons 10 | import modules 11 | from modules import LayerNorm 12 | 13 | 14 | class Encoder(nn.Module): # backward compatible vits2 encoder 15 | def __init__( 16 | self, 17 | hidden_channels, 18 | filter_channels, 19 | n_heads, 20 | n_layers, 21 | kernel_size=1, 22 | p_dropout=0.0, 23 | window_size=4, 24 | **kwargs 25 | ): 26 | super().__init__() 27 | self.hidden_channels = hidden_channels 28 | self.filter_channels = filter_channels 29 | self.n_heads = n_heads 30 | self.n_layers = n_layers 31 | self.kernel_size = kernel_size 32 | self.p_dropout = p_dropout 33 | self.window_size = window_size 34 | 35 | self.drop = nn.Dropout(p_dropout) 36 | self.attn_layers = nn.ModuleList() 37 | self.norm_layers_1 = nn.ModuleList() 38 | self.ffn_layers = nn.ModuleList() 39 | self.norm_layers_2 = nn.ModuleList() 40 | # if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels 41 | self.cond_layer_idx = self.n_layers 42 | if "gin_channels" in kwargs: 43 | self.gin_channels = kwargs["gin_channels"] 44 | if self.gin_channels != 0: 45 | self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels) 46 | # vits2 says 3rd block, so idx is 2 by default 47 | self.cond_layer_idx = ( 48 | kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2 49 | ) 50 | assert ( 51 | self.cond_layer_idx < self.n_layers 52 | ), "cond_layer_idx should be less than n_layers" 53 | 54 | for i in range(self.n_layers): 55 | self.attn_layers.append( 56 | MultiHeadAttention( 57 | hidden_channels, 58 | hidden_channels, 59 | n_heads, 60 | p_dropout=p_dropout, 61 | window_size=window_size, 62 | ) 63 | ) 64 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 65 | self.ffn_layers.append( 66 | FFN( 67 | hidden_channels, 68 | hidden_channels, 69 | filter_channels, 70 | kernel_size, 71 | p_dropout=p_dropout, 72 | ) 73 | ) 74 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 75 | 76 | def forward(self, x, x_mask, g=None): 77 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 78 | x = x * x_mask 79 | for i in range(self.n_layers): 80 | if i == self.cond_layer_idx and g is not None: 81 | g = self.spk_emb_linear(g.transpose(1, 2)) 82 | g = g.transpose(1, 2) 83 | x = x + g 84 | x = x * x_mask 85 | y = self.attn_layers[i](x, x, attn_mask) 86 | y = self.drop(y) 87 | x = self.norm_layers_1[i](x + y) 88 | 89 | y = self.ffn_layers[i](x, x_mask) 90 | y = self.drop(y) 91 | x = self.norm_layers_2[i](x + y) 92 | x = x * x_mask 93 | return x 94 | 95 | 96 | class Decoder(nn.Module): 97 | def __init__( 98 | self, 99 | hidden_channels, 100 | filter_channels, 101 | n_heads, 102 | n_layers, 103 | kernel_size=1, 104 | p_dropout=0.0, 105 | proximal_bias=False, 106 | proximal_init=True, 107 | **kwargs 108 | ): 109 | super().__init__() 110 | self.hidden_channels = hidden_channels 111 | self.filter_channels = filter_channels 112 | self.n_heads = n_heads 113 | self.n_layers = n_layers 114 | self.kernel_size = kernel_size 115 | self.p_dropout = p_dropout 116 | self.proximal_bias = proximal_bias 117 | self.proximal_init = proximal_init 118 | 119 | self.drop = nn.Dropout(p_dropout) 120 | self.self_attn_layers = nn.ModuleList() 121 | self.norm_layers_0 = nn.ModuleList() 122 | self.encdec_attn_layers = nn.ModuleList() 123 | self.norm_layers_1 = nn.ModuleList() 124 | self.ffn_layers = nn.ModuleList() 125 | self.norm_layers_2 = nn.ModuleList() 126 | for i in range(self.n_layers): 127 | self.self_attn_layers.append( 128 | MultiHeadAttention( 129 | hidden_channels, 130 | hidden_channels, 131 | n_heads, 132 | p_dropout=p_dropout, 133 | proximal_bias=proximal_bias, 134 | proximal_init=proximal_init, 135 | ) 136 | ) 137 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 138 | self.encdec_attn_layers.append( 139 | MultiHeadAttention( 140 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout 141 | ) 142 | ) 143 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 144 | self.ffn_layers.append( 145 | FFN( 146 | hidden_channels, 147 | hidden_channels, 148 | filter_channels, 149 | kernel_size, 150 | p_dropout=p_dropout, 151 | causal=True, 152 | ) 153 | ) 154 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 155 | 156 | def forward(self, x, x_mask, h, h_mask): 157 | """ 158 | x: decoder input 159 | h: encoder output 160 | """ 161 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 162 | device=x.device, dtype=x.dtype 163 | ) 164 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 165 | x = x * x_mask 166 | for i in range(self.n_layers): 167 | y = self.self_attn_layers[i](x, x, self_attn_mask) 168 | y = self.drop(y) 169 | x = self.norm_layers_0[i](x + y) 170 | 171 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 172 | y = self.drop(y) 173 | x = self.norm_layers_1[i](x + y) 174 | 175 | y = self.ffn_layers[i](x, x_mask) 176 | y = self.drop(y) 177 | x = self.norm_layers_2[i](x + y) 178 | x = x * x_mask 179 | return x 180 | 181 | 182 | class MultiHeadAttention(nn.Module): 183 | def __init__( 184 | self, 185 | channels, 186 | out_channels, 187 | n_heads, 188 | p_dropout=0.0, 189 | window_size=None, 190 | heads_share=True, 191 | block_length=None, 192 | proximal_bias=False, 193 | proximal_init=False, 194 | ): 195 | super().__init__() 196 | assert channels % n_heads == 0 197 | 198 | self.channels = channels 199 | self.out_channels = out_channels 200 | self.n_heads = n_heads 201 | self.p_dropout = p_dropout 202 | self.window_size = window_size 203 | self.heads_share = heads_share 204 | self.block_length = block_length 205 | self.proximal_bias = proximal_bias 206 | self.proximal_init = proximal_init 207 | self.attn = None 208 | 209 | self.k_channels = channels // n_heads 210 | self.conv_q = nn.Conv1d(channels, channels, 1) 211 | self.conv_k = nn.Conv1d(channels, channels, 1) 212 | self.conv_v = nn.Conv1d(channels, channels, 1) 213 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 214 | self.drop = nn.Dropout(p_dropout) 215 | 216 | if window_size is not None: 217 | n_heads_rel = 1 if heads_share else n_heads 218 | rel_stddev = self.k_channels**-0.5 219 | self.emb_rel_k = nn.Parameter( 220 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 221 | * rel_stddev 222 | ) 223 | self.emb_rel_v = nn.Parameter( 224 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 225 | * rel_stddev 226 | ) 227 | 228 | nn.init.xavier_uniform_(self.conv_q.weight) 229 | nn.init.xavier_uniform_(self.conv_k.weight) 230 | nn.init.xavier_uniform_(self.conv_v.weight) 231 | if proximal_init: 232 | with torch.no_grad(): 233 | self.conv_k.weight.copy_(self.conv_q.weight) 234 | self.conv_k.bias.copy_(self.conv_q.bias) 235 | 236 | def forward(self, x, c, attn_mask=None): 237 | q = self.conv_q(x) 238 | k = self.conv_k(c) 239 | v = self.conv_v(c) 240 | 241 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 242 | 243 | x = self.conv_o(x) 244 | return x 245 | 246 | def attention(self, query, key, value, mask=None): 247 | # reshape [b, d, t] -> [b, n_h, t, d_k] 248 | b, d, t_s, t_t = (*key.size(), query.size(2)) 249 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 250 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 251 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 252 | 253 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 254 | if self.window_size is not None: 255 | assert ( 256 | t_s == t_t 257 | ), "Relative attention is only available for self-attention." 258 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 259 | rel_logits = self._matmul_with_relative_keys( 260 | query / math.sqrt(self.k_channels), key_relative_embeddings 261 | ) 262 | scores_local = self._relative_position_to_absolute_position(rel_logits) 263 | scores = scores + scores_local 264 | if self.proximal_bias: 265 | assert t_s == t_t, "Proximal bias is only available for self-attention." 266 | scores = scores + self._attention_bias_proximal(t_s).to( 267 | device=scores.device, dtype=scores.dtype 268 | ) 269 | if mask is not None: 270 | scores = scores.masked_fill(mask == 0, -1e4) 271 | if self.block_length is not None: 272 | assert ( 273 | t_s == t_t 274 | ), "Local attention is only available for self-attention." 275 | block_mask = ( 276 | torch.ones_like(scores) 277 | .triu(-self.block_length) 278 | .tril(self.block_length) 279 | ) 280 | scores = scores.masked_fill(block_mask == 0, -1e4) 281 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 282 | p_attn = self.drop(p_attn) 283 | output = torch.matmul(p_attn, value) 284 | if self.window_size is not None: 285 | relative_weights = self._absolute_position_to_relative_position(p_attn) 286 | value_relative_embeddings = self._get_relative_embeddings( 287 | self.emb_rel_v, t_s 288 | ) 289 | output = output + self._matmul_with_relative_values( 290 | relative_weights, value_relative_embeddings 291 | ) 292 | output = ( 293 | output.transpose(2, 3).contiguous().view(b, d, t_t) 294 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 295 | return output, p_attn 296 | 297 | def _matmul_with_relative_values(self, x, y): 298 | """ 299 | x: [b, h, l, m] 300 | y: [h or 1, m, d] 301 | ret: [b, h, l, d] 302 | """ 303 | ret = torch.matmul(x, y.unsqueeze(0)) 304 | return ret 305 | 306 | def _matmul_with_relative_keys(self, x, y): 307 | """ 308 | x: [b, h, l, d] 309 | y: [h or 1, m, d] 310 | ret: [b, h, l, m] 311 | """ 312 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 313 | return ret 314 | 315 | def _get_relative_embeddings(self, relative_embeddings, length): 316 | max_relative_position = 2 * self.window_size + 1 317 | # Pad first before slice to avoid using cond ops. 318 | pad_length = max(length - (self.window_size + 1), 0) 319 | slice_start_position = max((self.window_size + 1) - length, 0) 320 | slice_end_position = slice_start_position + 2 * length - 1 321 | if pad_length > 0: 322 | padded_relative_embeddings = F.pad( 323 | relative_embeddings, 324 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 325 | ) 326 | else: 327 | padded_relative_embeddings = relative_embeddings 328 | used_relative_embeddings = padded_relative_embeddings[ 329 | :, slice_start_position:slice_end_position 330 | ] 331 | return used_relative_embeddings 332 | 333 | def _relative_position_to_absolute_position(self, x): 334 | """ 335 | x: [b, h, l, 2*l-1] 336 | ret: [b, h, l, l] 337 | """ 338 | batch, heads, length, _ = x.size() 339 | # Concat columns of pad to shift from relative to absolute indexing. 340 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 341 | 342 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 343 | x_flat = x.view([batch, heads, length * 2 * length]) 344 | x_flat = F.pad( 345 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 346 | ) 347 | 348 | # Reshape and slice out the padded elements. 349 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 350 | :, :, :length, length - 1 : 351 | ] 352 | return x_final 353 | 354 | def _absolute_position_to_relative_position(self, x): 355 | """ 356 | x: [b, h, l, l] 357 | ret: [b, h, l, 2*l-1] 358 | """ 359 | batch, heads, length, _ = x.size() 360 | # padd along column 361 | x = F.pad( 362 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 363 | ) 364 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 365 | # add 0's in the beginning that will skew the elements after reshape 366 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 367 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 368 | return x_final 369 | 370 | def _attention_bias_proximal(self, length): 371 | """Bias for self-attention to encourage attention to close positions. 372 | Args: 373 | length: an integer scalar. 374 | Returns: 375 | a Tensor with shape [1, 1, length, length] 376 | """ 377 | r = torch.arange(length, dtype=torch.float32) 378 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 379 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 380 | 381 | 382 | class FFN(nn.Module): 383 | def __init__( 384 | self, 385 | in_channels, 386 | out_channels, 387 | filter_channels, 388 | kernel_size, 389 | p_dropout=0.0, 390 | activation=None, 391 | causal=False, 392 | ): 393 | super().__init__() 394 | self.in_channels = in_channels 395 | self.out_channels = out_channels 396 | self.filter_channels = filter_channels 397 | self.kernel_size = kernel_size 398 | self.p_dropout = p_dropout 399 | self.activation = activation 400 | self.causal = causal 401 | 402 | if causal: 403 | self.padding = self._causal_padding 404 | else: 405 | self.padding = self._same_padding 406 | 407 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 408 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 409 | self.drop = nn.Dropout(p_dropout) 410 | 411 | def forward(self, x, x_mask): 412 | x = self.conv_1(self.padding(x * x_mask)) 413 | if self.activation == "gelu": 414 | x = x * torch.sigmoid(1.702 * x) 415 | else: 416 | x = torch.relu(x) 417 | x = self.drop(x) 418 | x = self.conv_2(self.padding(x * x_mask)) 419 | return x * x_mask 420 | 421 | def _causal_padding(self, x): 422 | if self.kernel_size == 1: 423 | return x 424 | pad_l = self.kernel_size - 1 425 | pad_r = 0 426 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 427 | x = F.pad(x, commons.convert_pad_shape(padding)) 428 | return x 429 | 430 | def _same_padding(self, x): 431 | if self.kernel_size == 1: 432 | return x 433 | pad_l = (self.kernel_size - 1) // 2 434 | pad_r = self.kernel_size // 2 435 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 436 | x = F.pad(x, commons.convert_pad_shape(padding)) 437 | return x 438 | 439 | 440 | class Depthwise_Separable_Conv1D(nn.Module): 441 | def __init__( 442 | self, 443 | in_channels, 444 | out_channels, 445 | kernel_size, 446 | stride=1, 447 | padding=0, 448 | dilation=1, 449 | bias=True, 450 | padding_mode="zeros", # TODO: refine this type 451 | device=None, 452 | dtype=None, 453 | ): 454 | super().__init__() 455 | self.depth_conv = nn.Conv1d( 456 | in_channels=in_channels, 457 | out_channels=in_channels, 458 | kernel_size=kernel_size, 459 | groups=in_channels, 460 | stride=stride, 461 | padding=padding, 462 | dilation=dilation, 463 | bias=bias, 464 | padding_mode=padding_mode, 465 | device=device, 466 | dtype=dtype, 467 | ) 468 | self.point_conv = nn.Conv1d( 469 | in_channels=in_channels, 470 | out_channels=out_channels, 471 | kernel_size=1, 472 | bias=bias, 473 | device=device, 474 | dtype=dtype, 475 | ) 476 | 477 | def forward(self, input): 478 | return self.point_conv(self.depth_conv(input)) 479 | 480 | def weight_norm(self): 481 | self.depth_conv = weight_norm(self.depth_conv, name="weight") 482 | self.point_conv = weight_norm(self.point_conv, name="weight") 483 | 484 | def remove_weight_norm(self): 485 | self.depth_conv = remove_weight_norm(self.depth_conv, name="weight") 486 | self.point_conv = remove_weight_norm(self.point_conv, name="weight") 487 | 488 | 489 | class Depthwise_Separable_TransposeConv1D(nn.Module): 490 | def __init__( 491 | self, 492 | in_channels, 493 | out_channels, 494 | kernel_size, 495 | stride=1, 496 | padding=0, 497 | output_padding=0, 498 | bias=True, 499 | dilation=1, 500 | padding_mode="zeros", # TODO: refine this type 501 | device=None, 502 | dtype=None, 503 | ): 504 | super().__init__() 505 | self.depth_conv = nn.ConvTranspose1d( 506 | in_channels=in_channels, 507 | out_channels=in_channels, 508 | kernel_size=kernel_size, 509 | groups=in_channels, 510 | stride=stride, 511 | output_padding=output_padding, 512 | padding=padding, 513 | dilation=dilation, 514 | bias=bias, 515 | padding_mode=padding_mode, 516 | device=device, 517 | dtype=dtype, 518 | ) 519 | self.point_conv = nn.Conv1d( 520 | in_channels=in_channels, 521 | out_channels=out_channels, 522 | kernel_size=1, 523 | bias=bias, 524 | device=device, 525 | dtype=dtype, 526 | ) 527 | 528 | def forward(self, input): 529 | return self.point_conv(self.depth_conv(input)) 530 | 531 | def weight_norm(self): 532 | self.depth_conv = weight_norm(self.depth_conv, name="weight") 533 | self.point_conv = weight_norm(self.point_conv, name="weight") 534 | 535 | def remove_weight_norm(self): 536 | remove_weight_norm(self.depth_conv, name="weight") 537 | remove_weight_norm(self.point_conv, name="weight") 538 | 539 | 540 | def weight_norm_modules(module, name="weight", dim=0): 541 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( 542 | module, Depthwise_Separable_TransposeConv1D 543 | ): 544 | module.weight_norm() 545 | return module 546 | else: 547 | return weight_norm(module, name, dim) 548 | 549 | 550 | def remove_weight_norm_modules(module, name="weight"): 551 | if isinstance(module, Depthwise_Separable_Conv1D) or isinstance( 552 | module, Depthwise_Separable_TransposeConv1D 553 | ): 554 | module.remove_weight_norm() 555 | else: 556 | remove_weight_norm(module, name) 557 | 558 | 559 | class FFT(nn.Module): 560 | def __init__( 561 | self, 562 | hidden_channels, 563 | filter_channels, 564 | n_heads, 565 | n_layers=1, 566 | kernel_size=1, 567 | p_dropout=0.0, 568 | proximal_bias=False, 569 | proximal_init=True, 570 | isflow=False, 571 | **kwargs 572 | ): 573 | super().__init__() 574 | self.hidden_channels = hidden_channels 575 | self.filter_channels = filter_channels 576 | self.n_heads = n_heads 577 | self.n_layers = n_layers 578 | self.kernel_size = kernel_size 579 | self.p_dropout = p_dropout 580 | self.proximal_bias = proximal_bias 581 | self.proximal_init = proximal_init 582 | if isflow and "gin_channels" in kwargs and kwargs["gin_channels"] > 0: 583 | cond_layer = torch.nn.Conv1d( 584 | kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1 585 | ) 586 | self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1) 587 | self.cond_layer = weight_norm_modules(cond_layer, name="weight") 588 | self.gin_channels = kwargs["gin_channels"] 589 | self.drop = nn.Dropout(p_dropout) 590 | self.self_attn_layers = nn.ModuleList() 591 | self.norm_layers_0 = nn.ModuleList() 592 | self.ffn_layers = nn.ModuleList() 593 | self.norm_layers_1 = nn.ModuleList() 594 | for i in range(self.n_layers): 595 | self.self_attn_layers.append( 596 | MultiHeadAttention( 597 | hidden_channels, 598 | hidden_channels, 599 | n_heads, 600 | p_dropout=p_dropout, 601 | proximal_bias=proximal_bias, 602 | proximal_init=proximal_init, 603 | ) 604 | ) 605 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 606 | self.ffn_layers.append( 607 | FFN( 608 | hidden_channels, 609 | hidden_channels, 610 | filter_channels, 611 | kernel_size, 612 | p_dropout=p_dropout, 613 | causal=True, 614 | ) 615 | ) 616 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 617 | 618 | def forward(self, x, x_mask, g=None): 619 | """ 620 | x: decoder input 621 | h: encoder output 622 | """ 623 | if g is not None: 624 | g = self.cond_layer(g) 625 | 626 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 627 | device=x.device, dtype=x.dtype 628 | ) 629 | x = x * x_mask 630 | for i in range(self.n_layers): 631 | if g is not None: 632 | x = self.cond_pre(x) 633 | cond_offset = i * 2 * self.hidden_channels 634 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 635 | x = commons.fused_add_tanh_sigmoid_multiply( 636 | x, g_l, torch.IntTensor([self.hidden_channels]) 637 | ) 638 | y = self.self_attn_layers[i](x, x, self_attn_mask) 639 | y = self.drop(y) 640 | x = self.norm_layers_0[i](x + y) 641 | 642 | y = self.ffn_layers[i](x, x_mask) 643 | y = self.drop(y) 644 | x = self.norm_layers_1[i](x + y) 645 | x = x * x_mask 646 | return x 647 | -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size * dilation - dilation) / 2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += ( 34 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 35 | ) 36 | return kl 37 | 38 | 39 | def rand_gumbel(shape): 40 | """Sample from the Gumbel distribution, protect from overflows.""" 41 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 42 | return -torch.log(-torch.log(uniform_samples)) 43 | 44 | 45 | def rand_gumbel_like(x): 46 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 47 | return g 48 | 49 | 50 | def slice_segments(x, ids_str, segment_size=4): 51 | ret = torch.zeros_like(x[:, :, :segment_size]) 52 | for i in range(x.size(0)): 53 | idx_str = ids_str[i] 54 | idx_end = idx_str + segment_size 55 | ret[i] = x[i, :, idx_str:idx_end] 56 | return ret 57 | 58 | 59 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 60 | b, d, t = x.size() 61 | if x_lengths is None: 62 | x_lengths = t 63 | ids_str_max = x_lengths - segment_size + 1 64 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 65 | ret = slice_segments(x, ids_str, segment_size) 66 | return ret, ids_str 67 | 68 | 69 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 70 | position = torch.arange(length, dtype=torch.float) 71 | num_timescales = channels // 2 72 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 73 | num_timescales - 1 74 | ) 75 | inv_timescales = min_timescale * torch.exp( 76 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 77 | ) 78 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 79 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 80 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 81 | signal = signal.view(1, channels, length) 82 | return signal 83 | 84 | 85 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 86 | b, channels, length = x.size() 87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 88 | return x + signal.to(dtype=x.dtype, device=x.device) 89 | 90 | 91 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 92 | b, channels, length = x.size() 93 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 94 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 95 | 96 | 97 | def subsequent_mask(length): 98 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 99 | return mask 100 | 101 | 102 | @torch.jit.script 103 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 104 | n_channels_int = n_channels[0] 105 | in_act = input_a + input_b 106 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 107 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 108 | acts = t_act * s_act 109 | return acts 110 | 111 | 112 | def convert_pad_shape(pad_shape): 113 | l = pad_shape[::-1] 114 | pad_shape = [item for sublist in l for item in sublist] 115 | return pad_shape 116 | 117 | 118 | def shift_1d(x): 119 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 120 | return x 121 | 122 | 123 | def sequence_mask(length, max_length=None): 124 | if max_length is None: 125 | max_length = length.max() 126 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 127 | return x.unsqueeze(0) < length.unsqueeze(1) 128 | 129 | 130 | def generate_path(duration, mask): 131 | """ 132 | duration: [b, 1, t_x] 133 | mask: [b, 1, t_y, t_x] 134 | """ 135 | device = duration.device 136 | 137 | b, _, t_y, t_x = mask.shape 138 | cum_duration = torch.cumsum(duration, -1) 139 | 140 | cum_duration_flat = cum_duration.view(b * t_x) 141 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 142 | path = path.view(b, t_x, t_y) 143 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 144 | path = path.unsqueeze(1).transpose(2, 3) * mask 145 | return path 146 | 147 | 148 | def clip_grad_value_(parameters, clip_value, norm_type=2): 149 | if isinstance(parameters, torch.Tensor): 150 | parameters = [parameters] 151 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 152 | norm_type = float(norm_type) 153 | if clip_value is not None: 154 | clip_value = float(clip_value) 155 | 156 | total_norm = 0 157 | for p in parameters: 158 | param_norm = p.grad.data.norm(norm_type) 159 | total_norm += param_norm.item() ** norm_type 160 | if clip_value is not None: 161 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 162 | total_norm = total_norm ** (1.0 / norm_type) 163 | return total_norm 164 | -------------------------------------------------------------------------------- /configs/jvnv_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 32, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 16320, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "use_mel_posterior_encoder": false, 21 | "training_files":"filelists/jvnv_ver1_train.txt", 22 | "validation_files":"filelists/jvnv_ver1_val.txt", 23 | "text_cleaners":[], 24 | "max_wav_value": 1.0, 25 | "sampling_rate": 48000, 26 | "filter_length": 2048, 27 | "hop_length": 480, 28 | "win_length": 2048, 29 | "n_mel_channels": 128, 30 | "mel_fmin": 0.0, 31 | "mel_fmax": null, 32 | "add_blank": false, 33 | "n_speakers": 4, 34 | "cleaned_text": true 35 | }, 36 | "model": { 37 | "use_mel_posterior_encoder": true, 38 | "use_transformer_flows": true, 39 | "transformer_flow_type": "fft", 40 | "use_spk_conditioned_encoder": true, 41 | "use_noise_scaled_mas": true, 42 | "use_duration_discriminator": true, 43 | "duration_discriminator_type": "dur_disc_2", 44 | "inter_channels": 192, 45 | "hidden_channels": 192, 46 | "filter_channels": 768, 47 | "bert_emb_size" : 768, 48 | "n_heads": 2, 49 | "n_layers": 6, 50 | "kernel_size": 3, 51 | "p_dropout": 0.1, 52 | "resblock": "1", 53 | "resblock_kernel_sizes": [3,7,11], 54 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 55 | "upsample_rates": [12,10,2,2], 56 | "upsample_initial_channel": 512, 57 | "upsample_kernel_sizes": [24,20,4,4], 58 | "n_layers_q": 3, 59 | "use_spectral_norm": false, 60 | "use_sdp": true, 61 | "gin_channels": 256 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | 9 | import commons 10 | from mel_processing import (mel_spectrogram_torch, spec_to_mel_torch, 11 | spectrogram_torch) 12 | from text import cleaned_text_to_sequence, text_to_sequence 13 | from utils import load_filepaths_and_text, load_wav_to_torch 14 | 15 | """Multi speaker version""" 16 | 17 | 18 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 19 | """ 20 | 1) loads audio, speaker_id, text pairs 21 | 2) normalizes text and converts them to sequences of integers 22 | 3) computes spectrograms from audio files. 23 | """ 24 | 25 | def __init__(self, audiopaths_sid_text, hparams): 26 | self.hparams = hparams 27 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 28 | self.text_cleaners = hparams.text_cleaners 29 | self.max_wav_value = hparams.max_wav_value 30 | self.sampling_rate = hparams.sampling_rate 31 | self.filter_length = hparams.filter_length 32 | self.hop_length = hparams.hop_length 33 | self.win_length = hparams.win_length 34 | self.sampling_rate = hparams.sampling_rate 35 | 36 | self.use_mel_spec_posterior = getattr( 37 | hparams, "use_mel_posterior_encoder", False 38 | ) 39 | if self.use_mel_spec_posterior: 40 | self.n_mel_channels = getattr(hparams, "n_mel_channels", 80) 41 | self.cleaned_text = getattr(hparams, "cleaned_text", False) 42 | 43 | self.add_blank = hparams.add_blank 44 | self.min_text_len = getattr(hparams, "min_text_len", 1) 45 | self.max_text_len = getattr(hparams, "max_text_len", 999) 46 | self.min_audio_len = getattr(hparams, "min_audio_len", 8192) 47 | 48 | random.seed(1234) 49 | random.shuffle(self.audiopaths_sid_text) 50 | self._filter() 51 | 52 | self.count = 0 53 | 54 | def _filter(self): 55 | """ 56 | Filter text & store spec lengths 57 | """ 58 | # Store spectrogram lengths for Bucketing 59 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 60 | # spec_length = wav_length // hop_length 61 | 62 | audiopaths_sid_text_new = [] 63 | lengths = [] 64 | for data in self.audiopaths_sid_text: 65 | audiopath, sid, ph, text, bert, emo, style = data 66 | if not os.path.isfile(audiopath): 67 | continue 68 | if self.min_text_len <= len(text) and len(text) <= self.max_text_len: 69 | audiopaths_sid_text_new.append([audiopath, sid, ph, text, bert, emo, style]) 70 | length = os.path.getsize(audiopath) // (2 * self.hop_length) 71 | if length < self.min_audio_len // self.hop_length: 72 | print("DATA PASS") 73 | continue 74 | lengths.append(length) 75 | self.audiopaths_sid_text = audiopaths_sid_text_new 76 | self.lengths = lengths 77 | print(f"INFO:{len(self.audiopaths_sid_text)} is used as Training Dataset.") 78 | 79 | def get_audio_text_speaker_pair(self, audiopath_sid_text): 80 | # separate filename, speaker_id and text 81 | audiopath, sid, ph, text, pl_bert, emo, style = ( 82 | audiopath_sid_text[0], 83 | audiopath_sid_text[1], 84 | audiopath_sid_text[2], 85 | audiopath_sid_text[3], 86 | audiopath_sid_text[4], 87 | audiopath_sid_text[5], 88 | audiopath_sid_text[6], 89 | ) 90 | ph = self.get_text(ph) 91 | spec, wav = self.get_audio(audiopath) 92 | bert = self.get_pl_bert(pl_bert) 93 | sid = self.get_sid(sid) 94 | 95 | # parameter checker 96 | assert len(ph) == bert.size(1) 97 | 98 | return (ph, spec, wav, sid, bert) 99 | 100 | def get_pl_bert(self, filename): 101 | path = os.path.join("pl_bert_embeddings", f"{filename}.PlBertJa") 102 | data = torch.load(path) 103 | if self.add_blank: 104 | L, T, H = data.shape 105 | new_data = torch.zeros(size=(L,2*T+1,H), dtype=data.dtype) 106 | for idx in range(T): 107 | target_idx = idx*2+1 108 | new_data[:, target_idx, :] = data[:, idx, :] 109 | data = new_data 110 | return data 111 | 112 | def get_audio(self, filename): 113 | # TODO : if linear spec exists convert to mel from existing linear spec 114 | audio, sampling_rate = load_wav_to_torch(filename) 115 | if sampling_rate != self.sampling_rate: 116 | raise ValueError( 117 | "{} {} SR doesn't match target {} SR".format( 118 | sampling_rate, self.sampling_rate 119 | ) 120 | ) 121 | # audio_norm = audio / self.max_wav_value 122 | audio_norm = audio.unsqueeze(0) 123 | spec_filename = filename.replace(".wav", ".spec.pt") 124 | if self.use_mel_spec_posterior: 125 | spec_filename = spec_filename.replace(".spec.pt", ".mel.pt") 126 | if os.path.exists(spec_filename): 127 | spec = torch.load(spec_filename) 128 | else: 129 | if self.use_mel_spec_posterior: 130 | """TODO : (need verification) 131 | if linear spec exists convert to 132 | mel from existing linear spec (uncomment below lines)""" 133 | # if os.path.exists(filename.replace(".wav", ".spec.pt")): 134 | # # spec, n_fft, num_mels, sampling_rate, fmin, fmax 135 | # spec = spec_to_mel_torch( 136 | # torch.load(filename.replace(".wav", ".spec.pt")), 137 | # self.filter_length, self.n_mel_channels, self.sampling_rate, 138 | # self.hparams.mel_fmin, self.hparams.mel_fmax) 139 | spec = mel_spectrogram_torch( 140 | audio_norm, 141 | self.filter_length, 142 | self.n_mel_channels, 143 | self.sampling_rate, 144 | self.hop_length, 145 | self.win_length, 146 | self.hparams.mel_fmin, 147 | self.hparams.mel_fmax, 148 | center=False, 149 | ) 150 | else: 151 | spec = spectrogram_torch( 152 | audio_norm, 153 | self.filter_length, 154 | self.sampling_rate, 155 | self.hop_length, 156 | self.win_length, 157 | center=False, 158 | ) 159 | spec = torch.squeeze(spec, 0) 160 | torch.save(spec, spec_filename) 161 | return spec, audio_norm 162 | 163 | def get_text(self, text): 164 | if self.cleaned_text: 165 | text_norm = cleaned_text_to_sequence(text) 166 | else: 167 | text_norm = text_to_sequence(text, self.text_cleaners) 168 | if self.add_blank: 169 | text_norm = commons.intersperse(text_norm, 0) 170 | text_norm = torch.LongTensor(text_norm) 171 | return text_norm 172 | 173 | def get_sid(self, sid): 174 | sid = torch.LongTensor([int(sid)]) 175 | return sid 176 | 177 | def __getitem__(self, index): 178 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) 179 | 180 | def __len__(self): 181 | return len(self.audiopaths_sid_text) 182 | 183 | 184 | class TextAudioSpeakerCollate: 185 | """Zero-pads model inputs and targets""" 186 | 187 | def __init__(self, return_ids=False): 188 | self.return_ids = return_ids 189 | 190 | def __call__(self, batch): 191 | """Collate's training batch from normalized text, audio and speaker identities 192 | PARAMS 193 | ------ 194 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 195 | """ 196 | # Right zero-pad all one-hot text sequences to max input length 197 | _, ids_sorted_decreasing = torch.sort( 198 | torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True 199 | ) 200 | 201 | max_text_len = max([len(x[0]) for x in batch]) 202 | max_spec_len = max([x[1].size(1) for x in batch]) 203 | max_wav_len = max([x[2].size(1) for x in batch]) 204 | # sid = 1 205 | max_bert_len = max([x[4].size(1) for x in batch]) 206 | 207 | text_lengths = torch.LongTensor(len(batch)) 208 | spec_lengths = torch.LongTensor(len(batch)) 209 | wav_lengths = torch.LongTensor(len(batch)) 210 | sid = torch.LongTensor(len(batch)) 211 | bert_lengths = torch.LongTensor(len(batch)) 212 | 213 | text_padded = torch.LongTensor(len(batch), max_text_len) 214 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 215 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 216 | bert_padded = torch.FloatTensor(len(batch), 13, max_bert_len, 768) 217 | 218 | text_padded.zero_() 219 | spec_padded.zero_() 220 | wav_padded.zero_() 221 | bert_padded.zero_() 222 | for i in range(len(ids_sorted_decreasing)): 223 | row = batch[ids_sorted_decreasing[i]] 224 | 225 | text = row[0] 226 | text_padded[i, : text.size(0)] = text 227 | text_lengths[i] = text.size(0) 228 | 229 | spec = row[1] 230 | spec_padded[i, :, : spec.size(1)] = spec 231 | spec_lengths[i] = spec.size(1) 232 | 233 | wav = row[2] 234 | wav_padded[i, :, : wav.size(1)] = wav 235 | wav_lengths[i] = wav.size(1) 236 | 237 | sid[i] = row[3] 238 | 239 | bert = row[4] 240 | bert_padded[i, :, :bert.size(1),:] = bert 241 | bert_lengths[i] = bert.size(1) 242 | 243 | 244 | if self.return_ids: 245 | return ( 246 | text_padded, 247 | text_lengths, 248 | spec_padded, 249 | spec_lengths, 250 | wav_padded, 251 | wav_lengths, 252 | bert_padded, 253 | bert_lengths, 254 | sid, 255 | ids_sorted_decreasing, 256 | ) 257 | return ( 258 | text_padded, 259 | text_lengths, 260 | spec_padded, 261 | spec_lengths, 262 | wav_padded, 263 | wav_lengths, 264 | bert_padded, 265 | bert_lengths, 266 | sid, 267 | ) 268 | 269 | 270 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 271 | """ 272 | Maintain similar input lengths in a batch. 273 | Length groups are specified by boundaries. 274 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 275 | 276 | It removes samples which are not included in the boundaries. 277 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 278 | """ 279 | 280 | def __init__( 281 | self, 282 | dataset, 283 | batch_size, 284 | boundaries, 285 | num_replicas=None, 286 | rank=None, 287 | shuffle=True, 288 | ): 289 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 290 | self.lengths = dataset.lengths 291 | self.batch_size = batch_size 292 | self.boundaries = boundaries 293 | 294 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 295 | self.total_size = sum(self.num_samples_per_bucket) 296 | self.num_samples = self.total_size // self.num_replicas 297 | 298 | def _create_buckets(self): 299 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 300 | for i in range(len(self.lengths)): 301 | length = self.lengths[i] 302 | idx_bucket = self._bisect(length) 303 | if idx_bucket != -1: 304 | buckets[idx_bucket].append(i) 305 | 306 | for i in range(len(buckets) - 1, 0, -1): 307 | if len(buckets[i]) == 0: 308 | buckets.pop(i) 309 | self.boundaries.pop(i + 1) 310 | i=0 311 | if len(buckets[i]) == 0: 312 | buckets.pop(i) 313 | self.boundaries.pop(i + 1) 314 | 315 | num_samples_per_bucket = [] 316 | for i in range(len(buckets)): 317 | len_bucket = len(buckets[i]) 318 | total_batch_size = self.num_replicas * self.batch_size 319 | rem = ( 320 | total_batch_size - (len_bucket % total_batch_size) 321 | ) % total_batch_size 322 | num_samples_per_bucket.append(len_bucket + rem) 323 | return buckets, num_samples_per_bucket 324 | 325 | def __iter__(self): 326 | # deterministically shuffle based on epoch 327 | g = torch.Generator() 328 | g.manual_seed(self.epoch) 329 | 330 | indices = [] 331 | if self.shuffle: 332 | for bucket in self.buckets: 333 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 334 | else: 335 | for bucket in self.buckets: 336 | indices.append(list(range(len(bucket)))) 337 | 338 | batches = [] 339 | for i in range(len(self.buckets)): 340 | bucket = self.buckets[i] 341 | len_bucket = len(bucket) 342 | ids_bucket = indices[i] 343 | num_samples_bucket = self.num_samples_per_bucket[i] 344 | 345 | # add extra samples to make it evenly divisible 346 | rem = num_samples_bucket - len_bucket 347 | ids_bucket = ( 348 | ids_bucket 349 | + ids_bucket * (rem // len_bucket) 350 | + ids_bucket[: (rem % len_bucket)] 351 | ) 352 | 353 | # subsample 354 | ids_bucket = ids_bucket[self.rank :: self.num_replicas] 355 | 356 | # batching 357 | for j in range(len(ids_bucket) // self.batch_size): 358 | batch = [ 359 | bucket[idx] 360 | for idx in ids_bucket[ 361 | j * self.batch_size : (j + 1) * self.batch_size 362 | ] 363 | ] 364 | batches.append(batch) 365 | 366 | if self.shuffle: 367 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 368 | batches = [batches[i] for i in batch_ids] 369 | self.batches = batches 370 | 371 | assert len(self.batches) * self.batch_size == self.num_samples 372 | return iter(self.batches) 373 | 374 | def _bisect(self, x, lo=0, hi=None): 375 | if hi is None: 376 | hi = len(self.boundaries) - 1 377 | 378 | if hi > lo: 379 | mid = (hi + lo) // 2 380 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 381 | return mid 382 | elif x <= self.boundaries[mid]: 383 | return self._bisect(x, lo, mid) 384 | else: 385 | return self._bisect(x, mid + 1, hi) 386 | else: 387 | return -1 388 | 389 | def __len__(self): 390 | return self.num_samples // self.batch_size 391 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | 2 | from models import SynthesizerTrn 3 | import argparse 4 | import utils 5 | from PL_BERT_ja.text.symbols import symbols 6 | import json 7 | from preprocess_ja import get_pl_bert_ja 8 | import torch 9 | import soundcard as sc 10 | import time 11 | import os 12 | import soundfile as sf 13 | from transformers import BertJapaneseTokenizer 14 | import torch 15 | from PL_BERT_ja.text_utils import TextCleaner 16 | from PL_BERT_ja.phonemize import phonemize 17 | import commons 18 | from text import cleaned_text_to_sequence, text_to_sequence 19 | 20 | def inference(model_ckpt_path, model_config_path, pl_bert_dir, is_save=True): 21 | with open(model_config_path, "r") as f: 22 | data = f.read() 23 | config = json.loads(data) 24 | hps = utils.HParams(**config) 25 | 26 | if hps.model.use_noise_scaled_mas is True : 27 | print("Using noise scaled MAS for VITS2") 28 | use_noise_scaled_mas = True 29 | mas_noise_scale_initial = 0.01 30 | noise_scale_delta = 2e-6 31 | 32 | net_g = SynthesizerTrn( 33 | len(symbols)+1, 34 | hps.data.n_mel_channels, 35 | hps.train.segment_size // hps.data.hop_length, 36 | n_speakers=hps.data.n_speakers, 37 | mas_noise_scale_initial=mas_noise_scale_initial, 38 | noise_scale_delta=noise_scale_delta, 39 | **hps.model, 40 | ) 41 | 42 | pl_bert_model, pl_bert_config, device = get_pl_bert_ja(dir=pl_bert_dir) 43 | pl_bert_cleaner = TextCleaner() 44 | pl_bert_tokenizer = BertJapaneseTokenizer.from_pretrained(pl_bert_config['dataset_params']['tokenizer']) 45 | 46 | net_g, _, _, _ = utils.load_checkpoint( model_ckpt_path, net_g, optimizer=None) 47 | 48 | # play audio by system default 49 | speaker = sc.get_speaker(sc.default_speaker().name) 50 | 51 | # parameter settings 52 | noise_scale = torch.tensor(0.66) # adjust z_p noise 53 | noise_scale_w = torch.tensor(0.8) # adjust SDP noise 54 | length_scale = torch.tensor(1.0) # adjust sound length scale (talk speed) 55 | 56 | if is_save is True: 57 | n_save = 0 58 | save_dir = os.path.join("./infer_logs/") 59 | os.makedirs(save_dir, exist_ok=True) 60 | 61 | net_g = net_g.to(device) 62 | pl_bert_model = pl_bert_model.to(device) 63 | 64 | ### Dummy Input ### 65 | with torch.inference_mode(): 66 | dummy_text = "色々疲れちまったけど、やっぱ音声合成してるときが一番ワクワクするんだよな。" 67 | 68 | # get bert features 69 | bert_features, phonemes = get_bert_features(dummy_text, pl_bert_model, pl_bert_tokenizer, pl_bert_config, pl_bert_cleaner, device, add_blank=hps.data.add_blank) 70 | x = get_text_ids(phonemes=phonemes, 71 | add_blank=hps.data.add_blank) 72 | x = x.unsqueeze(0) 73 | bert_features = bert_features.unsqueeze(0) 74 | x_lengths = torch.LongTensor([x.size(1)]) 75 | sid = torch.LongTensor([0]) 76 | net_g.infer(x .to(device), 77 | x_lengths .to(device), 78 | bert_features .to(device), 79 | x_lengths .to(device), 80 | sid .to(device), 81 | noise_scale=noise_scale.to(device), 82 | noise_scale_w=noise_scale_w.to(device), 83 | length_scale=length_scale.to(device), 84 | max_len=1000) 85 | 86 | while True: 87 | # get text 88 | text = input("Enter text. ==> ") 89 | if text=="": 90 | print("Empty input is detected... Exit...") 91 | break 92 | 93 | # measure the execution time 94 | torch.cuda.synchronize() 95 | start = time.time() 96 | 97 | # required_grad is False 98 | with torch.inference_mode(): 99 | bert_features, phonemes = get_bert_features(text, pl_bert_model, pl_bert_tokenizer, pl_bert_config, pl_bert_cleaner, device, add_blank=hps.data.add_blank) 100 | x = get_text_ids(phonemes=phonemes, 101 | add_blank=hps.data.add_blank).unsqueeze(0) 102 | bert_features = bert_features.unsqueeze(0) 103 | x_lengths = torch.LongTensor([x.size(1)]) 104 | sid = torch.LongTensor([0]) 105 | y_hat, _, _, _ = net_g.infer(x .to(device), 106 | x_lengths .to(device), 107 | bert_features .to(device), 108 | x_lengths .to(device), 109 | sid .to(device), 110 | noise_scale=noise_scale.to(device), 111 | noise_scale_w=noise_scale_w.to(device), 112 | length_scale=length_scale.to(device), 113 | max_len=1000) 114 | y_hat = y_hat.permute(0,2,1)[0, :, :].cpu().float().numpy().copy() 115 | 116 | # measure the execution time 117 | torch.cuda.synchronize() 118 | elapsed_time = time.time() - start 119 | print(f"Gen Time : {elapsed_time}") 120 | 121 | # play audio 122 | speaker.play(y_hat, hps.data.sampling_rate) 123 | 124 | # save audio 125 | if is_save is True: 126 | n_save += 1 127 | data = y_hat 128 | try: 129 | save_path = os.path.join(save_dir, str(n_save).zfill(3)+f"_{text}.wav") 130 | sf.write( 131 | file=save_path, 132 | data=data, 133 | samplerate=hps.data.sampling_rate, 134 | format="WAV") 135 | except: 136 | save_path = os.path.join(save_dir, str(n_save).zfill(3)+f"_{text[:10]}〜.wav") 137 | sf.write( 138 | file=save_path, 139 | data=data, 140 | samplerate=hps.data.sampling_rate, 141 | format="WAV") 142 | 143 | print(f"Audio is saved at : {save_path}") 144 | 145 | return 0 146 | 147 | def get_text_ids(phonemes, add_blank): 148 | 149 | text_norm = cleaned_text_to_sequence(phonemes) 150 | 151 | if add_blank: 152 | text_norm = commons.intersperse(text_norm, 0) 153 | 154 | text_norm = torch.LongTensor(text_norm) 155 | return text_norm 156 | 157 | 158 | def get_bert_features(text, pl_bert_model, pl_bert_tokenizer, pl_bert_config, pl_bert_cleaner, device, add_blank): 159 | text = text.replace("\n", "") 160 | hidden_size = pl_bert_config["model_params"]["hidden_size"] 161 | n_layers = pl_bert_config["model_params"]["num_hidden_layers"] + 1 162 | phonemes = ''.join(phonemize(text,pl_bert_tokenizer)["phonemes"]) 163 | input_ids = pl_bert_cleaner(phonemes) 164 | with torch.inference_mode(): 165 | hidden_stats = pl_bert_model(torch.tensor(input_ids, dtype=torch.int64, device=device).unsqueeze(0))[-1]["hidden_states"] 166 | save_tensor = torch.zeros(size=(n_layers, len(input_ids), hidden_size)) 167 | for idx, hidden_stat in enumerate(hidden_stats): 168 | save_tensor[idx, :, :] = hidden_stat 169 | 170 | 171 | if add_blank is True: 172 | L, T, H = save_tensor.shape 173 | new_data = torch.zeros(size=(L,2*T+1,H), dtype=save_tensor.dtype) 174 | for idx in range(T): 175 | target_idx = idx*2+1 176 | new_data[:, target_idx, :] = save_tensor[:, idx, :] 177 | save_tensor = new_data 178 | 179 | return save_tensor, phonemes 180 | 181 | def text2input_ids(): 182 | return 0 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--model_ckpt_path", default="./logs/AddBlankTrue/G_54000.pth") 187 | parser.add_argument("--model_cnfg_path", default="./logs/AddBlankTrue/config.json") 188 | parser.add_argument("--pl_bert_dir", default="./plb-ja_10000000-steps/") 189 | parser.add_argument("--is_save", default=False) 190 | 191 | args = parser.parse_args() 192 | 193 | inference(args.model_ckpt_path, args.model_cnfg_path, args.pl_bert_dir, args.is_save) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import commons 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1 - dr) ** 2) 26 | g_loss = torch.mean(dg**2) 27 | loss += r_loss + g_loss 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1 - dg) ** 2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | # warnings.simplefilter(action='ignore', category=FutureWarning) 4 | warnings.filterwarnings(action="ignore") 5 | 6 | import math 7 | import os 8 | import random 9 | 10 | import librosa 11 | import librosa.util as librosa_util 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.data 16 | from librosa.filters import mel as librosa_mel_fn 17 | from librosa.util import normalize, pad_center, tiny 18 | from packaging import version 19 | from scipy.io.wavfile import read 20 | from scipy.signal import get_window 21 | from torch import nn 22 | 23 | MAX_WAV_VALUE = 32768.0 24 | 25 | 26 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 27 | """ 28 | PARAMS 29 | ------ 30 | C: compression factor 31 | """ 32 | return torch.log(torch.clamp(x, min=clip_val) * C) 33 | 34 | 35 | def dynamic_range_decompression_torch(x, C=1): 36 | """ 37 | PARAMS 38 | ------ 39 | C: compression factor used to compress 40 | """ 41 | return torch.exp(x) / C 42 | 43 | 44 | def spectral_normalize_torch(magnitudes): 45 | output = dynamic_range_compression_torch(magnitudes) 46 | return output 47 | 48 | 49 | def spectral_de_normalize_torch(magnitudes): 50 | output = dynamic_range_decompression_torch(magnitudes) 51 | return output 52 | 53 | 54 | mel_basis = {} 55 | hann_window = {} 56 | 57 | 58 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 59 | if torch.min(y) < -1.0: 60 | print("min value is ", torch.min(y)) 61 | if torch.max(y) > 1.0: 62 | print("max value is ", torch.max(y)) 63 | 64 | global hann_window 65 | dtype_device = str(y.dtype) + "_" + str(y.device) 66 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 67 | if wnsize_dtype_device not in hann_window: 68 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 69 | dtype=y.dtype, device=y.device 70 | ) 71 | 72 | y = torch.nn.functional.pad( 73 | y.unsqueeze(1), 74 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 75 | mode="reflect", 76 | ) 77 | y = y.squeeze(1) 78 | 79 | if version.parse(torch.__version__) >= version.parse("2"): 80 | spec = torch.stft( 81 | y, 82 | n_fft, 83 | hop_length=hop_size, 84 | win_length=win_size, 85 | window=hann_window[wnsize_dtype_device], 86 | center=center, 87 | pad_mode="reflect", 88 | normalized=False, 89 | onesided=True, 90 | return_complex=False, 91 | ) 92 | else: 93 | spec = torch.stft( 94 | y, 95 | n_fft, 96 | hop_length=hop_size, 97 | win_length=win_size, 98 | window=hann_window[wnsize_dtype_device], 99 | center=center, 100 | pad_mode="reflect", 101 | normalized=False, 102 | onesided=True, 103 | ) 104 | 105 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 106 | return spec 107 | 108 | 109 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 110 | global mel_basis 111 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 112 | fmax_dtype_device = str(fmax) + "_" + dtype_device 113 | if fmax_dtype_device not in mel_basis: 114 | mel = librosa_mel_fn( 115 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 116 | ) 117 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 118 | dtype=spec.dtype, device=spec.device 119 | ) 120 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 121 | spec = spectral_normalize_torch(spec) 122 | return spec 123 | 124 | 125 | def mel_spectrogram_torch( 126 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 127 | ): 128 | if torch.min(y) < -1.0: 129 | print("min value is ", torch.min(y)) 130 | if torch.max(y) > 1.0: 131 | print("max value is ", torch.max(y)) 132 | 133 | global mel_basis, hann_window 134 | dtype_device = str(y.dtype) + "_" + str(y.device) 135 | fmax_dtype_device = str(fmax) + "_" + dtype_device 136 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 137 | if fmax_dtype_device not in mel_basis: 138 | mel = librosa_mel_fn( 139 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 140 | ) 141 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 142 | dtype=y.dtype, device=y.device 143 | ) 144 | if wnsize_dtype_device not in hann_window: 145 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 146 | dtype=y.dtype, device=y.device 147 | ) 148 | 149 | y = torch.nn.functional.pad( 150 | y.unsqueeze(1), 151 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 152 | mode="reflect", 153 | ) 154 | y = y.squeeze(1) 155 | 156 | if version.parse(torch.__version__) >= version.parse("2"): 157 | spec = torch.stft( 158 | y, 159 | n_fft, 160 | hop_length=hop_size, 161 | win_length=win_size, 162 | window=hann_window[wnsize_dtype_device], 163 | center=center, 164 | pad_mode="reflect", 165 | normalized=False, 166 | onesided=True, 167 | return_complex=False, 168 | ) 169 | else: 170 | spec = torch.stft( 171 | y, 172 | n_fft, 173 | hop_length=hop_size, 174 | win_length=win_size, 175 | window=hann_window[wnsize_dtype_device], 176 | center=center, 177 | pad_mode="reflect", 178 | normalized=False, 179 | onesided=True, 180 | ) 181 | 182 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 183 | 184 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 185 | spec = spectral_normalize_torch(spec) 186 | 187 | return spec 188 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 10 | from torch.nn.utils import weight_norm, remove_weight_norm 11 | 12 | import commons 13 | from commons import init_weights, get_padding 14 | from transforms import piecewise_rational_quadratic_transform 15 | 16 | 17 | LRELU_SLOPE = 0.1 18 | 19 | 20 | class LayerNorm(nn.Module): 21 | def __init__(self, channels, eps=1e-5): 22 | super().__init__() 23 | self.channels = channels 24 | self.eps = eps 25 | 26 | self.gamma = nn.Parameter(torch.ones(channels)) 27 | self.beta = nn.Parameter(torch.zeros(channels)) 28 | 29 | def forward(self, x): 30 | x = x.transpose(1, -1) 31 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 32 | return x.transpose(1, -1) 33 | 34 | 35 | class ConvReluNorm(nn.Module): 36 | def __init__( 37 | self, 38 | in_channels, 39 | hidden_channels, 40 | out_channels, 41 | kernel_size, 42 | n_layers, 43 | p_dropout, 44 | ): 45 | super().__init__() 46 | self.in_channels = in_channels 47 | self.hidden_channels = hidden_channels 48 | self.out_channels = out_channels 49 | self.kernel_size = kernel_size 50 | self.n_layers = n_layers 51 | self.p_dropout = p_dropout 52 | assert n_layers > 1, "Number of layers should be larger than 0." 53 | 54 | self.conv_layers = nn.ModuleList() 55 | self.norm_layers = nn.ModuleList() 56 | self.conv_layers.append( 57 | nn.Conv1d( 58 | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 59 | ) 60 | ) 61 | self.norm_layers.append(LayerNorm(hidden_channels)) 62 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) 63 | for _ in range(n_layers - 1): 64 | self.conv_layers.append( 65 | nn.Conv1d( 66 | hidden_channels, 67 | hidden_channels, 68 | kernel_size, 69 | padding=kernel_size // 2, 70 | ) 71 | ) 72 | self.norm_layers.append(LayerNorm(hidden_channels)) 73 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 74 | self.proj.weight.data.zero_() 75 | self.proj.bias.data.zero_() 76 | 77 | def forward(self, x, x_mask): 78 | x_org = x 79 | for i in range(self.n_layers): 80 | x = self.conv_layers[i](x * x_mask) 81 | x = self.norm_layers[i](x) 82 | x = self.relu_drop(x) 83 | x = x_org + self.proj(x) 84 | return x * x_mask 85 | 86 | 87 | class DDSConv(nn.Module): 88 | """ 89 | Dialted and Depth-Separable Convolution 90 | """ 91 | 92 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): 93 | super().__init__() 94 | self.channels = channels 95 | self.kernel_size = kernel_size 96 | self.n_layers = n_layers 97 | self.p_dropout = p_dropout 98 | 99 | self.drop = nn.Dropout(p_dropout) 100 | self.convs_sep = nn.ModuleList() 101 | self.convs_1x1 = nn.ModuleList() 102 | self.norms_1 = nn.ModuleList() 103 | self.norms_2 = nn.ModuleList() 104 | for i in range(n_layers): 105 | dilation = kernel_size**i 106 | padding = (kernel_size * dilation - dilation) // 2 107 | self.convs_sep.append( 108 | nn.Conv1d( 109 | channels, 110 | channels, 111 | kernel_size, 112 | groups=channels, 113 | dilation=dilation, 114 | padding=padding, 115 | ) 116 | ) 117 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 118 | self.norms_1.append(LayerNorm(channels)) 119 | self.norms_2.append(LayerNorm(channels)) 120 | 121 | def forward(self, x, x_mask, g=None): 122 | if g is not None: 123 | x = x + g 124 | for i in range(self.n_layers): 125 | y = self.convs_sep[i](x * x_mask) 126 | y = self.norms_1[i](y) 127 | y = F.gelu(y) 128 | y = self.convs_1x1[i](y) 129 | y = self.norms_2[i](y) 130 | y = F.gelu(y) 131 | y = self.drop(y) 132 | x = x + y 133 | return x * x_mask 134 | 135 | 136 | class WN(torch.nn.Module): 137 | def __init__( 138 | self, 139 | hidden_channels, 140 | kernel_size, 141 | dilation_rate, 142 | n_layers, 143 | gin_channels=0, 144 | p_dropout=0, 145 | ): 146 | super(WN, self).__init__() 147 | assert kernel_size % 2 == 1 148 | self.hidden_channels = hidden_channels 149 | self.kernel_size = (kernel_size,) 150 | self.dilation_rate = dilation_rate 151 | self.n_layers = n_layers 152 | self.gin_channels = gin_channels 153 | self.p_dropout = p_dropout 154 | 155 | self.in_layers = torch.nn.ModuleList() 156 | self.res_skip_layers = torch.nn.ModuleList() 157 | self.drop = nn.Dropout(p_dropout) 158 | 159 | if gin_channels != 0: 160 | cond_layer = torch.nn.Conv1d( 161 | gin_channels, 2 * hidden_channels * n_layers, 1 162 | ) 163 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") 164 | 165 | for i in range(n_layers): 166 | dilation = dilation_rate**i 167 | padding = int((kernel_size * dilation - dilation) / 2) 168 | in_layer = torch.nn.Conv1d( 169 | hidden_channels, 170 | 2 * hidden_channels, 171 | kernel_size, 172 | dilation=dilation, 173 | padding=padding, 174 | ) 175 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") 176 | self.in_layers.append(in_layer) 177 | 178 | # last one is not necessary 179 | if i < n_layers - 1: 180 | res_skip_channels = 2 * hidden_channels 181 | else: 182 | res_skip_channels = hidden_channels 183 | 184 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 185 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") 186 | self.res_skip_layers.append(res_skip_layer) 187 | 188 | def forward(self, x, x_mask, g=None, **kwargs): 189 | output = torch.zeros_like(x) 190 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 191 | 192 | if g is not None: 193 | g = self.cond_layer(g) 194 | 195 | for i in range(self.n_layers): 196 | x_in = self.in_layers[i](x) 197 | if g is not None: 198 | cond_offset = i * 2 * self.hidden_channels 199 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] 200 | else: 201 | g_l = torch.zeros_like(x_in) 202 | 203 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) 204 | acts = self.drop(acts) 205 | 206 | res_skip_acts = self.res_skip_layers[i](acts) 207 | if i < self.n_layers - 1: 208 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 209 | x = (x + res_acts) * x_mask 210 | output = output + res_skip_acts[:, self.hidden_channels :, :] 211 | else: 212 | output = output + res_skip_acts 213 | return output * x_mask 214 | 215 | def remove_weight_norm(self): 216 | if self.gin_channels != 0: 217 | torch.nn.utils.remove_weight_norm(self.cond_layer) 218 | for l in self.in_layers: 219 | torch.nn.utils.remove_weight_norm(l) 220 | for l in self.res_skip_layers: 221 | torch.nn.utils.remove_weight_norm(l) 222 | 223 | 224 | class ResBlock1(torch.nn.Module): 225 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 226 | super(ResBlock1, self).__init__() 227 | self.convs1 = nn.ModuleList( 228 | [ 229 | weight_norm( 230 | Conv1d( 231 | channels, 232 | channels, 233 | kernel_size, 234 | 1, 235 | dilation=dilation[0], 236 | padding=get_padding(kernel_size, dilation[0]), 237 | ) 238 | ), 239 | weight_norm( 240 | Conv1d( 241 | channels, 242 | channels, 243 | kernel_size, 244 | 1, 245 | dilation=dilation[1], 246 | padding=get_padding(kernel_size, dilation[1]), 247 | ) 248 | ), 249 | weight_norm( 250 | Conv1d( 251 | channels, 252 | channels, 253 | kernel_size, 254 | 1, 255 | dilation=dilation[2], 256 | padding=get_padding(kernel_size, dilation[2]), 257 | ) 258 | ), 259 | ] 260 | ) 261 | self.convs1.apply(init_weights) 262 | 263 | self.convs2 = nn.ModuleList( 264 | [ 265 | weight_norm( 266 | Conv1d( 267 | channels, 268 | channels, 269 | kernel_size, 270 | 1, 271 | dilation=1, 272 | padding=get_padding(kernel_size, 1), 273 | ) 274 | ), 275 | weight_norm( 276 | Conv1d( 277 | channels, 278 | channels, 279 | kernel_size, 280 | 1, 281 | dilation=1, 282 | padding=get_padding(kernel_size, 1), 283 | ) 284 | ), 285 | weight_norm( 286 | Conv1d( 287 | channels, 288 | channels, 289 | kernel_size, 290 | 1, 291 | dilation=1, 292 | padding=get_padding(kernel_size, 1), 293 | ) 294 | ), 295 | ] 296 | ) 297 | self.convs2.apply(init_weights) 298 | 299 | def forward(self, x, x_mask=None): 300 | for c1, c2 in zip(self.convs1, self.convs2): 301 | xt = F.leaky_relu(x, LRELU_SLOPE) 302 | if x_mask is not None: 303 | xt = xt * x_mask 304 | xt = c1(xt) 305 | xt = F.leaky_relu(xt, LRELU_SLOPE) 306 | if x_mask is not None: 307 | xt = xt * x_mask 308 | xt = c2(xt) 309 | x = xt + x 310 | if x_mask is not None: 311 | x = x * x_mask 312 | return x 313 | 314 | def remove_weight_norm(self): 315 | for l in self.convs1: 316 | remove_weight_norm(l) 317 | for l in self.convs2: 318 | remove_weight_norm(l) 319 | 320 | 321 | class ResBlock2(torch.nn.Module): 322 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 323 | super(ResBlock2, self).__init__() 324 | self.convs = nn.ModuleList( 325 | [ 326 | weight_norm( 327 | Conv1d( 328 | channels, 329 | channels, 330 | kernel_size, 331 | 1, 332 | dilation=dilation[0], 333 | padding=get_padding(kernel_size, dilation[0]), 334 | ) 335 | ), 336 | weight_norm( 337 | Conv1d( 338 | channels, 339 | channels, 340 | kernel_size, 341 | 1, 342 | dilation=dilation[1], 343 | padding=get_padding(kernel_size, dilation[1]), 344 | ) 345 | ), 346 | ] 347 | ) 348 | self.convs.apply(init_weights) 349 | 350 | def forward(self, x, x_mask=None): 351 | for c in self.convs: 352 | xt = F.leaky_relu(x, LRELU_SLOPE) 353 | if x_mask is not None: 354 | xt = xt * x_mask 355 | xt = c(xt) 356 | x = xt + x 357 | if x_mask is not None: 358 | x = x * x_mask 359 | return x 360 | 361 | def remove_weight_norm(self): 362 | for l in self.convs: 363 | remove_weight_norm(l) 364 | 365 | 366 | class Log(nn.Module): 367 | def forward(self, x, x_mask, reverse=False, **kwargs): 368 | if not reverse: 369 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 370 | logdet = torch.sum(-y, [1, 2]) 371 | return y, logdet 372 | else: 373 | x = torch.exp(x) * x_mask 374 | return x 375 | 376 | 377 | class Flip(nn.Module): 378 | def forward(self, x, *args, reverse=False, **kwargs): 379 | x = torch.flip(x, [1]) 380 | if not reverse: 381 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 382 | return x, logdet 383 | else: 384 | return x 385 | 386 | 387 | class ElementwiseAffine(nn.Module): 388 | def __init__(self, channels): 389 | super().__init__() 390 | self.channels = channels 391 | self.m = nn.Parameter(torch.zeros(channels, 1)) 392 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 393 | 394 | def forward(self, x, x_mask, reverse=False, **kwargs): 395 | if not reverse: 396 | y = self.m + torch.exp(self.logs) * x 397 | y = y * x_mask 398 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 399 | return y, logdet 400 | else: 401 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 402 | return x 403 | 404 | 405 | class ResidualCouplingLayer(nn.Module): 406 | def __init__( 407 | self, 408 | channels, 409 | hidden_channels, 410 | kernel_size, 411 | dilation_rate, 412 | n_layers, 413 | p_dropout=0, 414 | gin_channels=0, 415 | mean_only=False, 416 | ): 417 | assert channels % 2 == 0, "channels should be divisible by 2" 418 | super().__init__() 419 | self.channels = channels 420 | self.hidden_channels = hidden_channels 421 | self.kernel_size = kernel_size 422 | self.dilation_rate = dilation_rate 423 | self.n_layers = n_layers 424 | self.half_channels = channels // 2 425 | self.mean_only = mean_only 426 | 427 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 428 | self.enc = WN( 429 | hidden_channels, 430 | kernel_size, 431 | dilation_rate, 432 | n_layers, 433 | p_dropout=p_dropout, 434 | gin_channels=gin_channels, 435 | ) 436 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 437 | self.post.weight.data.zero_() 438 | self.post.bias.data.zero_() 439 | 440 | def forward(self, x, x_mask, g=None, reverse=False): 441 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 442 | h = self.pre(x0) * x_mask 443 | h = self.enc(h, x_mask, g=g) 444 | stats = self.post(h) * x_mask 445 | if not self.mean_only: 446 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 447 | else: 448 | m = stats 449 | logs = torch.zeros_like(m) 450 | 451 | if not reverse: 452 | x1 = m + x1 * torch.exp(logs) * x_mask 453 | x = torch.cat([x0, x1], 1) 454 | logdet = torch.sum(logs, [1, 2]) 455 | return x, logdet 456 | else: 457 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 458 | x = torch.cat([x0, x1], 1) 459 | return x 460 | 461 | 462 | class ConvFlow(nn.Module): 463 | def __init__( 464 | self, 465 | in_channels, 466 | filter_channels, 467 | kernel_size, 468 | n_layers, 469 | num_bins=10, 470 | tail_bound=5.0, 471 | ): 472 | super().__init__() 473 | self.in_channels = in_channels 474 | self.filter_channels = filter_channels 475 | self.kernel_size = kernel_size 476 | self.n_layers = n_layers 477 | self.num_bins = num_bins 478 | self.tail_bound = tail_bound 479 | self.half_channels = in_channels // 2 480 | 481 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 482 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) 483 | self.proj = nn.Conv1d( 484 | filter_channels, self.half_channels * (num_bins * 3 - 1), 1 485 | ) 486 | self.proj.weight.data.zero_() 487 | self.proj.bias.data.zero_() 488 | 489 | def forward(self, x, x_mask, g=None, reverse=False): 490 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 491 | h = self.pre(x0) 492 | h = self.convs(h, x_mask, g=g) 493 | h = self.proj(h) * x_mask 494 | 495 | b, c, t = x0.shape 496 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 497 | 498 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) 499 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( 500 | self.filter_channels 501 | ) 502 | unnormalized_derivatives = h[..., 2 * self.num_bins :] 503 | 504 | x1, logabsdet = piecewise_rational_quadratic_transform( 505 | x1, 506 | unnormalized_widths, 507 | unnormalized_heights, 508 | unnormalized_derivatives, 509 | inverse=reverse, 510 | tails="linear", 511 | tail_bound=self.tail_bound, 512 | ) 513 | 514 | x = torch.cat([x0, x1], 1) * x_mask 515 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 516 | if not reverse: 517 | return x, logdet 518 | else: 519 | return x 520 | -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(neg_cent, mask): 7 | """Cython optimized version. 8 | neg_cent: [b, t_t, t_s] 9 | mask: [b, t_t, t_s] 10 | """ 11 | device = neg_cent.device 12 | dtype = neg_cent.dtype 13 | neg_cent = neg_cent.data.cpu().numpy().astype(np.float32) 14 | path = np.zeros(neg_cent.shape, dtype=np.int32) 15 | 16 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32) 17 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32) 18 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 19 | return torch.from_numpy(path).to(device=device, dtype=dtype) 20 | -------------------------------------------------------------------------------- /monotonic_align/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/monotonic_align/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /monotonic_align/build/lib.linux-x86_64-cpython-38/monotonic_align/core.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/monotonic_align/build/lib.linux-x86_64-cpython-38/monotonic_align/core.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /monotonic_align/build/temp.linux-x86_64-cpython-38/core.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/monotonic_align/build/temp.linux-x86_64-cpython-38/core.o -------------------------------------------------------------------------------- /monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython.parallel import prange 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: 8 | cdef int x 9 | cdef int y 10 | cdef float v_prev 11 | cdef float v_cur 12 | cdef float tmp 13 | cdef int index = t_x - 1 14 | 15 | for y in range(t_y): 16 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 17 | if x == y: 18 | v_cur = max_neg_val 19 | else: 20 | v_cur = value[y-1, x] 21 | if x == 0: 22 | if y == 0: 23 | v_prev = 0. 24 | else: 25 | v_prev = max_neg_val 26 | else: 27 | v_prev = value[y-1, x-1] 28 | value[y, x] += max(v_prev, v_cur) 29 | 30 | for y in range(t_y - 1, -1, -1): 31 | path[y, index] = 1 32 | if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]): 33 | index = index - 1 34 | 35 | 36 | @cython.boundscheck(False) 37 | @cython.wraparound(False) 38 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil: 39 | cdef int b = paths.shape[0] 40 | cdef int i 41 | for i in prange(b, nogil=True): 42 | maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) 43 | -------------------------------------------------------------------------------- /monotonic_align/monotonic_align/core.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/monotonic_align/monotonic_align/core.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name="monotonic_align", 7 | ext_modules=cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()], 9 | ) 10 | -------------------------------------------------------------------------------- /preprocess_ja.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import polars 4 | import random 5 | from PL_BERT_ja.text_utils import TextCleaner 6 | from PL_BERT_ja.phonemize import phonemize 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from PL_BERT_ja.model import MultiTaskModel 11 | from transformers import AlbertConfig, AlbertModel 12 | from transformers import BertJapaneseTokenizer 13 | import yaml, torch 14 | 15 | def preprocess(dataset_dir, pl_bert_dir): 16 | 17 | n_val_test_file = 10 18 | filelist_dir = "./filelists/" 19 | dataset_name = "jvnv_ver1" 20 | os.makedirs(filelist_dir, exist_ok=True) 21 | split_symbol = "||||" 22 | 23 | transcript_csv_df = polars.read_csv(os.path.join(dataset_dir, "jvnv_v1", "transcription.csv"),has_header=False)[:, 0] 24 | emo_list = os.listdir(os.path.join(dataset_dir,"jvnv_v1", "F1")) 25 | style_list = os.listdir(os.path.join(dataset_dir,"jvnv_v1", "F1", "anger")) 26 | 27 | pl_bert_savedir = "./pl_bert_embeddings" 28 | os.makedirs(pl_bert_savedir, exist_ok=True) 29 | pl_bert_model, pl_bert_config, device = get_pl_bert_ja(dir=pl_bert_dir) 30 | pl_bert_cleaner = TextCleaner() 31 | pl_bert_tokenizer = BertJapaneseTokenizer.from_pretrained(pl_bert_config['dataset_params']['tokenizer']) 32 | 33 | hidden_size = pl_bert_config["model_params"]["hidden_size"] 34 | n_layers = pl_bert_config["model_params"]["num_hidden_layers"] + 1 35 | 36 | filelists = list() 37 | spk_g = ["F", "M"] 38 | for line in tqdm(transcript_csv_df): 39 | index_name, emo_prefix, text = line.split("|") 40 | emotion, style, file_idx = index_name.split("_") 41 | text = text.replace("\n", "") 42 | 43 | phonemes = ''.join(phonemize(text,pl_bert_tokenizer)["phonemes"]) 44 | input_ids = pl_bert_cleaner(phonemes) 45 | with torch.inference_mode(): 46 | hidden_stats = pl_bert_model(torch.tensor(input_ids, dtype=torch.int64, device=device).unsqueeze(0))[-1]["hidden_states"] 47 | save_tensor = torch.zeros(size=(n_layers, len(input_ids), hidden_size), device=device) 48 | for idx, hidden_stat in enumerate(hidden_stats): 49 | save_tensor[idx, :, :] = hidden_stat 50 | torch.save(save_tensor.to('cpu').detach(), os.path.join(pl_bert_savedir, f"{index_name}.PlBertJa")) 51 | 52 | for g_idx in range(2): 53 | for spk_idx in range(2): 54 | spk_ID = str(g_idx + spk_idx*2) 55 | spk = spk_g[g_idx] + str(spk_idx+1) 56 | wav_path = os.path.join(dataset_dir, "jvnv_v1", spk, emotion, style, f"{spk}_{emotion}_{style}_{file_idx}.wav") 57 | filelists.append(f"{wav_path}{split_symbol}{spk_ID}{split_symbol}{phonemes}{split_symbol}{text}{split_symbol}{index_name}{split_symbol}emo:{str(emo_list.index(emotion))}{split_symbol}style:{str(style_list.index(style))}\n") 58 | 59 | val_list = list() 60 | test_list = list() 61 | for idx in range(n_val_test_file*2): 62 | target_idx = random.randint(0, len(filelists)) 63 | target_line = filelists.pop(target_idx) 64 | if idx % 2 == 1: 65 | val_list.append(target_line) 66 | else: 67 | test_list.append(target_line) 68 | 69 | write_txt(filelists, os.path.join(filelist_dir, f"{dataset_name}_train.txt")) 70 | write_txt(val_list, os.path.join(filelist_dir, f"{dataset_name}_val.txt")) 71 | write_txt(test_list, os.path.join(filelist_dir, f"{dataset_name}_test.txt")) 72 | 73 | return 0 74 | 75 | def write_txt(lists, path): 76 | with open(path, mode="w", encoding="utf-8") as f: 77 | f.writelines(lists) 78 | 79 | def get_pl_bert_ja(dir): 80 | 81 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 82 | config_path=os.path.join(dir, "config.yml") 83 | config = yaml.safe_load(open(config_path)) 84 | 85 | albert_base_configuration = AlbertConfig(**config['model_params']) 86 | bert_ = AlbertModel(albert_base_configuration).to(device) 87 | #num_vocab = max([m['token'] for m in token_maps.values()]) + 1 # 30923 + 1 88 | bert = MultiTaskModel( 89 | bert_, 90 | num_vocab=30923 + 1, 91 | num_tokens=config['model_params']['vocab_size'], 92 | hidden_size=config['model_params']['hidden_size'] 93 | ) 94 | 95 | model_ckpt_path = os.path.join(dir,"10000000.pth.tar") 96 | checkpoint = torch.load(model_ckpt_path) 97 | bert.load_state_dict(checkpoint['model'], strict=False) 98 | 99 | bert.to(device) 100 | return bert, config, device 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--jvnv_dir", default="./jvnv_ver1/") 106 | parser.add_argument("--pl_bert_dir", default="./plb-ja_10000000-steps/") 107 | 108 | args = parser.parse_args() 109 | 110 | preprocess(args.jvnv_dir, args.pl_bert_dir) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==3.0.2 2 | librosa==0.10.1 3 | matplotlib==3.7.2 4 | numpy==1.24.4 5 | phonemizer==3.2.1 6 | Unidecode==1.3.6 7 | tensorboard==2.14.0 8 | onnx==1.14.1 9 | onnxruntime==1.15.1 10 | gradio 11 | pyopenjtalk-prebuild 12 | scipy 13 | transformers 14 | cmudict 15 | pinyin 16 | fugashi 17 | ipadic 18 | soundcard -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from text import cleaners 3 | from PL_BERT_ja.text_utils import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, cleaner_names): 12 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | """ 19 | sequence = [] 20 | 21 | clean_text = _clean_text(text, cleaner_names) 22 | for symbol in clean_text: 23 | if symbol in _symbol_to_id.keys(): 24 | symbol_id = _symbol_to_id[symbol] 25 | sequence += [symbol_id] 26 | else: 27 | continue 28 | return sequence 29 | 30 | 31 | def cleaned_text_to_sequence(cleaned_text): 32 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 33 | Args: 34 | text: string to convert to a sequence 35 | Returns: 36 | List of integers corresponding to the symbols in the text 37 | """ 38 | sequence = [] 39 | 40 | for symbol in cleaned_text: 41 | if symbol in _symbol_to_id.keys(): 42 | symbol_id = _symbol_to_id[symbol] 43 | sequence += [symbol_id] 44 | else: 45 | continue 46 | return sequence 47 | 48 | 49 | def sequence_to_text(sequence): 50 | """Converts a sequence of IDs back to a string""" 51 | result = "" 52 | for symbol_id in sequence: 53 | s = _id_to_symbol[symbol_id] 54 | result += s 55 | return result 56 | 57 | 58 | def _clean_text(text, cleaner_names): 59 | for name in cleaner_names: 60 | cleaner = getattr(cleaners, name) 61 | if not cleaner: 62 | raise Exception("Unknown cleaner: %s" % name) 63 | text = cleaner(text) 64 | return text 65 | -------------------------------------------------------------------------------- /text/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/text/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/cleaners.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/text/__pycache__/cleaners.cpython-38.pyc -------------------------------------------------------------------------------- /text/__pycache__/symbols.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnetonne814/PL-Bert-VITS2/a14feb38af363bf6e3c0797a0c9cb7e79557a1e7/text/__pycache__/symbols.cpython-38.pyc -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | """ 14 | 15 | import re 16 | from unidecode import unidecode 17 | from phonemizer import phonemize 18 | from phonemizer.backend import EspeakBackend 19 | backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True) 20 | 21 | 22 | # Regular expression matching whitespace: 23 | _whitespace_re = re.compile(r"\s+") 24 | 25 | # List of (regular expression, replacement) pairs for abbreviations: 26 | _abbreviations = [ 27 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 28 | for x in [ 29 | ("mrs", "misess"), 30 | ("mr", "mister"), 31 | ("dr", "doctor"), 32 | ("st", "saint"), 33 | ("co", "company"), 34 | ("jr", "junior"), 35 | ("maj", "major"), 36 | ("gen", "general"), 37 | ("drs", "doctors"), 38 | ("rev", "reverend"), 39 | ("lt", "lieutenant"), 40 | ("hon", "honorable"), 41 | ("sgt", "sergeant"), 42 | ("capt", "captain"), 43 | ("esq", "esquire"), 44 | ("ltd", "limited"), 45 | ("col", "colonel"), 46 | ("ft", "fort"), 47 | ] 48 | ] 49 | 50 | 51 | def expand_abbreviations(text): 52 | for regex, replacement in _abbreviations: 53 | text = re.sub(regex, replacement, text) 54 | return text 55 | 56 | 57 | def expand_numbers(text): 58 | return normalize_numbers(text) 59 | 60 | 61 | def lowercase(text): 62 | return text.lower() 63 | 64 | 65 | def collapse_whitespace(text): 66 | return re.sub(_whitespace_re, " ", text) 67 | 68 | 69 | def convert_to_ascii(text): 70 | return unidecode(text) 71 | 72 | 73 | def basic_cleaners(text): 74 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 75 | text = lowercase(text) 76 | text = collapse_whitespace(text) 77 | return text 78 | 79 | 80 | def transliteration_cleaners(text): 81 | """Pipeline for non-English text that transliterates to ASCII.""" 82 | text = convert_to_ascii(text) 83 | text = lowercase(text) 84 | text = collapse_whitespace(text) 85 | return text 86 | 87 | 88 | def english_cleaners(text): 89 | """Pipeline for English text, including abbreviation expansion.""" 90 | text = convert_to_ascii(text) 91 | text = lowercase(text) 92 | text = expand_abbreviations(text) 93 | phonemes = phonemize(text, language="en-us", backend="espeak", strip=True) 94 | phonemes = collapse_whitespace(phonemes) 95 | return phonemes 96 | 97 | 98 | def english_cleaners2(text): 99 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 100 | text = convert_to_ascii(text) 101 | text = lowercase(text) 102 | text = expand_abbreviations(text) 103 | phonemes = phonemize( 104 | text, 105 | language="en-us", 106 | backend="espeak", 107 | strip=True, 108 | preserve_punctuation=True, 109 | with_stress=True, 110 | ) 111 | phonemes = collapse_whitespace(phonemes) 112 | return phonemes 113 | 114 | 115 | def english_cleaners3(text): 116 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 117 | text = convert_to_ascii(text) 118 | text = lowercase(text) 119 | text = expand_abbreviations(text) 120 | phonemes = backend.phonemize([text], strip=True)[0] 121 | phonemes = collapse_whitespace(phonemes) 122 | return phonemes 123 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | """ 6 | _pad = "_" 7 | _punctuation = ';:,.!?¡¿—…"«»“” ' 8 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 9 | _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | 11 | 12 | # Export all symbols: 13 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 14 | 15 | # Special symbol ids 16 | SPACE_ID = symbols.index(" ") 17 | -------------------------------------------------------------------------------- /train_ms.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import math 5 | import os 6 | 7 | import logging 8 | 9 | numba_logger = logging.getLogger('numba') 10 | numba_logger.setLevel(logging.WARNING) 11 | 12 | import torch 13 | import torch.distributed as dist 14 | # from tensorboardX import SummaryWriter 15 | import torch.multiprocessing as mp 16 | import tqdm 17 | from torch import nn, optim 18 | from torch.cuda.amp import GradScaler, autocast 19 | from torch.nn import functional as F 20 | from torch.nn.parallel import DistributedDataParallel as DDP 21 | from torch.utils.data import DataLoader 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | import commons 25 | import models 26 | import utils 27 | from data_utils import (DistributedBucketSampler, TextAudioSpeakerCollate, 28 | TextAudioSpeakerLoader) 29 | from losses import discriminator_loss, feature_loss, generator_loss, kl_loss 30 | from mel_processing import mel_spectrogram_torch, spec_to_mel_torch 31 | from models import (AVAILABLE_DURATION_DISCRIMINATOR_TYPES, 32 | AVAILABLE_FLOW_TYPES, 33 | DurationDiscriminatorV1, DurationDiscriminatorV2, 34 | MultiPeriodDiscriminator, SynthesizerTrn) 35 | from PL_BERT_ja.text.symbols import symbols 36 | 37 | torch.backends.cudnn.benchmark = True 38 | global_step = 0 39 | 40 | def main(): 41 | """Assume Single Node Multi GPUs Training Only""" 42 | assert torch.cuda.is_available(), "CPU training is not allowed." 43 | 44 | n_gpus = torch.cuda.device_count() 45 | os.environ["MASTER_ADDR"] = "localhost" 46 | os.environ["MASTER_PORT"] = "6060" 47 | 48 | hps = utils.get_hparams() 49 | mp.spawn( 50 | run, 51 | nprocs=n_gpus, 52 | args=( 53 | n_gpus, 54 | hps, 55 | ), 56 | ) 57 | 58 | 59 | def run(rank, n_gpus, hps): 60 | net_dur_disc = None 61 | global global_step 62 | if rank == 0: 63 | logger = utils.get_logger(hps.model_dir) 64 | logger.info(hps) 65 | utils.check_git_hash(hps.model_dir) 66 | writer = SummaryWriter(log_dir=hps.model_dir) 67 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) 68 | 69 | dist.init_process_group( 70 | backend="nccl", init_method="env://", world_size=n_gpus, rank=rank 71 | ) 72 | torch.manual_seed(hps.train.seed) 73 | torch.cuda.set_device(rank) 74 | 75 | if ( 76 | "use_mel_posterior_encoder" in hps.model.keys() 77 | and hps.model.use_mel_posterior_encoder == True 78 | ): 79 | print("Using mel posterior encoder for VITS2") 80 | posterior_channels = 128 # vits2 81 | hps.data.use_mel_posterior_encoder = True 82 | else: 83 | print("Using lin posterior encoder for VITS1") 84 | posterior_channels = hps.data.filter_length // 2 + 1 85 | hps.data.use_mel_posterior_encoder = False 86 | 87 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) 88 | train_sampler = DistributedBucketSampler( 89 | train_dataset, 90 | hps.train.batch_size, 91 | [32, 300, 500, 700, 900, 1100, 1300, 1500, 3000], 92 | num_replicas=n_gpus, 93 | rank=rank, 94 | shuffle=True, 95 | ) 96 | collate_fn = TextAudioSpeakerCollate() 97 | train_loader = DataLoader( 98 | train_dataset, 99 | num_workers=8, 100 | shuffle=False, 101 | pin_memory=True, 102 | collate_fn=collate_fn, 103 | batch_sampler=train_sampler, 104 | ) 105 | 106 | if rank == 0: 107 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data) 108 | eval_loader = DataLoader( 109 | eval_dataset, 110 | num_workers=8, 111 | shuffle=False, 112 | batch_size=hps.train.batch_size, 113 | pin_memory=True, 114 | drop_last=False, 115 | collate_fn=collate_fn, 116 | ) 117 | # some of these flags are not being used in the code and directly set in hps json file. 118 | # they are kept here for reference and prototyping. 119 | if ( 120 | "use_transformer_flows" in hps.model.keys() 121 | and hps.model.use_transformer_flows == True 122 | ): 123 | use_transformer_flows = True 124 | transformer_flow_type = hps.model.transformer_flow_type 125 | print(f"Using transformer flows {transformer_flow_type} for VITS2") 126 | assert ( 127 | transformer_flow_type in AVAILABLE_FLOW_TYPES 128 | ), f"transformer_flow_type must be one of {AVAILABLE_FLOW_TYPES}" 129 | else: 130 | print("Using normal flows for VITS1") 131 | use_transformer_flows = False 132 | 133 | if ( 134 | "use_spk_conditioned_encoder" in hps.model.keys() 135 | and hps.model.use_spk_conditioned_encoder == True 136 | ): 137 | if hps.data.n_speakers == 0: 138 | raise ValueError( 139 | "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model" 140 | ) 141 | use_spk_conditioned_encoder = True 142 | else: 143 | print("Using normal encoder for VITS1") 144 | use_spk_conditioned_encoder = False 145 | 146 | if ( 147 | "use_noise_scaled_mas" in hps.model.keys() 148 | and hps.model.use_noise_scaled_mas == True 149 | ): 150 | print("Using noise scaled MAS for VITS2") 151 | use_noise_scaled_mas = True 152 | mas_noise_scale_initial = 0.01 153 | noise_scale_delta = 2e-6 154 | else: 155 | print("Using normal MAS for VITS1") 156 | use_noise_scaled_mas = False 157 | mas_noise_scale_initial = 0.0 158 | noise_scale_delta = 0.0 159 | 160 | if ( 161 | "use_duration_discriminator" in hps.model.keys() 162 | and hps.model.use_duration_discriminator == True 163 | ): 164 | # print("Using duration discriminator for VITS2") 165 | use_duration_discriminator = True 166 | 167 | # comment - choihkk 168 | # add duration discriminator type here 169 | # I think it would be a good idea to come up with a method to input this part accurately, like a hydra 170 | duration_discriminator_type = getattr( 171 | hps.model, "duration_discriminator_type", "dur_disc_1" 172 | ) 173 | print(f"Using duration_discriminator {duration_discriminator_type} for VITS2") 174 | assert ( 175 | duration_discriminator_type in AVAILABLE_DURATION_DISCRIMINATOR_TYPES 176 | ), f"duration_discriminator_type must be one of {AVAILABLE_DURATION_DISCRIMINATOR_TYPES}" 177 | # duration_discriminator_type = AVAILABLE_DURATION_DISCRIMINATOR_TYPES # ここ修正 178 | if duration_discriminator_type == "dur_disc_1": 179 | net_dur_disc = DurationDiscriminatorV1( 180 | hps.model.hidden_channels, 181 | hps.model.hidden_channels, 182 | 3, 183 | 0.1, 184 | gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, 185 | ).cuda(rank) 186 | elif duration_discriminator_type == "dur_disc_2": 187 | net_dur_disc = DurationDiscriminatorV2( 188 | hps.model.hidden_channels, 189 | hps.model.hidden_channels, 190 | 3, 191 | 0.1, 192 | gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0, 193 | ).cuda(rank) 194 | else: 195 | print("NOT using any duration discriminator like VITS1") 196 | net_dur_disc = None 197 | use_duration_discriminator = False 198 | 199 | net_g = SynthesizerTrn( 200 | len(symbols)+1, 201 | posterior_channels, 202 | hps.train.segment_size // hps.data.hop_length, 203 | n_speakers=hps.data.n_speakers, 204 | mas_noise_scale_initial=mas_noise_scale_initial, 205 | noise_scale_delta=noise_scale_delta, 206 | **hps.model, 207 | ).cuda(rank) 208 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 209 | optim_g = torch.optim.AdamW( 210 | net_g.parameters(), 211 | hps.train.learning_rate, 212 | betas=hps.train.betas, 213 | eps=hps.train.eps, 214 | ) 215 | optim_d = torch.optim.AdamW( 216 | net_d.parameters(), 217 | hps.train.learning_rate, 218 | betas=hps.train.betas, 219 | eps=hps.train.eps, 220 | ) 221 | if net_dur_disc is not None: 222 | optim_dur_disc = torch.optim.AdamW( 223 | net_dur_disc.parameters(), 224 | hps.train.learning_rate, 225 | betas=hps.train.betas, 226 | eps=hps.train.eps, 227 | ) 228 | else: 229 | optim_dur_disc = None 230 | 231 | # comment - choihkk 232 | # if we comment out unused parameter like DurationDiscriminator's self.pre_out_norm1,2 self.norm_1,2 233 | # and ResidualCouplingTransformersLayer's self.post_transformer 234 | # we don't have to set find_unused_parameters=True 235 | # but I will not proceed with commenting out for compatibility with the latest work for others 236 | net_g = DDP(net_g, device_ids=[rank]) 237 | net_d = DDP(net_d, device_ids=[rank]) 238 | if net_dur_disc is not None: 239 | net_dur_disc = DDP( 240 | net_dur_disc, device_ids=[rank]) 241 | 242 | try: 243 | _, _, _, epoch_str = utils.load_checkpoint( 244 | utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g 245 | ) 246 | _, _, _, epoch_str = utils.load_checkpoint( 247 | utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d 248 | ) 249 | if net_dur_disc is not None: 250 | _, _, _, epoch_str = utils.load_checkpoint( 251 | utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"), 252 | net_dur_disc, 253 | optim_dur_disc, 254 | ) 255 | global_step = (epoch_str - 1) * len(train_loader) 256 | 257 | input = input("Initialize Global Steps and Epochs ??? y/n") 258 | if input == "y": 259 | epoch_str = 1 260 | global_step = 0 261 | 262 | except: 263 | epoch_str = 1 264 | global_step = 0 265 | 266 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR( 267 | optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 268 | ) 269 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR( 270 | optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 271 | ) 272 | if net_dur_disc is not None: 273 | scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR( 274 | optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2 275 | ) 276 | else: 277 | scheduler_dur_disc = None 278 | 279 | scaler = GradScaler(enabled=hps.train.fp16_run) 280 | 281 | for epoch in range(epoch_str, hps.train.epochs + 1): 282 | if rank == 0: 283 | train_and_evaluate( 284 | rank, 285 | epoch, 286 | hps, 287 | [net_g, net_d, net_dur_disc], 288 | [optim_g, optim_d, optim_dur_disc], 289 | [scheduler_g, scheduler_d, scheduler_dur_disc], 290 | scaler, 291 | [train_loader, eval_loader], 292 | logger, 293 | [writer, writer_eval], 294 | ) 295 | else: 296 | train_and_evaluate( 297 | rank, 298 | epoch, 299 | hps, 300 | [net_g, net_d, net_dur_disc], 301 | [optim_g, optim_d, optim_dur_disc], 302 | [scheduler_g, scheduler_d, scheduler_dur_disc], 303 | scaler, 304 | [train_loader, None], 305 | None, 306 | None, 307 | ) 308 | scheduler_g.step() 309 | scheduler_d.step() 310 | if net_dur_disc is not None: 311 | scheduler_dur_disc.step() 312 | 313 | 314 | def train_and_evaluate( 315 | rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers 316 | ): 317 | net_g, net_d, net_dur_disc = nets 318 | optim_g, optim_d, optim_dur_disc = optims 319 | scheduler_g, scheduler_d, scheduler_dur_disc = schedulers 320 | train_loader, eval_loader = loaders 321 | if writers is not None: 322 | writer, writer_eval = writers 323 | 324 | train_loader.batch_sampler.set_epoch(epoch) 325 | global global_step 326 | 327 | net_g.train() 328 | net_d.train() 329 | if net_dur_disc is not None: 330 | net_dur_disc.train() 331 | 332 | if rank == 0: 333 | loader = tqdm.tqdm(train_loader, desc="Loading train data") 334 | else: 335 | loader = train_loader 336 | 337 | for batch_idx, ( 338 | x, 339 | x_lengths, 340 | spec, 341 | spec_lengths, 342 | y, 343 | y_lengths, 344 | bert, 345 | bert_lengths, 346 | speakers, 347 | ) in enumerate(loader): 348 | if net_g.module.use_noise_scaled_mas: 349 | current_mas_noise_scale = ( 350 | net_g.module.mas_noise_scale_initial 351 | - net_g.module.noise_scale_delta * global_step 352 | ) 353 | net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0) 354 | x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda( 355 | rank, non_blocking=True 356 | ) 357 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda( 358 | rank, non_blocking=True 359 | ) 360 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda( 361 | rank, non_blocking=True 362 | ) 363 | bert, bert_lengths = bert.cuda(rank, non_blocking=True), bert_lengths.cuda( 364 | rank, non_blocking=True 365 | ) 366 | speakers = speakers.cuda(rank, non_blocking=True) 367 | 368 | with autocast(enabled=hps.train.fp16_run): 369 | ( 370 | y_hat, 371 | l_length, 372 | attn, 373 | ids_slice, 374 | x_mask, 375 | z_mask, 376 | (z, z_p, m_p, logs_p, m_q, logs_q), 377 | (hidden_x, logw, logw_), 378 | ) = net_g(x, x_lengths, spec, spec_lengths, bert, bert_lengths, speakers) 379 | 380 | if ( 381 | hps.model.use_mel_posterior_encoder 382 | or hps.data.use_mel_posterior_encoder 383 | ): 384 | mel = spec 385 | else: 386 | # comment - choihkk 387 | # for numerical stable when using fp16 and torch>=2.0.0, 388 | # spec.float() could be help in the training stage 389 | # https://github.com/jaywalnut310/vits/issues/15 390 | mel = spec_to_mel_torch( 391 | spec.float(), 392 | hps.data.filter_length, 393 | hps.data.n_mel_channels, 394 | hps.data.sampling_rate, 395 | hps.data.mel_fmin, 396 | hps.data.mel_fmax, 397 | ) 398 | y_mel = commons.slice_segments( 399 | mel, ids_slice, hps.train.segment_size // hps.data.hop_length 400 | ) 401 | y_hat_mel = mel_spectrogram_torch( 402 | y_hat.squeeze(1), 403 | hps.data.filter_length, 404 | hps.data.n_mel_channels, 405 | hps.data.sampling_rate, 406 | hps.data.hop_length, 407 | hps.data.win_length, 408 | hps.data.mel_fmin, 409 | hps.data.mel_fmax, 410 | ) 411 | 412 | y = commons.slice_segments( 413 | y, ids_slice * hps.data.hop_length, hps.train.segment_size 414 | ) # slice 415 | 416 | # Discriminator 417 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 418 | with autocast(enabled=False): 419 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( 420 | y_d_hat_r, y_d_hat_g 421 | ) 422 | loss_disc_all = loss_disc 423 | 424 | # Duration Discriminator 425 | if net_dur_disc is not None: 426 | y_dur_hat_r, y_dur_hat_g = net_dur_disc( 427 | hidden_x.detach(), x_mask.detach(), logw_.detach(), logw.detach() 428 | ) 429 | with autocast(enabled=False): 430 | # TODO: I think need to mean using the mask, but for now, just mean all 431 | ( 432 | loss_dur_disc, 433 | losses_dur_disc_r, 434 | losses_dur_disc_g, 435 | ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g) 436 | loss_dur_disc_all = loss_dur_disc 437 | optim_dur_disc.zero_grad() 438 | scaler.scale(loss_dur_disc_all).backward() 439 | scaler.unscale_(optim_dur_disc) 440 | grad_norm_dur_disc = commons.clip_grad_value_( 441 | net_dur_disc.parameters(), None 442 | ) 443 | scaler.step(optim_dur_disc) 444 | 445 | optim_d.zero_grad() 446 | scaler.scale(loss_disc_all).backward() 447 | scaler.unscale_(optim_d) 448 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) 449 | scaler.step(optim_d) 450 | 451 | with autocast(enabled=hps.train.fp16_run): 452 | # Generator 453 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 454 | if net_dur_disc is not None: 455 | y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw) 456 | with autocast(enabled=False): 457 | loss_dur = torch.sum(l_length.float()) 458 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel 459 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl 460 | 461 | loss_fm = feature_loss(fmap_r, fmap_g) 462 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 463 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl 464 | if net_dur_disc is not None: 465 | loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g) 466 | loss_gen_all += loss_dur_gen 467 | 468 | optim_g.zero_grad() 469 | scaler.scale(loss_gen_all).backward() 470 | scaler.unscale_(optim_g) 471 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) 472 | scaler.step(optim_g) 473 | scaler.update() 474 | 475 | if rank == 0: 476 | if global_step % hps.train.log_interval == 0: 477 | lr = optim_g.param_groups[0]["lr"] 478 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl] 479 | logger.info( 480 | "Train Epoch: {} [{:.0f}%]".format( 481 | epoch, 100.0 * batch_idx / len(train_loader) 482 | ) 483 | ) 484 | logger.info([x.item() for x in losses] + [global_step, lr]) 485 | 486 | scalar_dict = { 487 | "loss/g/total": loss_gen_all, 488 | "loss/d/total": loss_disc_all, 489 | "learning_rate": lr, 490 | "grad_norm_d": grad_norm_d, 491 | "grad_norm_g": grad_norm_g, 492 | } 493 | if net_dur_disc is not None: 494 | scalar_dict.update( 495 | { 496 | "loss/dur_disc/total": loss_dur_disc_all, 497 | "grad_norm_dur_disc": grad_norm_dur_disc, 498 | } 499 | ) 500 | scalar_dict.update( 501 | { 502 | "loss/g/fm": loss_fm, 503 | "loss/g/mel": loss_mel, 504 | "loss/g/dur": loss_dur, 505 | "loss/g/kl": loss_kl, 506 | } 507 | ) 508 | 509 | scalar_dict.update( 510 | {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)} 511 | ) 512 | scalar_dict.update( 513 | {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)} 514 | ) 515 | scalar_dict.update( 516 | {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)} 517 | ) 518 | 519 | # if net_dur_disc is not None: 520 | # scalar_dict.update({"loss/dur_disc_r" : f"{losses_dur_disc_r}"}) 521 | # scalar_dict.update({"loss/dur_disc_g" : f"{losses_dur_disc_g}"}) 522 | # scalar_dict.update({"loss/dur_gen" : f"{loss_dur_gen}"}) 523 | 524 | image_dict = { 525 | "slice/mel_org": utils.plot_spectrogram_to_numpy( 526 | y_mel[0].data.cpu().numpy() 527 | ), 528 | "slice/mel_gen": utils.plot_spectrogram_to_numpy( 529 | y_hat_mel[0].data.cpu().numpy() 530 | ), 531 | "train/mel": utils.plot_spectrogram_to_numpy( 532 | mel[0].data.cpu().numpy() 533 | ), 534 | "train/attn": utils.plot_alignment_to_numpy( 535 | attn[0, 0].data.cpu().numpy() 536 | ), 537 | } 538 | utils.summarize( 539 | writer=writer, 540 | global_step=global_step, 541 | images=image_dict, 542 | scalars=scalar_dict, 543 | ) 544 | 545 | if global_step % hps.train.eval_interval == 0: 546 | evaluate(hps, net_g, eval_loader, writer_eval) 547 | utils.save_checkpoint( 548 | net_g, 549 | optim_g, 550 | hps.train.learning_rate, 551 | epoch, 552 | os.path.join(hps.model_dir, "G_{}.pth".format(global_step)), 553 | ) 554 | utils.save_checkpoint( 555 | net_d, 556 | optim_d, 557 | hps.train.learning_rate, 558 | epoch, 559 | os.path.join(hps.model_dir, "D_{}.pth".format(global_step)), 560 | ) 561 | if net_dur_disc is not None: 562 | utils.save_checkpoint( 563 | net_dur_disc, 564 | optim_dur_disc, 565 | hps.train.learning_rate, 566 | epoch, 567 | os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)), 568 | ) 569 | utils.remove_old_checkpoints(hps.model_dir, prefixes=["G_*.pth", "D_*.pth", "DUR_*.pth"]) 570 | global_step += 1 571 | 572 | if rank == 0: 573 | logger.info("====> Epoch: {}".format(epoch)) 574 | 575 | 576 | def evaluate(hps, generator, eval_loader, writer_eval): 577 | generator.eval() 578 | with torch.no_grad(): 579 | for batch_idx, ( 580 | x, 581 | x_lengths, 582 | spec, 583 | spec_lengths, 584 | y, 585 | y_lengths, 586 | bert, bert_lengths, 587 | speakers, 588 | ) in enumerate(eval_loader): 589 | x, x_lengths = x.cuda(0), x_lengths.cuda(0) 590 | spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0) 591 | y, y_lengths = y.cuda(0), y_lengths.cuda(0) 592 | speakers = speakers.cuda(0) 593 | bert, bert_lengths = bert.cuda(0), bert_lengths.cuda(0) 594 | 595 | # remove else 596 | x = x[:1] 597 | x_lengths = x_lengths[:1] 598 | spec = spec[:1] 599 | spec_lengths = spec_lengths[:1] 600 | y = y[:1] 601 | y_lengths = y_lengths[:1] 602 | bert = bert[:1] 603 | bert_lengths = bert_lengths[:1] 604 | speakers = speakers[:1] 605 | break 606 | y_hat, attn, mask, *_ = generator.module.infer( 607 | x, x_lengths, bert, bert_lengths, speakers, max_len=1000 608 | ) 609 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length 610 | 611 | if hps.model.use_mel_posterior_encoder or hps.data.use_mel_posterior_encoder: 612 | mel = spec 613 | else: 614 | mel = spec_to_mel_torch( 615 | spec, 616 | hps.data.filter_length, 617 | hps.data.n_mel_channels, 618 | hps.data.sampling_rate, 619 | hps.data.mel_fmin, 620 | hps.data.mel_fmax, 621 | ) 622 | y_hat_mel = mel_spectrogram_torch( 623 | y_hat.squeeze(1).float(), 624 | hps.data.filter_length, 625 | hps.data.n_mel_channels, 626 | hps.data.sampling_rate, 627 | hps.data.hop_length, 628 | hps.data.win_length, 629 | hps.data.mel_fmin, 630 | hps.data.mel_fmax, 631 | ) 632 | image_dict = { 633 | "valid/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) 634 | } 635 | audio_dict = {"valid/gen/audio": y_hat[0, :, : y_hat_lengths[0]]} 636 | if global_step == 0: 637 | image_dict.update( 638 | {"valid/gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())} 639 | ) 640 | audio_dict.update({"valid/gt/audio": y[0, :, : y_lengths[0]]}) 641 | 642 | utils.summarize( 643 | writer=writer_eval, 644 | global_step=global_step, 645 | images=image_dict, 646 | audios=audio_dict, 647 | audio_sampling_rate=hps.data.sampling_rate, 648 | ) 649 | generator.train() 650 | 651 | 652 | if __name__ == "__main__": 653 | main() 654 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | ( 80 | outputs[inside_interval_mask], 81 | logabsdet[inside_interval_mask], 82 | ) = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * ( 163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 164 | ) + input_heights * (input_delta - input_derivatives) 165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) 168 | c = -input_delta * (inputs - input_cumheights) 169 | 170 | discriminant = b.pow(2) - 4 * a * c 171 | assert (discriminant >= 0).all() 172 | 173 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 174 | outputs = root * input_bin_widths + input_cumwidths 175 | 176 | theta_one_minus_theta = root * (1 - root) 177 | denominator = input_delta + ( 178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 179 | * theta_one_minus_theta 180 | ) 181 | derivative_numerator = input_delta.pow(2) * ( 182 | input_derivatives_plus_one * root.pow(2) 183 | + 2 * input_delta * theta_one_minus_theta 184 | + input_derivatives * (1 - root).pow(2) 185 | ) 186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 187 | 188 | return outputs, -logabsdet 189 | else: 190 | theta = (inputs - input_cumwidths) / input_bin_widths 191 | theta_one_minus_theta = theta * (1 - theta) 192 | 193 | numerator = input_heights * ( 194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 195 | ) 196 | denominator = input_delta + ( 197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 198 | * theta_one_minus_theta 199 | ) 200 | outputs = input_cumheights + numerator / denominator 201 | 202 | derivative_numerator = input_delta.pow(2) * ( 203 | input_derivatives_plus_one * theta.pow(2) 204 | + 2 * input_delta * theta_one_minus_theta 205 | + input_derivatives * (1 - theta).pow(2) 206 | ) 207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 208 | 209 | return outputs, logabsdet 210 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 21 | iteration = checkpoint_dict["iteration"] 22 | learning_rate = checkpoint_dict["learning_rate"] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict["optimizer"]) 25 | saved_state_dict = checkpoint_dict["model"] 26 | if hasattr(model, "module"): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict = {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, "module"): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info( 42 | "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) 43 | ) 44 | return model, optimizer, learning_rate, iteration 45 | 46 | 47 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 48 | logger.info( 49 | "Saving model and optimizer state at iteration {} to {}".format( 50 | iteration, checkpoint_path 51 | ) 52 | ) 53 | if hasattr(model, "module"): 54 | state_dict = model.module.state_dict() 55 | else: 56 | state_dict = model.state_dict() 57 | torch.save( 58 | { 59 | "model": state_dict, 60 | "iteration": iteration, 61 | "optimizer": optimizer.state_dict(), 62 | "learning_rate": learning_rate, 63 | }, 64 | checkpoint_path, 65 | ) 66 | 67 | 68 | def summarize( 69 | writer, 70 | global_step, 71 | scalars={}, 72 | histograms={}, 73 | images={}, 74 | audios={}, 75 | audio_sampling_rate=22050, 76 | ): 77 | for k, v in scalars.items(): 78 | writer.add_scalar(k, v, global_step) 79 | for k, v in histograms.items(): 80 | writer.add_histogram(k, v, global_step) 81 | for k, v in images.items(): 82 | writer.add_image(k, v, global_step, dataformats="HWC") 83 | for k, v in audios.items(): 84 | writer.add_audio(k, v, global_step, audio_sampling_rate) 85 | 86 | 87 | def scan_checkpoint(dir_path, regex): 88 | f_list = glob.glob(os.path.join(dir_path, regex)) 89 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 90 | if len(f_list) == 0: 91 | return None 92 | return f_list 93 | 94 | 95 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 96 | f_list = scan_checkpoint(dir_path, regex) 97 | if not f_list: 98 | return None 99 | x = f_list[-1] 100 | print(x) 101 | return x 102 | 103 | 104 | def remove_old_checkpoints(cp_dir, prefixes=['G_*.pth', 'D_*.pth', 'DUR_*.pth']): 105 | for prefix in prefixes: 106 | sorted_ckpts = scan_checkpoint(cp_dir, prefix) 107 | if sorted_ckpts and len(sorted_ckpts) > 3: 108 | for ckpt_path in sorted_ckpts[:-3]: 109 | os.remove(ckpt_path) 110 | print("removed {}".format(ckpt_path)) 111 | 112 | 113 | def plot_spectrogram_to_numpy(spectrogram): 114 | global MATPLOTLIB_FLAG 115 | if not MATPLOTLIB_FLAG: 116 | import matplotlib 117 | 118 | matplotlib.use("Agg") 119 | MATPLOTLIB_FLAG = True 120 | mpl_logger = logging.getLogger("matplotlib") 121 | mpl_logger.setLevel(logging.WARNING) 122 | import matplotlib.pylab as plt 123 | import numpy as np 124 | 125 | fig, ax = plt.subplots(figsize=(10, 2)) 126 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 127 | plt.colorbar(im, ax=ax) 128 | plt.xlabel("Frames") 129 | plt.ylabel("Channels") 130 | plt.tight_layout() 131 | 132 | fig.canvas.draw() 133 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 134 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 135 | plt.close() 136 | return data 137 | 138 | 139 | def plot_alignment_to_numpy(alignment, info=None): 140 | global MATPLOTLIB_FLAG 141 | if not MATPLOTLIB_FLAG: 142 | import matplotlib 143 | 144 | matplotlib.use("Agg") 145 | MATPLOTLIB_FLAG = True 146 | mpl_logger = logging.getLogger("matplotlib") 147 | mpl_logger.setLevel(logging.WARNING) 148 | import matplotlib.pylab as plt 149 | import numpy as np 150 | 151 | fig, ax = plt.subplots(figsize=(6, 4)) 152 | im = ax.imshow( 153 | alignment.transpose(), aspect="auto", origin="lower", interpolation="none" 154 | ) 155 | fig.colorbar(im, ax=ax) 156 | xlabel = "Decoder timestep" 157 | if info is not None: 158 | xlabel += "\n\n" + info 159 | plt.xlabel(xlabel) 160 | plt.ylabel("Encoder timestep") 161 | plt.tight_layout() 162 | 163 | fig.canvas.draw() 164 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 165 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 166 | plt.close() 167 | return data 168 | 169 | 170 | def load_wav_to_torch(full_path): 171 | sampling_rate, wav = read(full_path.replace("\\", "/")) ### modify .replace("\\", "/") ### 172 | 173 | if len(wav.shape) == 2: 174 | wav = wav[:, 0] 175 | if wav.dtype == np.int16: 176 | wav = wav / 32768.0 177 | elif wav.dtype == np.int32: 178 | wav = wav / 2147483648.0 179 | elif wav.dtype == np.uint8: 180 | wav = (wav - 128) / 128.0 181 | wav = wav.astype(np.float32) 182 | 183 | return torch.FloatTensor(wav), sampling_rate 184 | 185 | 186 | def load_filepaths_and_text(filename, split="||||"): 187 | with open(filename, encoding="utf-8") as f: 188 | filepaths_and_text = [line.strip().split(split) for line in f] 189 | return filepaths_and_text 190 | 191 | 192 | def get_hparams(init=True): 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument( 195 | "-c", 196 | "--config", 197 | type=str, 198 | default="./configs/jvnv_base.json", 199 | help="JSON file for configuration", 200 | ) 201 | parser.add_argument("-m", "--model", type=str, default="test", help="Model name") 202 | 203 | args = parser.parse_args() 204 | model_dir = os.path.join("./logs", args.model) 205 | 206 | if not os.path.exists(model_dir): 207 | os.makedirs(model_dir) 208 | 209 | config_path = args.config 210 | config_save_path = os.path.join(model_dir, "config.json") 211 | if init: 212 | with open(config_path, "r") as f: 213 | data = f.read() 214 | with open(config_save_path, "w") as f: 215 | f.write(data) 216 | else: 217 | with open(config_save_path, "r") as f: 218 | data = f.read() 219 | config = json.loads(data) 220 | 221 | hparams = HParams(**config) 222 | hparams.model_dir = model_dir 223 | return hparams 224 | 225 | 226 | def get_hparams_from_dir(model_dir): 227 | config_save_path = os.path.join(model_dir, "config.json") 228 | with open(config_save_path, "r") as f: 229 | data = f.read() 230 | config = json.loads(data) 231 | 232 | hparams = HParams(**config) 233 | hparams.model_dir = model_dir 234 | return hparams 235 | 236 | 237 | def get_hparams_from_file(config_path): 238 | with open(config_path, "r") as f: 239 | data = f.read() 240 | config = json.loads(data) 241 | 242 | hparams = HParams(**config) 243 | return hparams 244 | 245 | 246 | def check_git_hash(model_dir): 247 | source_dir = os.path.dirname(os.path.realpath(__file__)) 248 | if not os.path.exists(os.path.join(source_dir, ".git")): 249 | logger.warn( 250 | "{} is not a git repository, therefore hash value comparison will be ignored.".format( 251 | source_dir 252 | ) 253 | ) 254 | return 255 | 256 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 257 | 258 | path = os.path.join(model_dir, "githash") 259 | if os.path.exists(path): 260 | saved_hash = open(path).read() 261 | if saved_hash != cur_hash: 262 | logger.warn( 263 | "git hash values are different. {}(saved) != {}(current)".format( 264 | saved_hash[:8], cur_hash[:8] 265 | ) 266 | ) 267 | else: 268 | open(path, "w").write(cur_hash) 269 | 270 | 271 | def get_logger(model_dir, filename="train.log"): 272 | global logger 273 | logger = logging.getLogger(os.path.basename(model_dir)) 274 | logger.setLevel(logging.DEBUG) 275 | 276 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 277 | if not os.path.exists(model_dir): 278 | os.makedirs(model_dir) 279 | h = logging.FileHandler(os.path.join(model_dir, filename)) 280 | h.setLevel(logging.DEBUG) 281 | h.setFormatter(formatter) 282 | logger.addHandler(h) 283 | return logger 284 | 285 | 286 | class HParams: 287 | def __init__(self, **kwargs): 288 | for k, v in kwargs.items(): 289 | if type(v) == dict: 290 | v = HParams(**v) 291 | self[k] = v 292 | 293 | def keys(self): 294 | return self.__dict__.keys() 295 | 296 | def items(self): 297 | return self.__dict__.items() 298 | 299 | def values(self): 300 | return self.__dict__.values() 301 | 302 | def __len__(self): 303 | return len(self.__dict__) 304 | 305 | def __getitem__(self, key): 306 | return getattr(self, key) 307 | 308 | def __setitem__(self, key, value): 309 | return setattr(self, key, value) 310 | 311 | def __contains__(self, key): 312 | return key in self.__dict__ 313 | 314 | def __repr__(self): 315 | return self.__dict__.__repr__() 316 | --------------------------------------------------------------------------------