├── .gitignore ├── GPT_SoVITS ├── AR │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── bucket_sampler.py │ │ ├── data_module.py │ │ └── dataset.py │ ├── models │ │ ├── __init__.py │ │ ├── t2s_lightning_module.py │ │ ├── t2s_lightning_module_onnx.py │ │ ├── t2s_model.py │ │ ├── t2s_model_onnx.py │ │ └── utils.py │ ├── modules │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── activation_onnx.py │ │ ├── embedding.py │ │ ├── embedding_onnx.py │ │ ├── lr_schedulers.py │ │ ├── optim.py │ │ ├── patched_mha_with_cache.py │ │ ├── patched_mha_with_cache_onnx.py │ │ ├── scaling.py │ │ ├── transformer.py │ │ └── transformer_onnx.py │ ├── text_processing │ │ ├── __init__.py │ │ ├── phonemizer.py │ │ └── symbols.py │ └── utils │ │ ├── __init__.py │ │ ├── initialize.py │ │ └── io.py ├── feature_extractor │ ├── __init__.py │ ├── cnhubert.py │ └── whisper_enc.py ├── module │ ├── __init__.py │ ├── attentions.py │ ├── attentions_onnx.py │ ├── commons.py │ ├── core_vq.py │ ├── data_utils.py │ ├── losses.py │ ├── mel_processing.py │ ├── models.py │ ├── models_onnx.py │ ├── modules.py │ ├── mrte_model.py │ ├── quantize.py │ └── transforms.py ├── my_utils.py ├── pretrained_models │ └── .gitignore ├── text │ ├── __init__.py │ ├── chinese.py │ ├── cleaner.py │ ├── cmudict-fast.rep │ ├── cmudict.rep │ ├── cmudict_cache.pickle │ ├── engdict-hot.rep │ ├── engdict_cache.pickle │ ├── english.py │ ├── japanese.py │ ├── opencpop-strict.txt │ ├── symbols.py │ ├── tone_sandhi.py │ └── zh_normalization │ │ ├── README.md │ │ ├── __init__.py │ │ ├── char_convert.py │ │ ├── chronology.py │ │ ├── constants.py │ │ ├── num.py │ │ ├── phonecode.py │ │ ├── quantifier.py │ │ └── text_normlization.py └── utils.py ├── LICENSE ├── README.md ├── __init__.py ├── example.py ├── get_tts_wav.py ├── install.sh ├── requirements.txt ├── requirements_win.txt └── temp └── nltk_data ├── corpora ├── cmudict.zip └── cmudict │ ├── README │ └── cmudict └── taggers ├── averaged_perceptron_tagger.zip └── averaged_perceptron_tagger └── averaged_perceptron_tagger.pickle /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *.py[cod] 3 | *$py.class 4 | __pycache__/ 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | .hypothesis/ 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | *.pyrdb 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | .idea/ 91 | *.db 92 | .DS_Store 93 | **/migrations/*.py 94 | !**/migrations/__init__.py 95 | *.pyc 96 | db.sqlite3 97 | media/ 98 | __pypackages__/ 99 | package-lock.json 100 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/AR/__init__.py -------------------------------------------------------------------------------- /GPT_SoVITS/AR/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/AR/data/__init__.py -------------------------------------------------------------------------------- /GPT_SoVITS/AR/data/bucket_sampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py 2 | import itertools 3 | import math 4 | import random 5 | from random import shuffle 6 | from typing import Iterator 7 | from typing import Optional 8 | from typing import TypeVar 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from torch.utils.data import Dataset 13 | from torch.utils.data import Sampler 14 | 15 | __all__ = [ 16 | "DistributedBucketSampler", 17 | ] 18 | 19 | T_co = TypeVar("T_co", covariant=True) 20 | 21 | 22 | class DistributedBucketSampler(Sampler[T_co]): 23 | r""" 24 | sort the dataset wrt. input length 25 | divide samples into buckets 26 | sort within buckets 27 | divide buckets into batches 28 | sort batches 29 | """ 30 | 31 | def __init__( 32 | self, 33 | dataset: Dataset, 34 | num_replicas: Optional[int] = None, 35 | rank: Optional[int] = None, 36 | shuffle: bool = True, 37 | seed: int = 0, 38 | drop_last: bool = False, 39 | batch_size: int = 32, 40 | ) -> None: 41 | if num_replicas is None: 42 | if not dist.is_available(): 43 | raise RuntimeError("Requires distributed package to be available") 44 | num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1 45 | if rank is None: 46 | if not dist.is_available(): 47 | raise RuntimeError("Requires distributed package to be available") 48 | rank = dist.get_rank() if torch.cuda.is_available() else 0 49 | if torch.cuda.is_available(): 50 | torch.cuda.set_device(rank) 51 | if rank >= num_replicas or rank < 0: 52 | raise ValueError( 53 | "Invalid rank {}, rank should be in the interval" 54 | " [0, {}]".format(rank, num_replicas - 1) 55 | ) 56 | self.dataset = dataset 57 | self.num_replicas = num_replicas 58 | self.rank = rank 59 | self.epoch = 0 60 | self.drop_last = drop_last 61 | # If the dataset length is evenly divisible by # of replicas, then there 62 | # is no need to drop any data, since the dataset will be split equally. 63 | if ( 64 | self.drop_last and len(self.dataset) % self.num_replicas != 0 65 | ): # type: ignore[arg-type] 66 | # Split to nearest available length that is evenly divisible. 67 | # This is to ensure each rank receives the same amount of data when 68 | # using this Sampler. 69 | self.num_samples = math.ceil( 70 | (len(self.dataset) - self.num_replicas) 71 | / self.num_replicas # type: ignore[arg-type] 72 | ) 73 | else: 74 | self.num_samples = math.ceil( 75 | len(self.dataset) / self.num_replicas 76 | ) # type: ignore[arg-type] 77 | self.total_size = self.num_samples * self.num_replicas 78 | self.shuffle = shuffle 79 | self.seed = seed 80 | self.batch_size = batch_size 81 | self.id_with_length = self._get_sample_lengths() 82 | self.id_buckets = self.make_buckets(bucket_width=2.0) 83 | 84 | def _get_sample_lengths(self): 85 | id_with_lengths = [] 86 | for i in range(len(self.dataset)): 87 | id_with_lengths.append((i, self.dataset.get_sample_length(i))) 88 | id_with_lengths.sort(key=lambda x: x[1]) 89 | return id_with_lengths 90 | 91 | def make_buckets(self, bucket_width: float = 2.0): 92 | buckets = [] 93 | cur = [] 94 | max_sec = bucket_width 95 | for id, sec in self.id_with_length: 96 | if sec < max_sec: 97 | cur.append(id) 98 | else: 99 | buckets.append(cur) 100 | cur = [id] 101 | max_sec += bucket_width 102 | if len(cur) > 0: 103 | buckets.append(cur) 104 | return buckets 105 | 106 | def __iter__(self) -> Iterator[T_co]: 107 | if self.shuffle: 108 | # deterministically shuffle based on epoch and seed 109 | g = torch.Generator() 110 | g.manual_seed(self.seed + self.epoch) 111 | random.seed(self.epoch + self.seed) 112 | shuffled_bucket = [] 113 | for buc in self.id_buckets: 114 | buc_copy = buc.copy() 115 | shuffle(buc_copy) 116 | shuffled_bucket.append(buc_copy) 117 | grouped_batch_size = self.batch_size * self.num_replicas 118 | shuffled_bucket = list(itertools.chain(*shuffled_bucket)) 119 | n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) 120 | batches = [ 121 | shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] 122 | for b in range(n_batch) 123 | ] 124 | shuffle(batches) 125 | indices = list(itertools.chain(*batches)) 126 | else: 127 | # type: ignore[arg-type] 128 | indices = list(range(len(self.dataset))) 129 | 130 | if not self.drop_last: 131 | # add extra samples to make it evenly divisible 132 | padding_size = self.total_size - len(indices) 133 | if padding_size <= len(indices): 134 | indices += indices[:padding_size] 135 | else: 136 | indices += (indices * math.ceil(padding_size / len(indices)))[ 137 | :padding_size 138 | ] 139 | else: 140 | # remove tail of data to make it evenly divisible. 141 | indices = indices[: self.total_size] 142 | assert len(indices) == self.total_size 143 | 144 | # subsample 145 | indices = indices[self.rank : self.total_size : self.num_replicas] 146 | assert len(indices) == self.num_samples 147 | 148 | return iter(indices) 149 | 150 | def __len__(self) -> int: 151 | return self.num_samples 152 | 153 | def set_epoch(self, epoch: int) -> None: 154 | r""" 155 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 156 | use a different random ordering for each epoch. Otherwise, the next iteration of this 157 | sampler will yield the same ordering. 158 | 159 | Args: 160 | epoch (int): Epoch number. 161 | """ 162 | self.epoch = epoch 163 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/data/data_module.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py 2 | from pytorch_lightning import LightningDataModule 3 | from AR.data.bucket_sampler import DistributedBucketSampler 4 | from AR.data.dataset import Text2SemanticDataset 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class Text2SemanticDataModule(LightningDataModule): 9 | def __init__( 10 | self, 11 | config, 12 | train_semantic_path, 13 | train_phoneme_path, 14 | dev_semantic_path=None, 15 | dev_phoneme_path=None, 16 | ): 17 | super().__init__() 18 | self.config = config 19 | self.train_semantic_path = train_semantic_path 20 | self.train_phoneme_path = train_phoneme_path 21 | self.dev_semantic_path = dev_semantic_path 22 | self.dev_phoneme_path = dev_phoneme_path 23 | self.num_workers = self.config["data"]["num_workers"] 24 | 25 | def prepare_data(self): 26 | pass 27 | 28 | def setup(self, stage=None, output_logs=False): 29 | self._train_dataset = Text2SemanticDataset( 30 | phoneme_path=self.train_phoneme_path, 31 | semantic_path=self.train_semantic_path, 32 | max_sec=self.config["data"]["max_sec"], 33 | pad_val=self.config["data"]["pad_val"], 34 | ) 35 | self._dev_dataset = self._train_dataset 36 | # self._dev_dataset = Text2SemanticDataset( 37 | # phoneme_path=self.dev_phoneme_path, 38 | # semantic_path=self.dev_semantic_path, 39 | # max_sample=self.config['data']['max_eval_sample'], 40 | # max_sec=self.config['data']['max_sec'], 41 | # pad_val=self.config['data']['pad_val']) 42 | 43 | def train_dataloader(self): 44 | batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"] 45 | batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存 46 | sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size) 47 | return DataLoader( 48 | self._train_dataset, 49 | batch_size=batch_size, 50 | sampler=sampler, 51 | collate_fn=self._train_dataset.collate, 52 | num_workers=self.num_workers, 53 | persistent_workers=True, 54 | prefetch_factor=16, 55 | ) 56 | 57 | def val_dataloader(self): 58 | return DataLoader( 59 | self._dev_dataset, 60 | batch_size=1, 61 | shuffle=False, 62 | collate_fn=self._train_dataset.collate, 63 | num_workers=max(self.num_workers, 12), 64 | persistent_workers=True, 65 | prefetch_factor=16, 66 | ) 67 | 68 | # 这个会使用到嘛? 69 | def test_dataloader(self): 70 | return DataLoader( 71 | self._dev_dataset, 72 | batch_size=1, 73 | shuffle=False, 74 | collate_fn=self._train_dataset.collate, 75 | ) 76 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/data/dataset.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py 2 | import pdb 3 | import sys 4 | 5 | # sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert") 6 | import traceback, os 7 | from typing import Dict 8 | from typing import List 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import torch, json 13 | from torch.utils.data import DataLoader 14 | from torch.utils.data import Dataset 15 | from transformers import AutoTokenizer 16 | 17 | from text import cleaned_text_to_sequence 18 | 19 | # from config import exp_dir 20 | 21 | 22 | def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0): 23 | seq = sequences[0] 24 | ndim = seq.ndim 25 | if axis < 0: 26 | axis += ndim 27 | dtype = seq.dtype 28 | pad_value = dtype.type(pad_value) 29 | seq_lengths = [seq.shape[axis] for seq in sequences] 30 | max_length = np.max(seq_lengths) 31 | 32 | padded_sequences = [] 33 | for seq, length in zip(sequences, seq_lengths): 34 | padding = ( 35 | [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1) 36 | ) 37 | padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value) 38 | padded_sequences.append(padded_seq) 39 | batch = np.stack(padded_sequences) 40 | return batch 41 | 42 | 43 | class Text2SemanticDataset(Dataset): 44 | """dataset class for text tokens to semantic model training.""" 45 | 46 | def __init__( 47 | self, 48 | phoneme_path: str, 49 | semantic_path: str, 50 | max_sample: int = None, 51 | max_sec: int = 100, 52 | pad_val: int = 1024, 53 | # min value of phoneme/sec 54 | min_ps_ratio: int = 3, 55 | # max value of phoneme/sec 56 | max_ps_ratio: int = 25, 57 | ) -> None: 58 | super().__init__() 59 | 60 | self.semantic_data = pd.read_csv( 61 | semantic_path, delimiter="\t", encoding="utf-8" 62 | ) 63 | # get dict 64 | self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path 65 | self.path3 = "%s/3-bert" % ( 66 | os.path.basename(phoneme_path) 67 | ) # "%s/3-bert"%exp_dir#bert_dir 68 | self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path 69 | assert os.path.exists(self.path2) 70 | assert os.path.exists(self.path6) 71 | self.phoneme_data = {} 72 | with open(self.path2, "r", encoding="utf8") as f: 73 | lines = f.read().strip("\n").split("\n") 74 | 75 | for line in lines: 76 | tmp = line.split("\t") 77 | if len(tmp) != 4: 78 | continue 79 | self.phoneme_data[tmp[0]] = [tmp[1], tmp[2], tmp[3]] 80 | 81 | # self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() 82 | # pad for semantic tokens 83 | self.PAD: int = pad_val 84 | # self.hz = 25 85 | # with open("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert/configs/s2.json", "r") as f:data = f.read() 86 | # data=json.loads(data)["model"]["semantic_frame_rate"]#50hz 87 | # self.hz=int(data[:-2])# 88 | self.hz = int(os.environ.get("hz", "25hz")[:-2]) 89 | 90 | # max seconds of semantic token 91 | self.max_sec = max_sec 92 | self.min_ps_ratio = min_ps_ratio 93 | self.max_ps_ratio = max_ps_ratio 94 | 95 | if max_sample is not None: 96 | self.semantic_data = self.semantic_data[:max_sample] 97 | 98 | # {idx: (semantic, phoneme)} 99 | # semantic list, phoneme list 100 | self.semantic_phoneme = [] 101 | self.item_names = [] 102 | 103 | self.inited = False 104 | 105 | if not self.inited: 106 | # 调用初始化函数 107 | self.init_batch() 108 | self.inited = True 109 | del self.semantic_data 110 | del self.phoneme_data 111 | # self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large") 112 | # self.tokenizer = AutoTokenizer.from_pretrained("/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large") 113 | 114 | def init_batch(self): 115 | semantic_data_len = len(self.semantic_data) 116 | phoneme_data_len = len(self.phoneme_data.keys()) 117 | print("semantic_data_len:", semantic_data_len) 118 | print("phoneme_data_len:", phoneme_data_len) 119 | print(self.semantic_data) 120 | idx = 0 121 | num_not_in = 0 122 | num_deleted_bigger = 0 123 | num_deleted_ps = 0 124 | for i in range(semantic_data_len): 125 | # 先依次遍历 126 | # get str 127 | item_name = self.semantic_data.iloc[i,0] 128 | # print(self.phoneme_data) 129 | try: 130 | phoneme, word2ph, text = self.phoneme_data[item_name] 131 | except Exception: 132 | traceback.print_exc() 133 | # print(f"{item_name} not in self.phoneme_data !") 134 | num_not_in += 1 135 | continue 136 | 137 | semantic_str = self.semantic_data.iloc[i,1] 138 | # get token list 139 | semantic_ids = [int(idx) for idx in semantic_str.split(" ")] 140 | # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len 141 | # 过滤掉太长的样本 142 | if ( 143 | len(semantic_ids) > self.max_sec * self.hz 144 | ): #########1###根据token个数推测总时长过滤时长60s(config里)#40*25=1k 145 | num_deleted_bigger += 1 146 | continue 147 | # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理#### 148 | phoneme = phoneme.split(" ") 149 | 150 | try: 151 | phoneme_ids = cleaned_text_to_sequence(phoneme) 152 | except: 153 | traceback.print_exc() 154 | # print(f"{item_name} not in self.phoneme_data !") 155 | num_not_in += 1 156 | continue 157 | # if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行 158 | if ( 159 | len(phoneme_ids) > self.max_sec * self.hz / 2.5 160 | ): ###########2:改为恒定限制为semantic/2.5就行 161 | num_deleted_ps += 1 162 | continue 163 | # if len(semantic_ids) > 1000:###########3 164 | # num_deleted_bigger += 1 165 | # continue 166 | 167 | ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) 168 | 169 | if ( 170 | ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio 171 | ): ##########4#3~25#每秒多少个phone 172 | num_deleted_ps += 1 173 | # print(item_name) 174 | continue 175 | 176 | self.semantic_phoneme.append((semantic_ids, phoneme_ids)) 177 | idx += 1 178 | self.item_names.append(item_name) 179 | 180 | min_num = 100 # 20直接不补#30补了也不存ckpt 181 | leng = len(self.semantic_phoneme) 182 | if leng < min_num: 183 | tmp1 = self.semantic_phoneme 184 | tmp2 = self.item_names 185 | self.semantic_phoneme = [] 186 | self.item_names = [] 187 | for _ in range(max(2, int(min_num / leng))): 188 | self.semantic_phoneme += tmp1 189 | self.item_names += tmp2 190 | if num_not_in > 0: 191 | print(f"there are {num_not_in} semantic datas not in phoneme datas") 192 | if num_deleted_bigger > 0: 193 | print( 194 | f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds" 195 | ) 196 | if num_deleted_ps > 0: 197 | # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值 198 | print( 199 | f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}" 200 | ) 201 | """ 202 | there are 31 semantic datas not in phoneme datas 203 | deleted 34 audios who's duration are bigger than 54 seconds 204 | deleted 3190 audios who's phoneme/sec are bigger than 25 or smaller than 3 205 | dataset.__len__(): 366463 206 | 207 | """ 208 | # 345410 for LibriTTS 209 | print("dataset.__len__():", self.__len__()) 210 | 211 | def __get_item_names__(self) -> List[str]: 212 | return self.item_names 213 | 214 | def __len__(self) -> int: 215 | return len(self.semantic_phoneme) 216 | 217 | def __getitem__(self, idx: int) -> Dict: 218 | semantic_ids, phoneme_ids = self.semantic_phoneme[idx] 219 | item_name = self.item_names[idx] 220 | phoneme_ids_len = len(phoneme_ids) 221 | # semantic tokens target 222 | semantic_ids_len = len(semantic_ids) 223 | 224 | flag = 0 225 | path_bert = "%s/%s.pt" % (self.path3, item_name) 226 | if os.path.exists(path_bert) == True: 227 | bert_feature = torch.load(path_bert, map_location="cpu") 228 | else: 229 | flag = 1 230 | if flag == 1: 231 | # bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32) 232 | bert_feature = None 233 | else: 234 | assert bert_feature.shape[-1] == len(phoneme_ids) 235 | return { 236 | "idx": idx, 237 | "phoneme_ids": phoneme_ids, 238 | "phoneme_ids_len": phoneme_ids_len, 239 | "semantic_ids": semantic_ids, 240 | "semantic_ids_len": semantic_ids_len, 241 | "bert_feature": bert_feature, 242 | } 243 | 244 | def get_sample_length(self, idx: int): 245 | semantic_ids = self.semantic_phoneme[idx][0] 246 | sec = 1.0 * len(semantic_ids) / self.hz 247 | return sec 248 | 249 | def collate(self, examples: List[Dict]) -> Dict: 250 | sample_index: List[int] = [] 251 | phoneme_ids: List[torch.Tensor] = [] 252 | phoneme_ids_lens: List[int] = [] 253 | semantic_ids: List[torch.Tensor] = [] 254 | semantic_ids_lens: List[int] = [] 255 | # return 256 | 257 | for item in examples: 258 | sample_index.append(item["idx"]) 259 | phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) 260 | semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) 261 | phoneme_ids_lens.append(item["phoneme_ids_len"]) 262 | semantic_ids_lens.append(item["semantic_ids_len"]) 263 | 264 | # pad 0 265 | phoneme_ids = batch_sequences(phoneme_ids) 266 | semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) 267 | 268 | # # convert each batch to torch.tensor 269 | phoneme_ids = torch.tensor(phoneme_ids) 270 | semantic_ids = torch.tensor(semantic_ids) 271 | phoneme_ids_lens = torch.tensor(phoneme_ids_lens) 272 | semantic_ids_lens = torch.tensor(semantic_ids_lens) 273 | bert_padded = torch.FloatTensor(len(examples), 1024, max(phoneme_ids_lens)) 274 | bert_padded.zero_() 275 | 276 | for idx, item in enumerate(examples): 277 | bert = item["bert_feature"] 278 | if bert != None: 279 | bert_padded[idx, :, : bert.shape[-1]] = bert 280 | 281 | return { 282 | # List[int] 283 | "ids": sample_index, 284 | # torch.Tensor (B, max_phoneme_length) 285 | "phoneme_ids": phoneme_ids, 286 | # torch.Tensor (B) 287 | "phoneme_ids_len": phoneme_ids_lens, 288 | # torch.Tensor (B, max_semantic_ids_length) 289 | "semantic_ids": semantic_ids, 290 | # torch.Tensor (B) 291 | "semantic_ids_len": semantic_ids_lens, 292 | # torch.Tensor (B, 1024, max_phoneme_length) 293 | "bert_feature": bert_padded, 294 | } 295 | 296 | 297 | if __name__ == "__main__": 298 | root_dir = "/data/docker/liujing04/gpt-vits/prepare/dump_mix/" 299 | dataset = Text2SemanticDataset( 300 | phoneme_path=root_dir + "phoneme_train.npy", 301 | semantic_path=root_dir + "semantic_train.tsv", 302 | ) 303 | 304 | batch_size = 12 305 | dataloader = DataLoader( 306 | dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False 307 | ) 308 | for i, batch in enumerate(dataloader): 309 | if i % 1000 == 0: 310 | print(i) 311 | # if i == 0: 312 | # print('batch["ids"]:', batch["ids"]) 313 | # print('batch["phoneme_ids"]:', batch["phoneme_ids"], 314 | # batch["phoneme_ids"].shape) 315 | # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"], 316 | # batch["phoneme_ids_len"].shape) 317 | # print('batch["semantic_ids"]:', batch["semantic_ids"], 318 | # batch["semantic_ids"].shape) 319 | # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"], 320 | # batch["semantic_ids_len"].shape) 321 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/AR/models/__init__.py -------------------------------------------------------------------------------- /GPT_SoVITS/AR/models/t2s_lightning_module.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py 2 | import os, sys 3 | 4 | now_dir = os.getcwd() 5 | sys.path.append(now_dir) 6 | from typing import Dict 7 | 8 | import torch 9 | from pytorch_lightning import LightningModule 10 | from AR.models.t2s_model import Text2SemanticDecoder 11 | from AR.modules.lr_schedulers import WarmupCosineLRSchedule 12 | from AR.modules.optim import ScaledAdam 13 | 14 | class Text2SemanticLightningModule(LightningModule): 15 | def __init__(self, config, output_dir, is_train=True): 16 | super().__init__() 17 | self.config = config 18 | self.top_k = 3 19 | self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) 20 | pretrained_s1 = config.get("pretrained_s1") 21 | if pretrained_s1 and is_train: 22 | # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) 23 | print( 24 | self.load_state_dict( 25 | torch.load(pretrained_s1, map_location="cpu")["weight"] 26 | ) 27 | ) 28 | if is_train: 29 | self.automatic_optimization = False 30 | self.save_hyperparameters() 31 | self.eval_dir = output_dir / "eval" 32 | self.eval_dir.mkdir(parents=True, exist_ok=True) 33 | 34 | def training_step(self, batch: Dict, batch_idx: int): 35 | opt = self.optimizers() 36 | scheduler = self.lr_schedulers() 37 | forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old 38 | loss, acc = forward( 39 | batch["phoneme_ids"], 40 | batch["phoneme_ids_len"], 41 | batch["semantic_ids"], 42 | batch["semantic_ids_len"], 43 | batch["bert_feature"], 44 | ) 45 | self.manual_backward(loss) 46 | if batch_idx > 0 and batch_idx % 4 == 0: 47 | opt.step() 48 | opt.zero_grad() 49 | scheduler.step() 50 | 51 | self.log( 52 | "total_loss", 53 | loss, 54 | on_step=True, 55 | on_epoch=True, 56 | prog_bar=True, 57 | sync_dist=True, 58 | ) 59 | self.log( 60 | "lr", 61 | scheduler.get_last_lr()[0], 62 | on_epoch=True, 63 | prog_bar=True, 64 | sync_dist=True, 65 | ) 66 | self.log( 67 | f"top_{self.top_k}_acc", 68 | acc, 69 | on_step=True, 70 | on_epoch=True, 71 | prog_bar=True, 72 | sync_dist=True, 73 | ) 74 | 75 | def validation_step(self, batch: Dict, batch_idx: int): 76 | return 77 | 78 | # # get loss 79 | # loss, acc = self.model.forward( 80 | # batch['phoneme_ids'], batch['phoneme_ids_len'], 81 | # batch['semantic_ids'], batch['semantic_ids_len'], 82 | # batch['bert_feature'] 83 | # ) 84 | # 85 | # self.log( 86 | # "val_total_loss", 87 | # loss, 88 | # on_step=True, 89 | # on_epoch=True, 90 | # prog_bar=True, 91 | # sync_dist=True) 92 | # self.log( 93 | # f"val_top_{self.top_k}_acc", 94 | # acc, 95 | # on_step=True, 96 | # on_epoch=True, 97 | # prog_bar=True, 98 | # sync_dist=True) 99 | # 100 | # # get infer output 101 | # semantic_len = batch['semantic_ids'].size(1) 102 | # prompt_len = min(int(semantic_len * 0.5), 150) 103 | # prompt = batch['semantic_ids'][:, :prompt_len] 104 | # pred_semantic = self.model.infer(batch['phoneme_ids'], 105 | # batch['phoneme_ids_len'], prompt, 106 | # batch['bert_feature'] 107 | # ) 108 | # save_name = f'semantic_toks_{batch_idx}.pt' 109 | # save_path = os.path.join(self.eval_dir, save_name) 110 | # torch.save(pred_semantic.detach().cpu(), save_path) 111 | 112 | def configure_optimizers(self): 113 | model_parameters = self.model.parameters() 114 | parameters_names = [] 115 | parameters_names.append( 116 | [name_param_pair[0] for name_param_pair in self.model.named_parameters()] 117 | ) 118 | lm_opt = ScaledAdam( 119 | model_parameters, 120 | lr=0.01, 121 | betas=(0.9, 0.95), 122 | clipping_scale=2.0, 123 | parameters_names=parameters_names, 124 | show_dominant_parameters=False, 125 | clipping_update_period=1000, 126 | ) 127 | 128 | return { 129 | "optimizer": lm_opt, 130 | "lr_scheduler": { 131 | "scheduler": WarmupCosineLRSchedule( 132 | lm_opt, 133 | init_lr=self.config["optimizer"]["lr_init"], 134 | peak_lr=self.config["optimizer"]["lr"], 135 | end_lr=self.config["optimizer"]["lr_end"], 136 | warmup_steps=self.config["optimizer"]["warmup_steps"], 137 | total_steps=self.config["optimizer"]["decay_steps"], 138 | ) 139 | }, 140 | } 141 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/models/t2s_lightning_module_onnx.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py 2 | import os, sys 3 | 4 | now_dir = os.getcwd() 5 | sys.path.append(now_dir) 6 | from typing import Dict 7 | 8 | import torch 9 | from pytorch_lightning import LightningModule 10 | from AR.models.t2s_model_onnx import Text2SemanticDecoder 11 | from AR.modules.lr_schedulers import WarmupCosineLRSchedule 12 | from AR.modules.optim import ScaledAdam 13 | 14 | 15 | class Text2SemanticLightningModule(LightningModule): 16 | def __init__(self, config, output_dir, is_train=True): 17 | super().__init__() 18 | self.config = config 19 | self.top_k = 3 20 | self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) 21 | pretrained_s1 = config.get("pretrained_s1") 22 | if pretrained_s1 and is_train: 23 | # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) 24 | print( 25 | self.load_state_dict( 26 | torch.load(pretrained_s1, map_location="cpu")["weight"] 27 | ) 28 | ) 29 | if is_train: 30 | self.automatic_optimization = False 31 | self.save_hyperparameters() 32 | self.eval_dir = output_dir / "eval" 33 | self.eval_dir.mkdir(parents=True, exist_ok=True) 34 | 35 | def training_step(self, batch: Dict, batch_idx: int): 36 | opt = self.optimizers() 37 | scheduler = self.lr_schedulers() 38 | loss, acc = self.model.forward( 39 | batch["phoneme_ids"], 40 | batch["phoneme_ids_len"], 41 | batch["semantic_ids"], 42 | batch["semantic_ids_len"], 43 | batch["bert_feature"], 44 | ) 45 | self.manual_backward(loss) 46 | if batch_idx > 0 and batch_idx % 4 == 0: 47 | opt.step() 48 | opt.zero_grad() 49 | scheduler.step() 50 | 51 | self.log( 52 | "total_loss", 53 | loss, 54 | on_step=True, 55 | on_epoch=True, 56 | prog_bar=True, 57 | sync_dist=True, 58 | ) 59 | self.log( 60 | "lr", 61 | scheduler.get_last_lr()[0], 62 | on_epoch=True, 63 | prog_bar=True, 64 | sync_dist=True, 65 | ) 66 | self.log( 67 | f"top_{self.top_k}_acc", 68 | acc, 69 | on_step=True, 70 | on_epoch=True, 71 | prog_bar=True, 72 | sync_dist=True, 73 | ) 74 | 75 | def validation_step(self, batch: Dict, batch_idx: int): 76 | return 77 | 78 | def configure_optimizers(self): 79 | model_parameters = self.model.parameters() 80 | parameters_names = [] 81 | parameters_names.append( 82 | [name_param_pair[0] for name_param_pair in self.model.named_parameters()] 83 | ) 84 | lm_opt = ScaledAdam( 85 | model_parameters, 86 | lr=0.01, 87 | betas=(0.9, 0.95), 88 | clipping_scale=2.0, 89 | parameters_names=parameters_names, 90 | show_dominant_parameters=False, 91 | clipping_update_period=1000, 92 | ) 93 | 94 | return { 95 | "optimizer": lm_opt, 96 | "lr_scheduler": { 97 | "scheduler": WarmupCosineLRSchedule( 98 | lm_opt, 99 | init_lr=self.config["optimizer"]["lr_init"], 100 | peak_lr=self.config["optimizer"]["lr"], 101 | end_lr=self.config["optimizer"]["lr_end"], 102 | warmup_steps=self.config["optimizer"]["warmup_steps"], 103 | total_steps=self.config["optimizer"]["decay_steps"], 104 | ) 105 | }, 106 | } 107 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/models/utils.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\ 2 | import torch 3 | import torch.nn.functional as F 4 | from typing import Tuple 5 | 6 | def sequence_mask(length, max_length=None): 7 | if max_length is None: 8 | max_length = length.max() 9 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 10 | return x.unsqueeze(0) < length.unsqueeze(1) 11 | 12 | 13 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 14 | """ 15 | Args: 16 | lengths: 17 | A 1-D tensor containing sentence lengths. 18 | max_len: 19 | The length of masks. 20 | Returns: 21 | Return a 2-D bool tensor, where masked positions 22 | are filled with `True` and non-masked positions are 23 | filled with `False`. 24 | 25 | #>>> lengths = torch.tensor([1, 3, 2, 5]) 26 | #>>> make_pad_mask(lengths) 27 | tensor([[False, True, True, True, True], 28 | [False, False, False, True, True], 29 | [False, False, True, True, True], 30 | [False, False, False, False, False]]) 31 | """ 32 | assert lengths.ndim == 1, lengths.ndim 33 | max_len = max(max_len, lengths.max()) 34 | n = lengths.size(0) 35 | seq_range = torch.arange(0, max_len, device=lengths.device) 36 | expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) 37 | 38 | return expaned_lengths >= lengths.unsqueeze(-1) 39 | 40 | 41 | # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py 42 | def top_k_top_p_filtering( 43 | logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 44 | ): 45 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 46 | Args: 47 | logits: logits distribution shape (batch size, vocabulary size) 48 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 49 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 50 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 51 | Make sure we keep at least min_tokens_to_keep per batch example in the output 52 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 53 | """ 54 | if top_k > 0: 55 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 56 | # Remove all tokens with a probability less than the last token of the top-k 57 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 58 | logits[indices_to_remove] = filter_value 59 | 60 | if top_p < 1.0: 61 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 62 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 63 | 64 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 65 | sorted_indices_to_remove = cumulative_probs > top_p 66 | if min_tokens_to_keep > 1: 67 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 68 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 69 | # Shift the indices to the right to keep also the first token above the threshold 70 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 71 | sorted_indices_to_remove[..., 0] = 0 72 | 73 | # scatter sorted tensors to original indexing 74 | indices_to_remove = sorted_indices_to_remove.scatter( 75 | 1, sorted_indices, sorted_indices_to_remove 76 | ) 77 | logits[indices_to_remove] = filter_value 78 | return logits 79 | 80 | 81 | def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): 82 | # temperature: (`optional`) float 83 | # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. 84 | # top_k: (`optional`) int 85 | # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. 86 | # top_p: (`optional`) float 87 | # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. 88 | 89 | # Temperature (higher temperature => more likely to sample low probability tokens) 90 | if temperature != 1.0: 91 | logits = logits / temperature 92 | # Top-p/top-k filtering 93 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 94 | # Sample 95 | token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) 96 | return token 97 | 98 | 99 | from typing import Optional, Tuple 100 | 101 | 102 | def multinomial_sample_one_no_sync( 103 | probs_sort, 104 | ): # Does multinomial sampling without a cuda synchronization 105 | q = torch.empty_like(probs_sort).exponential_(1) 106 | return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 107 | 108 | 109 | def logits_to_probs( 110 | logits, 111 | previous_tokens: Optional[torch.Tensor] = None, 112 | temperature: float = 1.0, 113 | top_k: Optional[int] = None, 114 | top_p: Optional[int] = None, 115 | repetition_penalty: float = 1.0, 116 | ): 117 | if previous_tokens is not None: 118 | previous_tokens = previous_tokens.squeeze() 119 | # print(logits.shape,previous_tokens.shape) 120 | # pdb.set_trace() 121 | if previous_tokens is not None and repetition_penalty != 1.0: 122 | previous_tokens = previous_tokens.long() 123 | score = torch.gather(logits, dim=0, index=previous_tokens) 124 | score = torch.where( 125 | score < 0, score * repetition_penalty, score / repetition_penalty 126 | ) 127 | logits.scatter_(dim=0, index=previous_tokens, src=score) 128 | 129 | if top_p is not None and top_p < 1.0: 130 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 131 | cum_probs = torch.cumsum( 132 | torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1 133 | ) 134 | sorted_indices_to_remove = cum_probs > top_p 135 | sorted_indices_to_remove[0] = False # keep at least one option 136 | indices_to_remove = sorted_indices_to_remove.scatter( 137 | dim=0, index=sorted_indices, src=sorted_indices_to_remove 138 | ) 139 | logits = logits.masked_fill(indices_to_remove, -float("Inf")) 140 | 141 | logits = logits / max(temperature, 1e-5) 142 | 143 | if top_k is not None: 144 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 145 | pivot = v.select(-1, -1).unsqueeze(-1) 146 | logits = torch.where(logits < pivot, -float("Inf"), logits) 147 | 148 | probs = torch.nn.functional.softmax(logits, dim=-1) 149 | return probs 150 | 151 | 152 | def sample( 153 | logits, 154 | previous_tokens: Optional[torch.Tensor] = None, 155 | **sampling_kwargs, 156 | ) -> Tuple[torch.Tensor, torch.Tensor]: 157 | probs = logits_to_probs( 158 | logits=logits, previous_tokens=previous_tokens, **sampling_kwargs 159 | ) 160 | idx_next = multinomial_sample_one_no_sync(probs) 161 | return idx_next, probs 162 | 163 | def dpo_loss(policy_chosen_logps: torch.FloatTensor, 164 | policy_rejected_logps: torch.FloatTensor, 165 | reference_chosen_logps: torch.FloatTensor, 166 | reference_rejected_logps: torch.FloatTensor, 167 | beta: float, 168 | reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 169 | pi_logratios = policy_chosen_logps - policy_rejected_logps 170 | ref_logratios = reference_chosen_logps - reference_rejected_logps 171 | 172 | if reference_free: 173 | ref_logratios = 0 174 | 175 | logits = pi_logratios - ref_logratios 176 | 177 | losses = -F.logsigmoid(beta * logits) 178 | chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() 179 | rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() 180 | 181 | return losses.mean(), chosen_rewards, rejected_rewards 182 | 183 | def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 184 | 185 | # dummy token; we'll ignore the losses on these tokens later 186 | 187 | per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2) 188 | per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2) 189 | 190 | return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1) 191 | 192 | def make_reject_y(y_o, y_lens): 193 | def repeat_P(y): 194 | range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() 195 | pre = y[:range_idx[0]] 196 | shf = y[range_idx[1]:] 197 | range_text = y[range_idx[0]:range_idx[1]] 198 | new_y = torch.cat([pre, range_text, range_text, shf]) 199 | return new_y 200 | def lost_P(y): 201 | range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() 202 | pre = y[:range_idx[0]] 203 | shf = y[range_idx[1]:] 204 | range_text = y[range_idx[0]:range_idx[1]] 205 | new_y = torch.cat([pre, shf]) 206 | return new_y 207 | bs = len(y_lens) 208 | reject_y = [] 209 | reject_y_lens = [] 210 | for b in range(bs): 211 | process_item_idx = torch.randint(0, 1, size=(1, ))[0] 212 | if process_item_idx == 0: 213 | new_y = repeat_P(y_o[b]) 214 | reject_y.append(new_y) 215 | reject_y_lens.append(len(new_y)) 216 | elif process_item_idx==1: 217 | new_y = lost_P(y_o[b]) 218 | reject_y.append(new_y) 219 | reject_y_lens.append(len(new_y)) 220 | max_length = max(reject_y_lens) 221 | for b in range(bs): 222 | pad_length = max_length - reject_y_lens[b] 223 | reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0) 224 | 225 | reject_y = torch.stack(reject_y, dim = 0) 226 | reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device) 227 | 228 | return reject_y, reject_y_lens 229 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/AR/modules/__init__.py -------------------------------------------------------------------------------- /GPT_SoVITS/AR/modules/activation_onnx.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py 2 | from typing import Optional 3 | from typing import Tuple 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import Linear 7 | from torch.nn import Module 8 | from torch.nn.init import constant_ 9 | from torch.nn.init import xavier_normal_ 10 | from torch.nn.init import xavier_uniform_ 11 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear 12 | from torch.nn.parameter import Parameter 13 | 14 | from torch.nn import functional as F 15 | from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched 16 | 17 | 18 | class MultiheadAttention(Module): 19 | __constants__ = ["batch_first"] 20 | bias_k: Optional[torch.Tensor] 21 | bias_v: Optional[torch.Tensor] 22 | 23 | def __init__( 24 | self, 25 | embed_dim, 26 | num_heads, 27 | dropout=0.0, 28 | bias=True, 29 | add_bias_kv=False, 30 | add_zero_attn=False, 31 | kdim=None, 32 | vdim=None, 33 | batch_first=False, 34 | linear1_cls=Linear, 35 | linear2_cls=Linear, 36 | device=None, 37 | dtype=None, 38 | ) -> None: 39 | factory_kwargs = {"device": device, "dtype": dtype} 40 | super(MultiheadAttention, self).__init__() 41 | self.embed_dim = embed_dim 42 | self.kdim = kdim if kdim is not None else embed_dim 43 | self.vdim = vdim if vdim is not None else embed_dim 44 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 45 | 46 | self.num_heads = num_heads 47 | self.dropout = dropout 48 | self.batch_first = batch_first 49 | self.head_dim = embed_dim // num_heads 50 | assert ( 51 | self.head_dim * num_heads == self.embed_dim 52 | ), "embed_dim must be divisible by num_heads" 53 | 54 | if add_bias_kv: 55 | self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 56 | self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 57 | else: 58 | self.bias_k = self.bias_v = None 59 | 60 | if linear1_cls == Linear: 61 | if not self._qkv_same_embed_dim: 62 | self.q_proj_weight = Parameter( 63 | torch.empty((embed_dim, embed_dim), **factory_kwargs) 64 | ) 65 | self.k_proj_weight = Parameter( 66 | torch.empty((embed_dim, self.kdim), **factory_kwargs) 67 | ) 68 | self.v_proj_weight = Parameter( 69 | torch.empty((embed_dim, self.vdim), **factory_kwargs) 70 | ) 71 | self.register_parameter("in_proj_weight", None) 72 | else: 73 | self.in_proj_weight = Parameter( 74 | torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 75 | ) 76 | self.register_parameter("q_proj_weight", None) 77 | self.register_parameter("k_proj_weight", None) 78 | self.register_parameter("v_proj_weight", None) 79 | 80 | if bias: 81 | self.in_proj_bias = Parameter( 82 | torch.empty(3 * embed_dim, **factory_kwargs) 83 | ) 84 | else: 85 | self.register_parameter("in_proj_bias", None) 86 | self.out_proj = NonDynamicallyQuantizableLinear( 87 | embed_dim, embed_dim, bias=bias, **factory_kwargs 88 | ) 89 | 90 | self._reset_parameters() 91 | else: 92 | if not self._qkv_same_embed_dim: 93 | raise NotImplementedError 94 | else: 95 | self.in_proj_linear = linear1_cls( 96 | embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs 97 | ) 98 | self.in_proj_weight = self.in_proj_linear.weight 99 | 100 | self.register_parameter("q_proj_weight", None) 101 | self.register_parameter("k_proj_weight", None) 102 | self.register_parameter("v_proj_weight", None) 103 | 104 | if bias: 105 | self.in_proj_bias = self.in_proj_linear.bias 106 | else: 107 | self.register_parameter("in_proj_bias", None) 108 | 109 | self.out_proj = linear2_cls( 110 | embed_dim, embed_dim, bias=bias, **factory_kwargs 111 | ) 112 | 113 | if self.bias_k is not None: 114 | xavier_normal_(self.bias_k) 115 | if self.bias_v is not None: 116 | xavier_normal_(self.bias_v) 117 | 118 | self.add_zero_attn = add_zero_attn 119 | 120 | def _reset_parameters(self): 121 | if self._qkv_same_embed_dim: 122 | xavier_uniform_(self.in_proj_weight) 123 | else: 124 | xavier_uniform_(self.q_proj_weight) 125 | xavier_uniform_(self.k_proj_weight) 126 | xavier_uniform_(self.v_proj_weight) 127 | 128 | if self.in_proj_bias is not None: 129 | constant_(self.in_proj_bias, 0.0) 130 | constant_(self.out_proj.bias, 0.0) 131 | 132 | if self.bias_k is not None: 133 | xavier_normal_(self.bias_k) 134 | if self.bias_v is not None: 135 | xavier_normal_(self.bias_v) 136 | 137 | def __setstate__(self, state): 138 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 139 | if "_qkv_same_embed_dim" not in state: 140 | state["_qkv_same_embed_dim"] = True 141 | 142 | super(MultiheadAttention, self).__setstate__(state) 143 | 144 | def forward( 145 | self, 146 | query: Tensor, 147 | key: Tensor, 148 | value: Tensor, 149 | key_padding_mask: Optional[Tensor] = None, 150 | need_weights: bool = True, 151 | attn_mask: Optional[Tensor] = None, 152 | average_attn_weights: bool = True, 153 | cache=None, 154 | ) -> Tuple[Tensor, Optional[Tensor]]: 155 | any_nested = query.is_nested or key.is_nested or value.is_nested 156 | query = key = value = query.transpose(1, 0) 157 | attn_output = multi_head_attention_forward_patched( 158 | query, 159 | key, 160 | value, 161 | self.embed_dim, 162 | self.num_heads, 163 | self.in_proj_weight, 164 | self.in_proj_bias, 165 | self.bias_k, 166 | self.bias_v, 167 | self.add_zero_attn, 168 | self.dropout, 169 | self.out_proj.weight, 170 | self.out_proj.bias, 171 | training=self.training, 172 | key_padding_mask=key_padding_mask, 173 | need_weights=need_weights, 174 | attn_mask=attn_mask, 175 | average_attn_weights=average_attn_weights, 176 | cache=cache, 177 | ) 178 | return attn_output.transpose(1, 0) 179 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class TokenEmbedding(nn.Module): 9 | def __init__( 10 | self, 11 | embedding_dim: int, 12 | vocab_size: int, 13 | dropout: float = 0.0, 14 | ): 15 | super().__init__() 16 | 17 | self.vocab_size = vocab_size 18 | self.embedding_dim = embedding_dim 19 | 20 | self.dropout = torch.nn.Dropout(p=dropout) 21 | self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) 22 | 23 | @property 24 | def weight(self) -> torch.Tensor: 25 | return self.word_embeddings.weight 26 | 27 | def embedding(self, index: int) -> torch.Tensor: 28 | return self.word_embeddings.weight[index : index + 1] 29 | 30 | def forward(self, x: torch.Tensor): 31 | x = self.word_embeddings(x) 32 | x = self.dropout(x) 33 | return x 34 | 35 | 36 | class SinePositionalEmbedding(nn.Module): 37 | def __init__( 38 | self, 39 | embedding_dim: int, 40 | dropout: float = 0.0, 41 | scale: bool = False, 42 | alpha: bool = False, 43 | ): 44 | super().__init__() 45 | self.embedding_dim = embedding_dim 46 | self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 47 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 48 | self.dropout = torch.nn.Dropout(p=dropout) 49 | 50 | self.reverse = False 51 | self.pe = None 52 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 53 | 54 | def extend_pe(self, x): 55 | """Reset the positional encodings.""" 56 | if self.pe is not None: 57 | if self.pe.size(1) >= x.size(1): 58 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 59 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 60 | return 61 | pe = torch.zeros(x.size(1), self.embedding_dim) 62 | if self.reverse: 63 | position = torch.arange( 64 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 65 | ).unsqueeze(1) 66 | else: 67 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) 68 | div_term = torch.exp( 69 | torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) 70 | * -(math.log(10000.0) / self.embedding_dim) 71 | ) 72 | pe[:, 0::2] = torch.sin(position * div_term) 73 | pe[:, 1::2] = torch.cos(position * div_term) 74 | pe = pe.unsqueeze(0) 75 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 76 | 77 | def forward(self, x: torch.Tensor) -> torch.Tensor: 78 | self.extend_pe(x) 79 | output = x.unsqueeze(-1) if x.ndim == 2 else x 80 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] 81 | return self.dropout(output) 82 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/modules/embedding_onnx.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class TokenEmbedding(nn.Module): 9 | def __init__( 10 | self, 11 | embedding_dim: int, 12 | vocab_size: int, 13 | dropout: float = 0.0, 14 | ): 15 | super().__init__() 16 | 17 | self.vocab_size = vocab_size 18 | self.embedding_dim = embedding_dim 19 | 20 | self.dropout = torch.nn.Dropout(p=dropout) 21 | self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) 22 | 23 | @property 24 | def weight(self) -> torch.Tensor: 25 | return self.word_embeddings.weight 26 | 27 | def embedding(self, index: int) -> torch.Tensor: 28 | return self.word_embeddings.weight[index : index + 1] 29 | 30 | def forward(self, x: torch.Tensor): 31 | x = self.word_embeddings(x) 32 | x = self.dropout(x) 33 | return x 34 | 35 | 36 | class SinePositionalEmbedding(nn.Module): 37 | def __init__( 38 | self, 39 | embedding_dim: int, 40 | dropout: float = 0.0, 41 | scale: bool = False, 42 | alpha: bool = False, 43 | ): 44 | super().__init__() 45 | self.embedding_dim = embedding_dim 46 | self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 47 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 48 | self.dropout = torch.nn.Dropout(p=dropout) 49 | self.reverse = False 50 | self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim)) 51 | 52 | def extend_pe(self, x): 53 | position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1) 54 | scpe = (position * self.div_term).unsqueeze(0) 55 | pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0) 56 | pe = pe.contiguous().view(1, -1, self.embedding_dim) 57 | return pe 58 | 59 | def forward(self, x: torch.Tensor) -> torch.Tensor: 60 | pe = self.extend_pe(x) 61 | output = x.unsqueeze(-1) if x.ndim == 2 else x 62 | output = output * self.x_scale + self.alpha * pe 63 | return self.dropout(output) 64 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/modules/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py 2 | import math 3 | 4 | import torch 5 | from matplotlib import pyplot as plt 6 | from torch import nn 7 | from torch.optim import Adam 8 | 9 | 10 | class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): 11 | """ 12 | Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | optimizer, 18 | init_lr, 19 | peak_lr, 20 | end_lr, 21 | warmup_steps=10000, 22 | total_steps=400000, 23 | current_step=0, 24 | ): 25 | self.init_lr = init_lr 26 | self.peak_lr = peak_lr 27 | self.end_lr = end_lr 28 | self.optimizer = optimizer 29 | self._warmup_rate = (peak_lr - init_lr) / warmup_steps 30 | self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps) 31 | self._current_step = current_step 32 | self.lr = init_lr 33 | self.warmup_steps = warmup_steps 34 | self.total_steps = total_steps 35 | self._last_lr = [self.lr] 36 | 37 | def set_lr(self, lr): 38 | self._last_lr = [g["lr"] for g in self.optimizer.param_groups] 39 | for g in self.optimizer.param_groups: 40 | # g['lr'] = lr 41 | g["lr"] = self.end_lr ###锁定用线性 42 | 43 | def step(self): 44 | if self._current_step < self.warmup_steps: 45 | lr = self.init_lr + self._warmup_rate * self._current_step 46 | 47 | elif self._current_step > self.total_steps: 48 | lr = self.end_lr 49 | 50 | else: 51 | decay_ratio = (self._current_step - self.warmup_steps) / ( 52 | self.total_steps - self.warmup_steps 53 | ) 54 | if decay_ratio < 0.0 or decay_ratio > 1.0: 55 | raise RuntimeError( 56 | "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings." 57 | ) 58 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 59 | lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) 60 | 61 | self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定! 62 | self.set_lr(lr) 63 | self.lr = lr 64 | self._current_step += 1 65 | return self.lr 66 | 67 | 68 | if __name__ == "__main__": 69 | m = nn.Linear(10, 10) 70 | opt = Adam(m.parameters(), lr=1e-4) 71 | s = WarmupCosineLRSchedule( 72 | opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0 73 | ) 74 | lrs = [] 75 | for i in range(25000): 76 | s.step() 77 | lrs.append(s.lr) 78 | print(s.lr) 79 | 80 | plt.plot(lrs) 81 | plt.plot(range(0, 25000), lrs) 82 | plt.show() 83 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/modules/patched_mha_with_cache_onnx.py: -------------------------------------------------------------------------------- 1 | from torch.nn.functional import * 2 | from torch.nn.functional import ( 3 | _mha_shape_check, 4 | _canonical_mask, 5 | _none_or_dtype, 6 | _in_projection_packed, 7 | ) 8 | 9 | def multi_head_attention_forward_patched( 10 | query, 11 | key, 12 | value, 13 | embed_dim_to_check: int, 14 | num_heads: int, 15 | in_proj_weight, 16 | in_proj_bias: Optional[Tensor], 17 | bias_k: Optional[Tensor], 18 | bias_v: Optional[Tensor], 19 | add_zero_attn: bool, 20 | dropout_p: float, 21 | out_proj_weight: Tensor, 22 | out_proj_bias: Optional[Tensor], 23 | training: bool = True, 24 | key_padding_mask: Optional[Tensor] = None, 25 | need_weights: bool = True, 26 | attn_mask: Optional[Tensor] = None, 27 | use_separate_proj_weight: bool = False, 28 | q_proj_weight: Optional[Tensor] = None, 29 | k_proj_weight: Optional[Tensor] = None, 30 | v_proj_weight: Optional[Tensor] = None, 31 | static_k: Optional[Tensor] = None, 32 | static_v: Optional[Tensor] = None, 33 | average_attn_weights: bool = True, 34 | is_causal: bool = False, 35 | cache=None, 36 | ) -> Tuple[Tensor, Optional[Tensor]]: 37 | 38 | # set up shape vars 39 | _, _, embed_dim = query.shape 40 | attn_mask = _canonical_mask( 41 | mask=attn_mask, 42 | mask_name="attn_mask", 43 | other_type=None, 44 | other_name="", 45 | target_type=query.dtype, 46 | check_other=False, 47 | ) 48 | head_dim = embed_dim // num_heads 49 | 50 | proj_qkv = linear(query, in_proj_weight, in_proj_bias) 51 | proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() 52 | q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2] 53 | 54 | if cache["first_infer"] == 1: 55 | cache["k"][cache["stage"]] = k 56 | cache["v"][cache["stage"]] = v 57 | else: 58 | cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]][:-1], k], 0) 59 | cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]][:-1], v], 0) 60 | k = cache["k"][cache["stage"]] 61 | v = cache["v"][cache["stage"]] 62 | cache["stage"] = (cache["stage"] + 1) % cache["all_stage"] 63 | 64 | attn_mask = _canonical_mask( 65 | mask=attn_mask, 66 | mask_name="attn_mask", 67 | other_type=None, 68 | other_name="", 69 | target_type=q.dtype, 70 | check_other=False, 71 | ) 72 | attn_mask = attn_mask.unsqueeze(0) 73 | 74 | q = q.view(-1, num_heads, head_dim).transpose(0, 1) 75 | k = k.view(-1, num_heads, head_dim).transpose(0, 1) 76 | v = v.view(-1, num_heads, head_dim).transpose(0, 1) 77 | 78 | dropout_p = 0.0 79 | attn_mask = attn_mask.unsqueeze(0) 80 | q = q.view(num_heads, -1, head_dim).unsqueeze(0) 81 | k = k.view(num_heads, -1, head_dim).unsqueeze(0) 82 | v = v.view(num_heads, -1, head_dim).unsqueeze(0) 83 | attn_output = scaled_dot_product_attention( 84 | q, k, v, attn_mask, dropout_p, is_causal 85 | ) 86 | attn_output = ( 87 | attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim) 88 | ) 89 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 90 | attn_output = attn_output.view(-1, 1, attn_output.size(1)) 91 | 92 | return attn_output 93 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/modules/transformer_onnx.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py 2 | import copy 3 | import numbers 4 | from functools import partial 5 | from typing import Any 6 | from typing import Callable 7 | from typing import List 8 | from typing import Optional 9 | from typing import Tuple 10 | from typing import Union 11 | 12 | import torch 13 | from AR.modules.activation_onnx import MultiheadAttention 14 | from AR.modules.scaling import BalancedDoubleSwish 15 | from torch import nn 16 | from torch import Tensor 17 | from torch.nn import functional as F 18 | 19 | _shape_t = Union[int, List[int], torch.Size] 20 | 21 | 22 | class LayerNorm(nn.Module): 23 | __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 24 | normalized_shape: Tuple[int, ...] 25 | eps: float 26 | elementwise_affine: bool 27 | 28 | def __init__( 29 | self, 30 | normalized_shape: _shape_t, 31 | eps: float = 1e-5, 32 | elementwise_affine: bool = True, 33 | device=None, 34 | dtype=None, 35 | ) -> None: 36 | factory_kwargs = {"device": device, "dtype": dtype} 37 | super(LayerNorm, self).__init__() 38 | if isinstance(normalized_shape, numbers.Integral): 39 | # mypy error: incompatible types in assignment 40 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 41 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 42 | self.eps = eps 43 | self.elementwise_affine = elementwise_affine 44 | if self.elementwise_affine: 45 | self.weight = nn.Parameter( 46 | torch.empty(self.normalized_shape, **factory_kwargs) 47 | ) 48 | self.bias = nn.Parameter( 49 | torch.empty(self.normalized_shape, **factory_kwargs) 50 | ) 51 | else: 52 | self.register_parameter("weight", None) 53 | self.register_parameter("bias", None) 54 | 55 | self.reset_parameters() 56 | 57 | def reset_parameters(self) -> None: 58 | if self.elementwise_affine: 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 63 | if isinstance(input, tuple): 64 | input, embedding = input 65 | return ( 66 | F.layer_norm( 67 | input, 68 | self.normalized_shape, 69 | self.weight, 70 | self.bias, 71 | self.eps, 72 | ), 73 | embedding, 74 | ) 75 | 76 | assert embedding is None 77 | return F.layer_norm( 78 | input, self.normalized_shape, self.weight, self.bias, self.eps 79 | ) 80 | 81 | def extra_repr(self) -> str: 82 | return ( 83 | "{normalized_shape}, eps={eps}, " 84 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 85 | ) 86 | 87 | 88 | class IdentityNorm(nn.Module): 89 | def __init__( 90 | self, 91 | d_model: int, 92 | eps: float = 1e-5, 93 | device=None, 94 | dtype=None, 95 | ) -> None: 96 | super(IdentityNorm, self).__init__() 97 | 98 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 99 | if isinstance(input, tuple): 100 | return input 101 | 102 | assert embedding is None 103 | return input 104 | 105 | 106 | class TransformerEncoder(nn.Module): 107 | r"""TransformerEncoder is a stack of N encoder layers. Users can build the 108 | BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. 109 | 110 | Args: 111 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 112 | num_layers: the number of sub-encoder-layers in the encoder (required). 113 | norm: the layer normalization component (optional). 114 | enable_nested_tensor: if True, input will automatically convert to nested tensor 115 | (and convert back on output). This will improve the overall performance of 116 | TransformerEncoder when padding rate is high. Default: ``True`` (enabled). 117 | 118 | Examples:: 119 | >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) 120 | >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) 121 | >>> src = torch.rand(10, 32, 512) 122 | >>> out = transformer_encoder(src) 123 | """ 124 | __constants__ = ["norm"] 125 | 126 | def __init__(self, encoder_layer, num_layers, norm=None): 127 | super(TransformerEncoder, self).__init__() 128 | self.layers = _get_clones(encoder_layer, num_layers) 129 | self.num_layers = num_layers 130 | self.norm = norm 131 | 132 | def forward( 133 | self, 134 | src: Tensor, 135 | mask: Optional[Tensor] = None, 136 | src_key_padding_mask: Optional[Tensor] = None, 137 | return_layer_states: bool = False, 138 | cache=None, 139 | ) -> Tensor: 140 | output = src 141 | for mod in self.layers: 142 | output = mod( 143 | output, 144 | src_mask=mask, 145 | src_key_padding_mask=src_key_padding_mask, 146 | cache=cache, 147 | ) 148 | 149 | if self.norm is not None: 150 | output = self.norm(output) 151 | 152 | return output 153 | 154 | 155 | class TransformerEncoderLayer(nn.Module): 156 | __constants__ = ["batch_first", "norm_first"] 157 | def __init__( 158 | self, 159 | d_model: int, 160 | nhead: int, 161 | dim_feedforward: int = 2048, 162 | dropout: float = 0.1, 163 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 164 | batch_first: bool = False, 165 | norm_first: bool = False, 166 | device=None, 167 | dtype=None, 168 | linear1_self_attention_cls: nn.Module = nn.Linear, 169 | linear2_self_attention_cls: nn.Module = nn.Linear, 170 | linear1_feedforward_cls: nn.Module = nn.Linear, 171 | linear2_feedforward_cls: nn.Module = nn.Linear, 172 | layer_norm_cls: nn.Module = LayerNorm, 173 | layer_norm_eps: float = 1e-5, 174 | adaptive_layer_norm=False, 175 | ) -> None: 176 | factory_kwargs = {"device": device, "dtype": dtype} 177 | super(TransformerEncoderLayer, self).__init__() 178 | self.self_attn = MultiheadAttention( 179 | d_model, # 512 16 180 | nhead, 181 | dropout=dropout, 182 | batch_first=batch_first, 183 | linear1_cls=linear1_self_attention_cls, 184 | linear2_cls=linear2_self_attention_cls, 185 | **factory_kwargs, 186 | ) 187 | self.linear1 = linear1_feedforward_cls( 188 | d_model, dim_feedforward, **factory_kwargs 189 | ) 190 | self.dropout = nn.Dropout(dropout) 191 | self.linear2 = linear2_feedforward_cls( 192 | dim_feedforward, d_model, **factory_kwargs 193 | ) 194 | self.norm_first = norm_first 195 | self.dropout1 = nn.Dropout(dropout) 196 | self.dropout2 = nn.Dropout(dropout) 197 | if isinstance(activation, str): 198 | activation = _get_activation_fn(activation) 199 | elif isinstance(activation, partial): 200 | activation = activation(d_model) 201 | elif activation == BalancedDoubleSwish: 202 | activation = BalancedDoubleSwish(d_model) 203 | self.activation = activation 204 | 205 | norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) 206 | if layer_norm_cls == IdentityNorm: 207 | norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) 208 | else: 209 | norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) 210 | 211 | if adaptive_layer_norm: 212 | self.norm1 = AdaptiveLayerNorm(d_model, norm1) 213 | self.norm2 = AdaptiveLayerNorm(d_model, norm2) 214 | else: 215 | self.norm1 = norm1 216 | self.norm2 = norm2 217 | 218 | def __setstate__(self, state): 219 | super(TransformerEncoderLayer, self).__setstate__(state) 220 | if not hasattr(self, "activation"): 221 | self.activation = F.relu 222 | 223 | def forward( 224 | self, 225 | src: Tensor, 226 | src_mask: Optional[Tensor] = None, 227 | src_key_padding_mask: Optional[Tensor] = None, 228 | cache=None, 229 | ) -> Tensor: 230 | x = src 231 | stage_embedding = None 232 | x = self.norm1( 233 | x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), 234 | stage_embedding, 235 | ) 236 | x = self.norm2(x + self._ff_block(x), stage_embedding) 237 | 238 | return x 239 | 240 | def _sa_block( 241 | self, 242 | x: Tensor, 243 | attn_mask: Optional[Tensor], 244 | key_padding_mask: Optional[Tensor], 245 | cache=None, 246 | ) -> Tensor: 247 | x = self.self_attn( 248 | x, 249 | x, 250 | x, 251 | attn_mask=attn_mask, 252 | key_padding_mask=key_padding_mask, 253 | need_weights=False, 254 | cache=cache, 255 | ) 256 | return self.dropout1(x) 257 | 258 | def _ff_block(self, x: Tensor) -> Tensor: 259 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 260 | return self.dropout2(x) 261 | 262 | 263 | class AdaptiveLayerNorm(nn.Module): 264 | r"""Adaptive Layer Normalization""" 265 | 266 | def __init__(self, d_model, norm) -> None: 267 | super(AdaptiveLayerNorm, self).__init__() 268 | self.project_layer = nn.Linear(d_model, 2 * d_model) 269 | self.norm = norm 270 | self.d_model = d_model 271 | self.eps = self.norm.eps 272 | 273 | def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: 274 | if isinstance(input, tuple): 275 | input, embedding = input 276 | weight, bias = torch.split( 277 | self.project_layer(embedding), 278 | split_size_or_sections=self.d_model, 279 | dim=-1, 280 | ) 281 | return (weight * self.norm(input) + bias, embedding) 282 | 283 | weight, bias = torch.split( 284 | self.project_layer(embedding), 285 | split_size_or_sections=self.d_model, 286 | dim=-1, 287 | ) 288 | return weight * self.norm(input) + bias 289 | 290 | 291 | def _get_clones(module, N): 292 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 293 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/text_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/AR/text_processing/__init__.py -------------------------------------------------------------------------------- /GPT_SoVITS/AR/text_processing/phonemizer.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py 2 | import itertools 3 | import re 4 | from typing import Dict 5 | from typing import List 6 | 7 | import regex 8 | from gruut import sentences 9 | from gruut.const import Sentence 10 | from gruut.const import Word 11 | from AR.text_processing.symbols import SYMBOL_TO_ID 12 | 13 | 14 | class GruutPhonemizer: 15 | def __init__(self, language: str): 16 | self._phonemizer = sentences 17 | self.lang = language 18 | self.symbol_to_id = SYMBOL_TO_ID 19 | self._special_cases_dict: Dict[str] = { 20 | r"\.\.\.": "... ", 21 | ";": "; ", 22 | ":": ": ", 23 | ",": ", ", 24 | r"\.": ". ", 25 | "!": "! ", 26 | r"\?": "? ", 27 | "—": "—", 28 | "…": "… ", 29 | "«": "«", 30 | "»": "»", 31 | } 32 | self._punctuation_regexp: str = ( 33 | rf"([{''.join(self._special_cases_dict.keys())}])" 34 | ) 35 | 36 | def _normalize_punctuation(self, text: str) -> str: 37 | text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text) 38 | text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text) 39 | text = regex.sub(r"\pZ+", r" ", text) 40 | return text.strip() 41 | 42 | def _convert_punctuation(self, word: Word) -> str: 43 | if not word.phonemes: 44 | return "" 45 | if word.phonemes[0] in ["‖", "|"]: 46 | return word.text.strip() 47 | 48 | phonemes = "".join(word.phonemes) 49 | # remove modifier characters ˈˌː with regex 50 | phonemes = re.sub(r"[ˈˌː͡]", "", phonemes) 51 | return phonemes.strip() 52 | 53 | def phonemize(self, text: str, espeak: bool = False) -> str: 54 | text_to_phonemize: str = self._normalize_punctuation(text) 55 | sents: List[Sentence] = [ 56 | sent 57 | for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak) 58 | ] 59 | words: List[str] = [ 60 | self._convert_punctuation(word) for word in itertools.chain(*sents) 61 | ] 62 | return " ".join(words) 63 | 64 | def transform(self, phonemes): 65 | # convert phonemes to ids 66 | # dictionary is in symbols.py 67 | return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()] 68 | 69 | 70 | if __name__ == "__main__": 71 | phonemizer = GruutPhonemizer("en-us") 72 | # text -> IPA 73 | phonemes = phonemizer.phonemize("Hello, wor-ld ?") 74 | print("phonemes:", phonemes) 75 | print("len(phonemes):", len(phonemes)) 76 | phoneme_ids = phonemizer.transform(phonemes) 77 | print("phoneme_ids:", phoneme_ids) 78 | print("len(phoneme_ids):", len(phoneme_ids)) 79 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/text_processing/symbols.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py 2 | PAD = "_" 3 | PUNCTUATION = ';:,.!?¡¿—…"«»“” ' 4 | LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 5 | IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 6 | SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS) 7 | SPACE_ID = SYMBOLS.index(" ") 8 | SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)} 9 | ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)} 10 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def str2bool(str): 5 | return True if str.lower() == 'true' else False 6 | 7 | 8 | def get_newest_ckpt(string_list): 9 | # 定义一个正则表达式模式,用于匹配字符串中的数字 10 | pattern = r'epoch=(\d+)-step=(\d+)\.ckpt' 11 | 12 | # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表 13 | extracted_info = [] 14 | for string in string_list: 15 | match = re.match(pattern, string) 16 | if match: 17 | epoch = int(match.group(1)) 18 | step = int(match.group(2)) 19 | extracted_info.append((epoch, step, string)) 20 | # 按照 epoch 后面的数字和 step 后面的数字进行排序 21 | sorted_info = sorted( 22 | extracted_info, key=lambda x: (x[0], x[1]), reverse=True) 23 | # 获取最新的 ckpt 文件名 24 | newest_ckpt = sorted_info[0][2] 25 | return newest_ckpt 26 | 27 | 28 | # 文本存在且不为空时 return True 29 | def check_txt_file(file_path): 30 | try: 31 | with open(file_path, 'r') as file: 32 | text = file.readline().strip() 33 | assert text.strip() != '' 34 | return text 35 | except Exception: 36 | return False 37 | return False 38 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/utils/initialize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Initialize modules for espnet2 neural networks.""" 3 | import torch 4 | from typeguard import check_argument_types 5 | 6 | 7 | def initialize(model: torch.nn.Module, init: str): 8 | """Initialize weights of a neural network module. 9 | 10 | Parameters are initialized using the given method or distribution. 11 | 12 | Custom initialization routines can be implemented into submodules 13 | as function `espnet_initialization_fn` within the custom module. 14 | 15 | Args: 16 | model: Target. 17 | init: Method of initialization. 18 | """ 19 | assert check_argument_types() 20 | print("init with", init) 21 | 22 | # weight init 23 | for p in model.parameters(): 24 | if p.dim() > 1: 25 | if init == "xavier_uniform": 26 | torch.nn.init.xavier_uniform_(p.data) 27 | elif init == "xavier_normal": 28 | torch.nn.init.xavier_normal_(p.data) 29 | elif init == "kaiming_uniform": 30 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") 31 | elif init == "kaiming_normal": 32 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") 33 | else: 34 | raise ValueError("Unknown initialization: " + init) 35 | # bias init 36 | for name, p in model.named_parameters(): 37 | if ".bias" in name and p.dim() == 1: 38 | p.data.zero_() 39 | -------------------------------------------------------------------------------- /GPT_SoVITS/AR/utils/io.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import yaml 5 | 6 | 7 | def load_yaml_config(path): 8 | with open(path) as f: 9 | config = yaml.full_load(f) 10 | return config 11 | 12 | 13 | def save_config_to_yaml(config, path): 14 | assert path.endswith(".yaml") 15 | with open(path, "w") as f: 16 | f.write(yaml.dump(config)) 17 | f.close() 18 | 19 | 20 | def write_args(args, path): 21 | args_dict = dict( 22 | (name, getattr(args, name)) for name in dir(args) if not name.startswith("_") 23 | ) 24 | with open(path, "a") as args_file: 25 | args_file.write("==> torch version: {}\n".format(torch.__version__)) 26 | args_file.write( 27 | "==> cudnn version: {}\n".format(torch.backends.cudnn.version()) 28 | ) 29 | args_file.write("==> Cmd:\n") 30 | args_file.write(str(sys.argv)) 31 | args_file.write("\n==> args:\n") 32 | for k, v in sorted(args_dict.items()): 33 | args_file.write(" %s: %s\n" % (str(k), str(v))) 34 | args_file.close() 35 | -------------------------------------------------------------------------------- /GPT_SoVITS/feature_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cnhubert, whisper_enc 2 | 3 | content_module_map = { 4 | 'cnhubert': cnhubert, 5 | 'whisper': whisper_enc 6 | } -------------------------------------------------------------------------------- /GPT_SoVITS/feature_extractor/cnhubert.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import librosa 4 | import torch 5 | import torch.nn.functional as F 6 | import soundfile as sf 7 | import logging 8 | 9 | logging.getLogger("numba").setLevel(logging.WARNING) 10 | 11 | from transformers import ( 12 | Wav2Vec2FeatureExtractor, 13 | HubertModel, 14 | ) 15 | 16 | import utils 17 | import torch.nn as nn 18 | 19 | cnhubert_base_path = None 20 | 21 | 22 | class CNHubert(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | self.model = HubertModel.from_pretrained(cnhubert_base_path) 26 | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( 27 | cnhubert_base_path 28 | ) 29 | 30 | def forward(self, x): 31 | input_values = self.feature_extractor( 32 | x, return_tensors="pt", sampling_rate=16000 33 | ).input_values.to(x.device) 34 | feats = self.model(input_values)["last_hidden_state"] 35 | return feats 36 | 37 | 38 | # class CNHubertLarge(nn.Module): 39 | # def __init__(self): 40 | # super().__init__() 41 | # self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large") 42 | # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large") 43 | # def forward(self, x): 44 | # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) 45 | # feats = self.model(input_values)["last_hidden_state"] 46 | # return feats 47 | # 48 | # class CVec(nn.Module): 49 | # def __init__(self): 50 | # super().__init__() 51 | # self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base") 52 | # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base") 53 | # def forward(self, x): 54 | # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) 55 | # feats = self.model(input_values)["last_hidden_state"] 56 | # return feats 57 | # 58 | # class cnw2v2base(nn.Module): 59 | # def __init__(self): 60 | # super().__init__() 61 | # self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base") 62 | # self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base") 63 | # def forward(self, x): 64 | # input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) 65 | # feats = self.model(input_values)["last_hidden_state"] 66 | # return feats 67 | 68 | 69 | def get_model(): 70 | model = CNHubert() 71 | model.eval() 72 | return model 73 | 74 | 75 | # def get_large_model(): 76 | # model = CNHubertLarge() 77 | # model.eval() 78 | # return model 79 | # 80 | # def get_model_cvec(): 81 | # model = CVec() 82 | # model.eval() 83 | # return model 84 | # 85 | # def get_model_cnw2v2base(): 86 | # model = cnw2v2base() 87 | # model.eval() 88 | # return model 89 | 90 | 91 | def get_content(hmodel, wav_16k_tensor): 92 | with torch.no_grad(): 93 | feats = hmodel(wav_16k_tensor) 94 | return feats.transpose(1, 2) 95 | 96 | 97 | if __name__ == "__main__": 98 | model = get_model() 99 | src_path = "/Users/Shared/原音频2.wav" 100 | wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000) 101 | model = model 102 | wav_16k_tensor = wav_16k_tensor 103 | feats = get_content(model, wav_16k_tensor) 104 | print(feats.shape) 105 | -------------------------------------------------------------------------------- /GPT_SoVITS/feature_extractor/whisper_enc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_model(): 5 | import whisper 6 | 7 | model = whisper.load_model("small", device="cpu") 8 | 9 | return model.encoder 10 | 11 | 12 | def get_content(model=None, wav_16k_tensor=None): 13 | from whisper import log_mel_spectrogram, pad_or_trim 14 | 15 | dev = next(model.parameters()).device 16 | mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000] 17 | # if torch.cuda.is_available(): 18 | # mel = mel.to(torch.float16) 19 | feature_len = mel.shape[-1] // 2 20 | assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频" 21 | with torch.no_grad(): 22 | feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[ 23 | :1, :feature_len, : 24 | ].transpose(1, 2) 25 | return feature 26 | -------------------------------------------------------------------------------- /GPT_SoVITS/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/module/__init__.py -------------------------------------------------------------------------------- /GPT_SoVITS/module/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def init_weights(m, mean=0.0, std=0.01): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | m.weight.data.normal_(mean, std) 10 | 11 | 12 | def get_padding(kernel_size, dilation=1): 13 | return int((kernel_size * dilation - dilation) / 2) 14 | 15 | 16 | def convert_pad_shape(pad_shape): 17 | l = pad_shape[::-1] 18 | pad_shape = [item for sublist in l for item in sublist] 19 | return pad_shape 20 | 21 | 22 | def intersperse(lst, item): 23 | result = [item] * (len(lst) * 2 + 1) 24 | result[1::2] = lst 25 | return result 26 | 27 | 28 | def kl_divergence(m_p, logs_p, m_q, logs_q): 29 | """KL(P||Q)""" 30 | kl = (logs_q - logs_p) - 0.5 31 | kl += ( 32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 33 | ) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 68 | position = torch.arange(length, dtype=torch.float) 69 | num_timescales = channels // 2 70 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 71 | num_timescales - 1 72 | ) 73 | inv_timescales = min_timescale * torch.exp( 74 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 75 | ) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | l = pad_shape[::-1] 112 | pad_shape = [item for sublist in l for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | device = duration.device 134 | 135 | b, _, t_y, t_x = mask.shape 136 | cum_duration = torch.cumsum(duration, -1) 137 | 138 | cum_duration_flat = cum_duration.view(b * t_x) 139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 140 | path = path.view(b, t_x, t_y) 141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 142 | path = path.unsqueeze(1).transpose(2, 3) * mask 143 | return path 144 | 145 | 146 | def clip_grad_value_(parameters, clip_value, norm_type=2): 147 | if isinstance(parameters, torch.Tensor): 148 | parameters = [parameters] 149 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 150 | norm_type = float(norm_type) 151 | if clip_value is not None: 152 | clip_value = float(clip_value) 153 | 154 | total_norm = 0 155 | for p in parameters: 156 | param_norm = p.grad.data.norm(norm_type) 157 | total_norm += param_norm.item() ** norm_type 158 | if clip_value is not None: 159 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 160 | total_norm = total_norm ** (1.0 / norm_type) 161 | return total_norm 162 | 163 | 164 | def squeeze(x, x_mask=None, n_sqz=2): 165 | b, c, t = x.size() 166 | 167 | t = (t // n_sqz) * n_sqz 168 | x = x[:, :, :t] 169 | x_sqz = x.view(b, c, t // n_sqz, n_sqz) 170 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) 171 | 172 | if x_mask is not None: 173 | x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz] 174 | else: 175 | x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) 176 | return x_sqz * x_mask, x_mask 177 | 178 | 179 | def unsqueeze(x, x_mask=None, n_sqz=2): 180 | b, c, t = x.size() 181 | 182 | x_unsqz = x.view(b, n_sqz, c // n_sqz, t) 183 | x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) 184 | 185 | if x_mask is not None: 186 | x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) 187 | else: 188 | x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) 189 | return x_unsqz * x_mask, x_mask 190 | -------------------------------------------------------------------------------- /GPT_SoVITS/module/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1 - dr) ** 2) 26 | g_loss = torch.mean(dg**2) 27 | loss += r_loss + g_loss 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1 - dg) ** 2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | 63 | 64 | def mle_loss(z, m, logs, logdet, mask): 65 | l = torch.sum(logs) + 0.5 * torch.sum( 66 | torch.exp(-2 * logs) * ((z - m) ** 2) 67 | ) # neg normal likelihood w/o the constant term 68 | l = l - torch.sum(logdet) # log jacobian determinant 69 | l = l / torch.sum( 70 | torch.ones_like(z) * mask 71 | ) # averaging across batch, channel and time axes 72 | l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term 73 | return l 74 | -------------------------------------------------------------------------------- /GPT_SoVITS/module/mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.0: 53 | print("min value is ", torch.min(y)) 54 | if torch.max(y) > 1.0: 55 | print("max value is ", torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + "_" + str(y.device) 59 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 62 | dtype=y.dtype, device=y.device 63 | ) 64 | 65 | y = torch.nn.functional.pad( 66 | y.unsqueeze(1), 67 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 68 | mode="reflect", 69 | ) 70 | y = y.squeeze(1) 71 | spec = torch.stft( 72 | y, 73 | n_fft, 74 | hop_length=hop_size, 75 | win_length=win_size, 76 | window=hann_window[wnsize_dtype_device], 77 | center=center, 78 | pad_mode="reflect", 79 | normalized=False, 80 | onesided=True, 81 | return_complex=False, 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 85 | return spec 86 | 87 | 88 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 89 | global mel_basis 90 | dtype_device = str(spec.dtype) + "_" + str(spec.device) 91 | fmax_dtype_device = str(fmax) + "_" + dtype_device 92 | if fmax_dtype_device not in mel_basis: 93 | mel = librosa_mel_fn( 94 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 95 | ) 96 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 97 | dtype=spec.dtype, device=spec.device 98 | ) 99 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 100 | spec = spectral_normalize_torch(spec) 101 | return spec 102 | 103 | 104 | def mel_spectrogram_torch( 105 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 106 | ): 107 | if torch.min(y) < -1.0: 108 | print("min value is ", torch.min(y)) 109 | if torch.max(y) > 1.0: 110 | print("max value is ", torch.max(y)) 111 | 112 | global mel_basis, hann_window 113 | dtype_device = str(y.dtype) + "_" + str(y.device) 114 | fmax_dtype_device = str(fmax) + "_" + dtype_device 115 | wnsize_dtype_device = str(win_size) + "_" + dtype_device 116 | if fmax_dtype_device not in mel_basis: 117 | mel = librosa_mel_fn( 118 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 119 | ) 120 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 121 | dtype=y.dtype, device=y.device 122 | ) 123 | if wnsize_dtype_device not in hann_window: 124 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 125 | dtype=y.dtype, device=y.device 126 | ) 127 | 128 | y = torch.nn.functional.pad( 129 | y.unsqueeze(1), 130 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 131 | mode="reflect", 132 | ) 133 | y = y.squeeze(1) 134 | 135 | spec = torch.stft( 136 | y, 137 | n_fft, 138 | hop_length=hop_size, 139 | win_length=win_size, 140 | window=hann_window[wnsize_dtype_device], 141 | center=center, 142 | pad_mode="reflect", 143 | normalized=False, 144 | onesided=True, 145 | return_complex=False, 146 | ) 147 | 148 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 149 | 150 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 151 | spec = spectral_normalize_torch(spec) 152 | 153 | return spec 154 | -------------------------------------------------------------------------------- /GPT_SoVITS/module/mrte_model.py: -------------------------------------------------------------------------------- 1 | # This is Multi-reference timbre encoder 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils import remove_weight_norm, weight_norm 6 | from module.attentions import MultiHeadAttention 7 | 8 | 9 | class MRTE(nn.Module): 10 | def __init__( 11 | self, 12 | content_enc_channels=192, 13 | hidden_size=512, 14 | out_channels=192, 15 | kernel_size=5, 16 | n_heads=4, 17 | ge_layer=2, 18 | ): 19 | super(MRTE, self).__init__() 20 | self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads) 21 | self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) 22 | self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1) 23 | self.c_post = nn.Conv1d(hidden_size, out_channels, 1) 24 | 25 | def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None): 26 | if ge == None: 27 | ge = 0 28 | attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) 29 | 30 | ssl_enc = self.c_pre(ssl_enc * ssl_mask) 31 | text_enc = self.text_pre(text * text_mask) 32 | if test != None: 33 | if test == 0: 34 | x = ( 35 | self.cross_attention( 36 | ssl_enc * ssl_mask, text_enc * text_mask, attn_mask 37 | ) 38 | + ssl_enc 39 | + ge 40 | ) 41 | elif test == 1: 42 | x = ssl_enc + ge 43 | elif test == 2: 44 | x = ( 45 | self.cross_attention( 46 | ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask 47 | ) 48 | + ge 49 | ) 50 | else: 51 | raise ValueError("test should be 0,1,2") 52 | else: 53 | x = ( 54 | self.cross_attention( 55 | ssl_enc * ssl_mask, text_enc * text_mask, attn_mask 56 | ) 57 | + ssl_enc 58 | + ge 59 | ) 60 | x = self.c_post(x * ssl_mask) 61 | return x 62 | 63 | 64 | class SpeakerEncoder(torch.nn.Module): 65 | def __init__( 66 | self, 67 | mel_n_channels=80, 68 | model_num_layers=2, 69 | model_hidden_size=256, 70 | model_embedding_size=256, 71 | ): 72 | super(SpeakerEncoder, self).__init__() 73 | self.lstm = nn.LSTM( 74 | mel_n_channels, model_hidden_size, model_num_layers, batch_first=True 75 | ) 76 | self.linear = nn.Linear(model_hidden_size, model_embedding_size) 77 | self.relu = nn.ReLU() 78 | 79 | def forward(self, mels): 80 | self.lstm.flatten_parameters() 81 | _, (hidden, _) = self.lstm(mels.transpose(-1, -2)) 82 | embeds_raw = self.relu(self.linear(hidden[-1])) 83 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 84 | 85 | 86 | class MELEncoder(nn.Module): 87 | def __init__( 88 | self, 89 | in_channels, 90 | out_channels, 91 | hidden_channels, 92 | kernel_size, 93 | dilation_rate, 94 | n_layers, 95 | ): 96 | super().__init__() 97 | self.in_channels = in_channels 98 | self.out_channels = out_channels 99 | self.hidden_channels = hidden_channels 100 | self.kernel_size = kernel_size 101 | self.dilation_rate = dilation_rate 102 | self.n_layers = n_layers 103 | 104 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 105 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers) 106 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 107 | 108 | def forward(self, x): 109 | # print(x.shape,x_lengths.shape) 110 | x = self.pre(x) 111 | x = self.enc(x) 112 | x = self.proj(x) 113 | return x 114 | 115 | 116 | class WN(torch.nn.Module): 117 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers): 118 | super(WN, self).__init__() 119 | assert kernel_size % 2 == 1 120 | self.hidden_channels = hidden_channels 121 | self.kernel_size = kernel_size 122 | self.dilation_rate = dilation_rate 123 | self.n_layers = n_layers 124 | 125 | self.in_layers = torch.nn.ModuleList() 126 | self.res_skip_layers = torch.nn.ModuleList() 127 | 128 | for i in range(n_layers): 129 | dilation = dilation_rate**i 130 | padding = int((kernel_size * dilation - dilation) / 2) 131 | in_layer = nn.Conv1d( 132 | hidden_channels, 133 | 2 * hidden_channels, 134 | kernel_size, 135 | dilation=dilation, 136 | padding=padding, 137 | ) 138 | in_layer = weight_norm(in_layer) 139 | self.in_layers.append(in_layer) 140 | 141 | # last one is not necessary 142 | if i < n_layers - 1: 143 | res_skip_channels = 2 * hidden_channels 144 | else: 145 | res_skip_channels = hidden_channels 146 | 147 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 148 | res_skip_layer = weight_norm(res_skip_layer, name="weight") 149 | self.res_skip_layers.append(res_skip_layer) 150 | 151 | def forward(self, x): 152 | output = torch.zeros_like(x) 153 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 154 | 155 | for i in range(self.n_layers): 156 | x_in = self.in_layers[i](x) 157 | 158 | acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor) 159 | 160 | res_skip_acts = self.res_skip_layers[i](acts) 161 | if i < self.n_layers - 1: 162 | res_acts = res_skip_acts[:, : self.hidden_channels, :] 163 | x = x + res_acts 164 | output = output + res_skip_acts[:, self.hidden_channels :, :] 165 | else: 166 | output = output + res_skip_acts 167 | return output 168 | 169 | def remove_weight_norm(self): 170 | for l in self.in_layers: 171 | remove_weight_norm(l) 172 | for l in self.res_skip_layers: 173 | remove_weight_norm(l) 174 | 175 | 176 | @torch.jit.script 177 | def fused_add_tanh_sigmoid_multiply(input, n_channels): 178 | n_channels_int = n_channels[0] 179 | t_act = torch.tanh(input[:, :n_channels_int, :]) 180 | s_act = torch.sigmoid(input[:, n_channels_int:, :]) 181 | acts = t_act * s_act 182 | return acts 183 | 184 | 185 | if __name__ == "__main__": 186 | content_enc = torch.randn(3, 192, 100) 187 | content_mask = torch.ones(3, 1, 100) 188 | ref_mel = torch.randn(3, 128, 30) 189 | ref_mask = torch.ones(3, 1, 30) 190 | model = MRTE() 191 | out = model(content_enc, content_mask, ref_mel, ref_mask) 192 | print(out.shape) 193 | -------------------------------------------------------------------------------- /GPT_SoVITS/module/quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Residual vector quantizer implementation.""" 8 | 9 | from dataclasses import dataclass, field 10 | import math 11 | import typing as tp 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from module.core_vq import ResidualVectorQuantization 17 | 18 | 19 | @dataclass 20 | class QuantizedResult: 21 | quantized: torch.Tensor 22 | codes: torch.Tensor 23 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 24 | penalty: tp.Optional[torch.Tensor] = None 25 | metrics: dict = field(default_factory=dict) 26 | 27 | 28 | class ResidualVectorQuantizer(nn.Module): 29 | """Residual Vector Quantizer. 30 | Args: 31 | dimension (int): Dimension of the codebooks. 32 | n_q (int): Number of residual vector quantizers used. 33 | bins (int): Codebook size. 34 | decay (float): Decay for exponential moving average over the codebooks. 35 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 36 | kmeans_iters (int): Number of iterations used for kmeans initialization. 37 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 38 | that have an exponential moving average cluster size less than the specified threshold with 39 | randomly selected vector from the current batch. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | dimension: int = 256, 45 | n_q: int = 8, 46 | bins: int = 1024, 47 | decay: float = 0.99, 48 | kmeans_init: bool = True, 49 | kmeans_iters: int = 50, 50 | threshold_ema_dead_code: int = 2, 51 | ): 52 | super().__init__() 53 | self.n_q = n_q 54 | self.dimension = dimension 55 | self.bins = bins 56 | self.decay = decay 57 | self.kmeans_init = kmeans_init 58 | self.kmeans_iters = kmeans_iters 59 | self.threshold_ema_dead_code = threshold_ema_dead_code 60 | self.vq = ResidualVectorQuantization( 61 | dim=self.dimension, 62 | codebook_size=self.bins, 63 | num_quantizers=self.n_q, 64 | decay=self.decay, 65 | kmeans_init=self.kmeans_init, 66 | kmeans_iters=self.kmeans_iters, 67 | threshold_ema_dead_code=self.threshold_ema_dead_code, 68 | ) 69 | 70 | def forward( 71 | self, 72 | x: torch.Tensor, 73 | n_q: tp.Optional[int] = None, 74 | layers: tp.Optional[list] = None, 75 | ) -> QuantizedResult: 76 | """Residual vector quantization on the given input tensor. 77 | Args: 78 | x (torch.Tensor): Input tensor. 79 | n_q (int): Number of quantizer used to quantize. Default: All quantizers. 80 | layers (list): Layer that need to return quantized. Defalt: None. 81 | Returns: 82 | QuantizedResult: 83 | The quantized (or approximately quantized) representation with 84 | the associated numbert quantizers and layer quantized required to return. 85 | """ 86 | n_q = n_q if n_q else self.n_q 87 | if layers and max(layers) >= n_q: 88 | raise ValueError( 89 | f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B." 90 | ) 91 | quantized, codes, commit_loss, quantized_list = self.vq( 92 | x, n_q=n_q, layers=layers 93 | ) 94 | return quantized, codes, torch.mean(commit_loss), quantized_list 95 | 96 | def encode( 97 | self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None 98 | ) -> torch.Tensor: 99 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 100 | The RVQ encode method sets the appropriate number of quantizer to use 101 | and returns indices for each quantizer. 102 | Args: 103 | x (torch.Tensor): Input tensor. 104 | n_q (int): Number of quantizer used to quantize. Default: All quantizers. 105 | st (int): Start to encode input from which layers. Default: 0. 106 | """ 107 | n_q = n_q if n_q else self.n_q 108 | st = st or 0 109 | codes = self.vq.encode(x, n_q=n_q, st=st) 110 | return codes 111 | 112 | def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: 113 | """Decode the given codes to the quantized representation. 114 | Args: 115 | codes (torch.Tensor): Input indices for each quantizer. 116 | st (int): Start to decode input codes from which layers. Default: 0. 117 | """ 118 | quantized = self.vq.decode(codes, st=st) 119 | return quantized 120 | -------------------------------------------------------------------------------- /GPT_SoVITS/module/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform( 13 | inputs, 14 | unnormalized_widths, 15 | unnormalized_heights, 16 | unnormalized_derivatives, 17 | inverse=False, 18 | tails=None, 19 | tail_bound=1.0, 20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 22 | min_derivative=DEFAULT_MIN_DERIVATIVE, 23 | ): 24 | if tails is None: 25 | spline_fn = rational_quadratic_spline 26 | spline_kwargs = {} 27 | else: 28 | spline_fn = unconstrained_rational_quadratic_spline 29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound} 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations, inputs, eps=1e-6): 46 | bin_locations[..., -1] += eps 47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 48 | 49 | 50 | def unconstrained_rational_quadratic_spline( 51 | inputs, 52 | unnormalized_widths, 53 | unnormalized_heights, 54 | unnormalized_derivatives, 55 | inverse=False, 56 | tails="linear", 57 | tail_bound=1.0, 58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 60 | min_derivative=DEFAULT_MIN_DERIVATIVE, 61 | ): 62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 63 | outside_interval_mask = ~inside_interval_mask 64 | 65 | outputs = torch.zeros_like(inputs) 66 | logabsdet = torch.zeros_like(inputs) 67 | 68 | if tails == "linear": 69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 70 | constant = np.log(np.exp(1 - min_derivative) - 1) 71 | unnormalized_derivatives[..., 0] = constant 72 | unnormalized_derivatives[..., -1] = constant 73 | 74 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 75 | logabsdet[outside_interval_mask] = 0 76 | else: 77 | raise RuntimeError("{} tails are not implemented.".format(tails)) 78 | 79 | ( 80 | outputs[inside_interval_mask], 81 | logabsdet[inside_interval_mask], 82 | ) = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, 89 | right=tail_bound, 90 | bottom=-tail_bound, 91 | top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative, 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline( 101 | inputs, 102 | unnormalized_widths, 103 | unnormalized_heights, 104 | unnormalized_derivatives, 105 | inverse=False, 106 | left=0.0, 107 | right=1.0, 108 | bottom=0.0, 109 | top=1.0, 110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 112 | min_derivative=DEFAULT_MIN_DERIVATIVE, 113 | ): 114 | if torch.min(inputs) < left or torch.max(inputs) > right: 115 | raise ValueError("Input to a transform is not within its domain") 116 | 117 | num_bins = unnormalized_widths.shape[-1] 118 | 119 | if min_bin_width * num_bins > 1.0: 120 | raise ValueError("Minimal bin width too large for the number of bins") 121 | if min_bin_height * num_bins > 1.0: 122 | raise ValueError("Minimal bin height too large for the number of bins") 123 | 124 | widths = F.softmax(unnormalized_widths, dim=-1) 125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 126 | cumwidths = torch.cumsum(widths, dim=-1) 127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) 128 | cumwidths = (right - left) * cumwidths + left 129 | cumwidths[..., 0] = left 130 | cumwidths[..., -1] = right 131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 132 | 133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 134 | 135 | heights = F.softmax(unnormalized_heights, dim=-1) 136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 137 | cumheights = torch.cumsum(heights, dim=-1) 138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) 139 | cumheights = (top - bottom) * cumheights + bottom 140 | cumheights[..., 0] = bottom 141 | cumheights[..., -1] = top 142 | heights = cumheights[..., 1:] - cumheights[..., :-1] 143 | 144 | if inverse: 145 | bin_idx = searchsorted(cumheights, inputs)[..., None] 146 | else: 147 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 148 | 149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 151 | 152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 153 | delta = heights / widths 154 | input_delta = delta.gather(-1, bin_idx)[..., 0] 155 | 156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 158 | 159 | input_heights = heights.gather(-1, bin_idx)[..., 0] 160 | 161 | if inverse: 162 | a = (inputs - input_cumheights) * ( 163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 164 | ) + input_heights * (input_delta - input_derivatives) 165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * ( 166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta 167 | ) 168 | c = -input_delta * (inputs - input_cumheights) 169 | 170 | discriminant = b.pow(2) - 4 * a * c 171 | assert (discriminant >= 0).all() 172 | 173 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 174 | outputs = root * input_bin_widths + input_cumwidths 175 | 176 | theta_one_minus_theta = root * (1 - root) 177 | denominator = input_delta + ( 178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 179 | * theta_one_minus_theta 180 | ) 181 | derivative_numerator = input_delta.pow(2) * ( 182 | input_derivatives_plus_one * root.pow(2) 183 | + 2 * input_delta * theta_one_minus_theta 184 | + input_derivatives * (1 - root).pow(2) 185 | ) 186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 187 | 188 | return outputs, -logabsdet 189 | else: 190 | theta = (inputs - input_cumwidths) / input_bin_widths 191 | theta_one_minus_theta = theta * (1 - theta) 192 | 193 | numerator = input_heights * ( 194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta 195 | ) 196 | denominator = input_delta + ( 197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta) 198 | * theta_one_minus_theta 199 | ) 200 | outputs = input_cumheights + numerator / denominator 201 | 202 | derivative_numerator = input_delta.pow(2) * ( 203 | input_derivatives_plus_one * theta.pow(2) 204 | + 2 * input_delta * theta_one_minus_theta 205 | + input_derivatives * (1 - theta).pow(2) 206 | ) 207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 208 | 209 | return outputs, logabsdet 210 | -------------------------------------------------------------------------------- /GPT_SoVITS/my_utils.py: -------------------------------------------------------------------------------- 1 | import ffmpeg 2 | import numpy as np 3 | 4 | 5 | def load_audio(file, sr): 6 | try: 7 | # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 8 | # This launches a subprocess to decode audio while down-mixing and resampling as necessary. 9 | # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. 10 | file = ( 11 | file.strip(" ").strip('"').strip("\n").strip('"').strip(" ") 12 | ) # 防止小白拷路径头尾带了空格和"和回车 13 | out, _ = ( 14 | ffmpeg.input(file, threads=0) 15 | .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) 16 | .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) 17 | ) 18 | except Exception as e: 19 | raise RuntimeError(f"Failed to load audio: {e}") 20 | 21 | return np.frombuffer(out, np.float32).flatten() 22 | -------------------------------------------------------------------------------- /GPT_SoVITS/pretrained_models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /GPT_SoVITS/text/__init__.py: -------------------------------------------------------------------------------- 1 | from text.symbols import * 2 | 3 | 4 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 5 | 6 | def cleaned_text_to_sequence(cleaned_text): 7 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 8 | Args: 9 | text: string to convert to a sequence 10 | Returns: 11 | List of integers corresponding to the symbols in the text 12 | ''' 13 | phones = [_symbol_to_id[symbol] for symbol in cleaned_text] 14 | return phones 15 | 16 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/chinese.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import re 4 | 5 | import cn2an 6 | from pypinyin import lazy_pinyin, Style 7 | 8 | from text.symbols import punctuation 9 | from text.tone_sandhi import ToneSandhi 10 | from text.zh_normalization.text_normlization import TextNormalizer 11 | 12 | normalizer = lambda x: cn2an.transform(x, "an2cn") 13 | 14 | current_file_path = os.path.dirname(__file__) 15 | pinyin_to_symbol_map = { 16 | line.split("\t")[0]: line.strip().split("\t")[1] 17 | for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines() 18 | } 19 | 20 | import jieba_fast.posseg as psg 21 | 22 | 23 | rep_map = { 24 | ":": ",", 25 | ";": ",", 26 | ",": ",", 27 | "。": ".", 28 | "!": "!", 29 | "?": "?", 30 | "\n": ".", 31 | "·": ",", 32 | "、": ",", 33 | # "...": "…", 34 | "$": ".", 35 | "/": ",", 36 | "—": "-", 37 | } 38 | 39 | tone_modifier = ToneSandhi() 40 | 41 | 42 | def replace_punctuation(text): 43 | text = text.replace("嗯", "恩").replace("呣", "母") 44 | pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys())) 45 | 46 | replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) 47 | 48 | replaced_text = re.sub( 49 | r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text 50 | ) 51 | 52 | return replaced_text 53 | 54 | 55 | def g2p(text): 56 | pattern = r"(?<=[{0}])\s*".format("".join(punctuation)) 57 | sentences = [i for i in re.split(pattern, text) if i.strip() != ""] 58 | phones, word2ph = _g2p(sentences) 59 | return phones, word2ph 60 | 61 | 62 | def _get_initials_finals(word): 63 | initials = [] 64 | finals = [] 65 | orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) 66 | orig_finals = lazy_pinyin( 67 | word, neutral_tone_with_five=True, style=Style.FINALS_TONE3 68 | ) 69 | for c, v in zip(orig_initials, orig_finals): 70 | initials.append(c) 71 | finals.append(v) 72 | return initials, finals 73 | 74 | 75 | def _g2p(segments): 76 | phones_list = [] 77 | word2ph = [] 78 | for seg in segments: 79 | pinyins = [] 80 | # Replace all English words in the sentence 81 | seg = re.sub("[a-zA-Z]+", "", seg) 82 | seg_cut = psg.lcut(seg) 83 | initials = [] 84 | finals = [] 85 | seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) 86 | for word, pos in seg_cut: 87 | if pos == "eng": 88 | continue 89 | sub_initials, sub_finals = _get_initials_finals(word) 90 | sub_finals = tone_modifier.modified_tone(word, pos, sub_finals) 91 | initials.append(sub_initials) 92 | finals.append(sub_finals) 93 | 94 | # assert len(sub_initials) == len(sub_finals) == len(word) 95 | initials = sum(initials, []) 96 | finals = sum(finals, []) 97 | # 98 | for c, v in zip(initials, finals): 99 | raw_pinyin = c + v 100 | # NOTE: post process for pypinyin outputs 101 | # we discriminate i, ii and iii 102 | if c == v: 103 | assert c in punctuation 104 | phone = [c] 105 | word2ph.append(1) 106 | else: 107 | v_without_tone = v[:-1] 108 | tone = v[-1] 109 | 110 | pinyin = c + v_without_tone 111 | assert tone in "12345" 112 | 113 | if c: 114 | # 多音节 115 | v_rep_map = { 116 | "uei": "ui", 117 | "iou": "iu", 118 | "uen": "un", 119 | } 120 | if v_without_tone in v_rep_map.keys(): 121 | pinyin = c + v_rep_map[v_without_tone] 122 | else: 123 | # 单音节 124 | pinyin_rep_map = { 125 | "ing": "ying", 126 | "i": "yi", 127 | "in": "yin", 128 | "u": "wu", 129 | } 130 | if pinyin in pinyin_rep_map.keys(): 131 | pinyin = pinyin_rep_map[pinyin] 132 | else: 133 | single_rep_map = { 134 | "v": "yu", 135 | "e": "e", 136 | "i": "y", 137 | "u": "w", 138 | } 139 | if pinyin[0] in single_rep_map.keys(): 140 | pinyin = single_rep_map[pinyin[0]] + pinyin[1:] 141 | 142 | assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) 143 | new_c, new_v = pinyin_to_symbol_map[pinyin].split(" ") 144 | new_v = new_v + tone 145 | phone = [new_c, new_v] 146 | word2ph.append(len(phone)) 147 | 148 | phones_list += phone 149 | return phones_list, word2ph 150 | 151 | 152 | def text_normalize(text): 153 | # https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization 154 | tx = TextNormalizer() 155 | sentences = tx.normalize(text) 156 | dest_text = "" 157 | for sentence in sentences: 158 | dest_text += replace_punctuation(sentence) 159 | return dest_text 160 | 161 | 162 | if __name__ == "__main__": 163 | text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" 164 | text = "呣呣呣~就是…大人的鼹鼠党吧?" 165 | text = "你好" 166 | text = text_normalize(text) 167 | print(g2p(text)) 168 | 169 | 170 | # # 示例用法 171 | # text = "这是一个示例文本:,你好!这是一个测试..." 172 | # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 173 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/cleaner.py: -------------------------------------------------------------------------------- 1 | from text import chinese, japanese, cleaned_text_to_sequence, symbols, english 2 | 3 | language_module_map = {"zh": chinese, "ja": japanese, "en": english} 4 | special = [ 5 | # ("%", "zh", "SP"), 6 | ("¥", "zh", "SP2"), 7 | ("^", "zh", "SP3"), 8 | # ('@', 'zh', "SP4")#不搞鬼畜了,和第二版保持一致吧 9 | ] 10 | 11 | 12 | def clean_text(text, language): 13 | if(language not in language_module_map): 14 | language="en" 15 | text=" " 16 | for special_s, special_l, target_symbol in special: 17 | if special_s in text and language == special_l: 18 | return clean_special(text, language, special_s, target_symbol) 19 | language_module = language_module_map[language] 20 | norm_text = language_module.text_normalize(text) 21 | if language == "zh": 22 | phones, word2ph = language_module.g2p(norm_text) 23 | assert len(phones) == sum(word2ph) 24 | assert len(norm_text) == len(word2ph) 25 | else: 26 | phones = language_module.g2p(norm_text) 27 | word2ph = None 28 | 29 | for ph in phones: 30 | assert ph in symbols 31 | return phones, word2ph, norm_text 32 | 33 | 34 | def clean_special(text, language, special_s, target_symbol): 35 | """ 36 | 特殊静音段sp符号处理 37 | """ 38 | text = text.replace(special_s, ",") 39 | language_module = language_module_map[language] 40 | norm_text = language_module.text_normalize(text) 41 | phones = language_module.g2p(norm_text) 42 | new_ph = [] 43 | for ph in phones[0]: 44 | assert ph in symbols 45 | if ph == ",": 46 | new_ph.append(target_symbol) 47 | else: 48 | new_ph.append(ph) 49 | return new_ph, phones[1], norm_text 50 | 51 | 52 | def text_to_sequence(text, language): 53 | phones = clean_text(text) 54 | return cleaned_text_to_sequence(phones) 55 | 56 | 57 | if __name__ == "__main__": 58 | print(clean_text("你好%啊啊啊额、还是到付红四方。", "zh")) 59 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/cmudict_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/text/cmudict_cache.pickle -------------------------------------------------------------------------------- /GPT_SoVITS/text/engdict-hot.rep: -------------------------------------------------------------------------------- 1 | CHATGPT CH AE1 T JH IY1 P IY1 T IY1 -------------------------------------------------------------------------------- /GPT_SoVITS/text/engdict_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/GPT_SoVITS/text/engdict_cache.pickle -------------------------------------------------------------------------------- /GPT_SoVITS/text/english.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import re 4 | from g2p_en import G2p 5 | 6 | from string import punctuation 7 | 8 | from text import symbols 9 | 10 | current_file_path = os.path.dirname(__file__) 11 | CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep") 12 | CMU_DICT_FAST_PATH = os.path.join(current_file_path, "cmudict-fast.rep") 13 | CMU_DICT_HOT_PATH = os.path.join(current_file_path, "engdict-hot.rep") 14 | CACHE_PATH = os.path.join(current_file_path, "engdict_cache.pickle") 15 | _g2p = G2p() 16 | 17 | arpa = { 18 | "AH0", 19 | "S", 20 | "AH1", 21 | "EY2", 22 | "AE2", 23 | "EH0", 24 | "OW2", 25 | "UH0", 26 | "NG", 27 | "B", 28 | "G", 29 | "AY0", 30 | "M", 31 | "AA0", 32 | "F", 33 | "AO0", 34 | "ER2", 35 | "UH1", 36 | "IY1", 37 | "AH2", 38 | "DH", 39 | "IY0", 40 | "EY1", 41 | "IH0", 42 | "K", 43 | "N", 44 | "W", 45 | "IY2", 46 | "T", 47 | "AA1", 48 | "ER1", 49 | "EH2", 50 | "OY0", 51 | "UH2", 52 | "UW1", 53 | "Z", 54 | "AW2", 55 | "AW1", 56 | "V", 57 | "UW2", 58 | "AA2", 59 | "ER", 60 | "AW0", 61 | "UW0", 62 | "R", 63 | "OW1", 64 | "EH1", 65 | "ZH", 66 | "AE0", 67 | "IH2", 68 | "IH", 69 | "Y", 70 | "JH", 71 | "P", 72 | "AY1", 73 | "EY0", 74 | "OY2", 75 | "TH", 76 | "HH", 77 | "D", 78 | "ER0", 79 | "CH", 80 | "AO1", 81 | "AE1", 82 | "AO2", 83 | "OY1", 84 | "AY2", 85 | "IH1", 86 | "OW0", 87 | "L", 88 | "SH", 89 | } 90 | 91 | 92 | def replace_phs(phs): 93 | rep_map = {";": ",", ":": ",", "'": "-", '"': "-"} 94 | phs_new = [] 95 | for ph in phs: 96 | if ph in symbols: 97 | phs_new.append(ph) 98 | elif ph in rep_map.keys(): 99 | phs_new.append(rep_map[ph]) 100 | else: 101 | print("ph not in symbols: ", ph) 102 | return phs_new 103 | 104 | 105 | def read_dict(): 106 | g2p_dict = {} 107 | start_line = 49 108 | with open(CMU_DICT_PATH) as f: 109 | line = f.readline() 110 | line_index = 1 111 | while line: 112 | if line_index >= start_line: 113 | line = line.strip() 114 | word_split = line.split(" ") 115 | word = word_split[0] 116 | 117 | syllable_split = word_split[1].split(" - ") 118 | g2p_dict[word] = [] 119 | for syllable in syllable_split: 120 | phone_split = syllable.split(" ") 121 | g2p_dict[word].append(phone_split) 122 | 123 | line_index = line_index + 1 124 | line = f.readline() 125 | 126 | return g2p_dict 127 | 128 | 129 | def read_dict_new(): 130 | g2p_dict = {} 131 | with open(CMU_DICT_PATH) as f: 132 | line = f.readline() 133 | line_index = 1 134 | while line: 135 | if line_index >= 49: 136 | line = line.strip() 137 | word_split = line.split(" ") 138 | word = word_split[0] 139 | 140 | syllable_split = word_split[1].split(" - ") 141 | g2p_dict[word] = [] 142 | for syllable in syllable_split: 143 | phone_split = syllable.split(" ") 144 | g2p_dict[word].append(phone_split) 145 | 146 | line_index = line_index + 1 147 | line = f.readline() 148 | 149 | with open(CMU_DICT_FAST_PATH) as f: 150 | line = f.readline() 151 | line_index = 1 152 | while line: 153 | if line_index >= 0: 154 | line = line.strip() 155 | word_split = line.split(" ") 156 | word = word_split[0] 157 | if word not in g2p_dict: 158 | g2p_dict[word] = [] 159 | g2p_dict[word].append(word_split[1:]) 160 | 161 | line_index = line_index + 1 162 | line = f.readline() 163 | 164 | with open(CMU_DICT_HOT_PATH) as f: 165 | line = f.readline() 166 | line_index = 1 167 | while line: 168 | if line_index >= 0: 169 | line = line.strip() 170 | word_split = line.split(" ") 171 | word = word_split[0] 172 | #if word not in g2p_dict: 173 | g2p_dict[word] = [] 174 | g2p_dict[word].append(word_split[1:]) 175 | 176 | line_index = line_index + 1 177 | line = f.readline() 178 | 179 | return g2p_dict 180 | 181 | 182 | def cache_dict(g2p_dict, file_path): 183 | with open(file_path, "wb") as pickle_file: 184 | pickle.dump(g2p_dict, pickle_file) 185 | 186 | 187 | def get_dict(): 188 | if os.path.exists(CACHE_PATH): 189 | with open(CACHE_PATH, "rb") as pickle_file: 190 | g2p_dict = pickle.load(pickle_file) 191 | else: 192 | g2p_dict = read_dict_new() 193 | cache_dict(g2p_dict, CACHE_PATH) 194 | 195 | return g2p_dict 196 | 197 | 198 | eng_dict = get_dict() 199 | 200 | 201 | def text_normalize(text): 202 | # todo: eng text normalize 203 | return text.replace(";", ",") 204 | 205 | 206 | def g2p(text): 207 | phones = [] 208 | words = re.split(r"([,;.\-\?\!\s+])", text) 209 | for w in words: 210 | if w.upper() in eng_dict: 211 | phns = eng_dict[w.upper()] 212 | for ph in phns: 213 | phones += ph 214 | else: 215 | phone_list = list(filter(lambda p: p != " ", _g2p(w))) 216 | for ph in phone_list: 217 | if ph in arpa: 218 | phones.append(ph) 219 | else: 220 | phones.append(ph) 221 | 222 | return replace_phs(phones) 223 | 224 | 225 | if __name__ == "__main__": 226 | # print(get_dict()) 227 | print(g2p("hello")) 228 | print(g2p("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")) 229 | # all_phones = set() 230 | # for k, syllables in eng_dict.items(): 231 | # for group in syllables: 232 | # for ph in group: 233 | # all_phones.add(ph) 234 | # print(all_phones) 235 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/japanese.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py 2 | import re 3 | import sys 4 | 5 | import pyopenjtalk 6 | 7 | 8 | from text import symbols 9 | # Regular expression matching Japanese without punctuation marks: 10 | _japanese_characters = re.compile( 11 | r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" 12 | ) 13 | 14 | # Regular expression matching non-Japanese characters or punctuation marks: 15 | _japanese_marks = re.compile( 16 | r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]" 17 | ) 18 | 19 | # List of (symbol, Japanese) pairs for marks: 20 | _symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]] 21 | 22 | 23 | # List of (consonant, sokuon) pairs: 24 | _real_sokuon = [ 25 | (re.compile("%s" % x[0]), x[1]) 26 | for x in [ 27 | (r"Q([↑↓]*[kg])", r"k#\1"), 28 | (r"Q([↑↓]*[tdjʧ])", r"t#\1"), 29 | (r"Q([↑↓]*[sʃ])", r"s\1"), 30 | (r"Q([↑↓]*[pb])", r"p#\1"), 31 | ] 32 | ] 33 | 34 | # List of (consonant, hatsuon) pairs: 35 | _real_hatsuon = [ 36 | (re.compile("%s" % x[0]), x[1]) 37 | for x in [ 38 | (r"N([↑↓]*[pbm])", r"m\1"), 39 | (r"N([↑↓]*[ʧʥj])", r"n^\1"), 40 | (r"N([↑↓]*[tdn])", r"n\1"), 41 | (r"N([↑↓]*[kg])", r"ŋ\1"), 42 | ] 43 | ] 44 | 45 | 46 | def post_replace_ph(ph): 47 | rep_map = { 48 | ":": ",", 49 | ";": ",", 50 | ",": ",", 51 | "。": ".", 52 | "!": "!", 53 | "?": "?", 54 | "\n": ".", 55 | "·": ",", 56 | "、": ",", 57 | "...": "…", 58 | } 59 | if ph in rep_map.keys(): 60 | ph = rep_map[ph] 61 | if ph in symbols: 62 | return ph 63 | if ph not in symbols: 64 | ph = "UNK" 65 | return ph 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def preprocess_jap(text, with_prosody=False): 75 | """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html""" 76 | text = symbols_to_japanese(text) 77 | sentences = re.split(_japanese_marks, text) 78 | marks = re.findall(_japanese_marks, text) 79 | text = [] 80 | for i, sentence in enumerate(sentences): 81 | if re.match(_japanese_characters, sentence): 82 | if with_prosody: 83 | text += pyopenjtalk_g2p_prosody(sentence)[1:-1] 84 | else: 85 | p = pyopenjtalk.g2p(sentence) 86 | text += p.split(" ") 87 | 88 | if i < len(marks): 89 | if marks[i] == " ":# 防止意外的UNK 90 | continue 91 | text += [marks[i].replace(" ", "")] 92 | return text 93 | 94 | 95 | def text_normalize(text): 96 | # todo: jap text normalize 97 | return text 98 | 99 | # Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py 100 | def pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True): 101 | """Extract phoneme + prosoody symbol sequence from input full-context labels. 102 | 103 | The algorithm is based on `Prosodic features control by symbols as input of 104 | sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks. 105 | 106 | Args: 107 | text (str): Input text. 108 | drop_unvoiced_vowels (bool): whether to drop unvoiced vowels. 109 | 110 | Returns: 111 | List[str]: List of phoneme + prosody symbols. 112 | 113 | Examples: 114 | >>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody 115 | >>> pyopenjtalk_g2p_prosody("こんにちは。") 116 | ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$'] 117 | 118 | .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic 119 | modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104 120 | 121 | """ 122 | labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text)) 123 | N = len(labels) 124 | 125 | phones = [] 126 | for n in range(N): 127 | lab_curr = labels[n] 128 | 129 | # current phoneme 130 | p3 = re.search(r"\-(.*?)\+", lab_curr).group(1) 131 | # deal unvoiced vowels as normal vowels 132 | if drop_unvoiced_vowels and p3 in "AEIOU": 133 | p3 = p3.lower() 134 | 135 | # deal with sil at the beginning and the end of text 136 | if p3 == "sil": 137 | assert n == 0 or n == N - 1 138 | if n == 0: 139 | phones.append("^") 140 | elif n == N - 1: 141 | # check question form or not 142 | e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr) 143 | if e3 == 0: 144 | phones.append("$") 145 | elif e3 == 1: 146 | phones.append("?") 147 | continue 148 | elif p3 == "pau": 149 | phones.append("_") 150 | continue 151 | else: 152 | phones.append(p3) 153 | 154 | # accent type and position info (forward or backward) 155 | a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr) 156 | a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr) 157 | a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr) 158 | 159 | # number of mora in accent phrase 160 | f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr) 161 | 162 | a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1]) 163 | # accent phrase border 164 | if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl": 165 | phones.append("#") 166 | # pitch falling 167 | elif a1 == 0 and a2_next == a2 + 1 and a2 != f1: 168 | phones.append("]") 169 | # pitch rising 170 | elif a2 == 1 and a2_next == 2: 171 | phones.append("[") 172 | 173 | return phones 174 | 175 | # Copied from espnet https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py 176 | def _numeric_feature_by_regex(regex, s): 177 | match = re.search(regex, s) 178 | if match is None: 179 | return -50 180 | return int(match.group(1)) 181 | 182 | def g2p(norm_text, with_prosody=False): 183 | phones = preprocess_jap(norm_text, with_prosody) 184 | phones = [post_replace_ph(i) for i in phones] 185 | # todo: implement tones and word2ph 186 | return phones 187 | 188 | 189 | if __name__ == "__main__": 190 | phones = g2p("こんにちは, hello, AKITOです,よろしくお願いしますね!") 191 | print(phones) -------------------------------------------------------------------------------- /GPT_SoVITS/text/opencpop-strict.txt: -------------------------------------------------------------------------------- 1 | a AA a 2 | ai AA ai 3 | an AA an 4 | ang AA ang 5 | ao AA ao 6 | ba b a 7 | bai b ai 8 | ban b an 9 | bang b ang 10 | bao b ao 11 | bei b ei 12 | ben b en 13 | beng b eng 14 | bi b i 15 | bian b ian 16 | biao b iao 17 | bie b ie 18 | bin b in 19 | bing b ing 20 | bo b o 21 | bu b u 22 | ca c a 23 | cai c ai 24 | can c an 25 | cang c ang 26 | cao c ao 27 | ce c e 28 | cei c ei 29 | cen c en 30 | ceng c eng 31 | cha ch a 32 | chai ch ai 33 | chan ch an 34 | chang ch ang 35 | chao ch ao 36 | che ch e 37 | chen ch en 38 | cheng ch eng 39 | chi ch ir 40 | chong ch ong 41 | chou ch ou 42 | chu ch u 43 | chua ch ua 44 | chuai ch uai 45 | chuan ch uan 46 | chuang ch uang 47 | chui ch ui 48 | chun ch un 49 | chuo ch uo 50 | ci c i0 51 | cong c ong 52 | cou c ou 53 | cu c u 54 | cuan c uan 55 | cui c ui 56 | cun c un 57 | cuo c uo 58 | da d a 59 | dai d ai 60 | dan d an 61 | dang d ang 62 | dao d ao 63 | de d e 64 | dei d ei 65 | den d en 66 | deng d eng 67 | di d i 68 | dia d ia 69 | dian d ian 70 | diao d iao 71 | die d ie 72 | ding d ing 73 | diu d iu 74 | dong d ong 75 | dou d ou 76 | du d u 77 | duan d uan 78 | dui d ui 79 | dun d un 80 | duo d uo 81 | e EE e 82 | ei EE ei 83 | en EE en 84 | eng EE eng 85 | er EE er 86 | fa f a 87 | fan f an 88 | fang f ang 89 | fei f ei 90 | fen f en 91 | feng f eng 92 | fo f o 93 | fou f ou 94 | fu f u 95 | ga g a 96 | gai g ai 97 | gan g an 98 | gang g ang 99 | gao g ao 100 | ge g e 101 | gei g ei 102 | gen g en 103 | geng g eng 104 | gong g ong 105 | gou g ou 106 | gu g u 107 | gua g ua 108 | guai g uai 109 | guan g uan 110 | guang g uang 111 | gui g ui 112 | gun g un 113 | guo g uo 114 | ha h a 115 | hai h ai 116 | han h an 117 | hang h ang 118 | hao h ao 119 | he h e 120 | hei h ei 121 | hen h en 122 | heng h eng 123 | hong h ong 124 | hou h ou 125 | hu h u 126 | hua h ua 127 | huai h uai 128 | huan h uan 129 | huang h uang 130 | hui h ui 131 | hun h un 132 | huo h uo 133 | ji j i 134 | jia j ia 135 | jian j ian 136 | jiang j iang 137 | jiao j iao 138 | jie j ie 139 | jin j in 140 | jing j ing 141 | jiong j iong 142 | jiu j iu 143 | ju j v 144 | jv j v 145 | juan j van 146 | jvan j van 147 | jue j ve 148 | jve j ve 149 | jun j vn 150 | jvn j vn 151 | ka k a 152 | kai k ai 153 | kan k an 154 | kang k ang 155 | kao k ao 156 | ke k e 157 | kei k ei 158 | ken k en 159 | keng k eng 160 | kong k ong 161 | kou k ou 162 | ku k u 163 | kua k ua 164 | kuai k uai 165 | kuan k uan 166 | kuang k uang 167 | kui k ui 168 | kun k un 169 | kuo k uo 170 | la l a 171 | lai l ai 172 | lan l an 173 | lang l ang 174 | lao l ao 175 | le l e 176 | lei l ei 177 | leng l eng 178 | li l i 179 | lia l ia 180 | lian l ian 181 | liang l iang 182 | liao l iao 183 | lie l ie 184 | lin l in 185 | ling l ing 186 | liu l iu 187 | lo l o 188 | long l ong 189 | lou l ou 190 | lu l u 191 | luan l uan 192 | lun l un 193 | luo l uo 194 | lv l v 195 | lve l ve 196 | ma m a 197 | mai m ai 198 | man m an 199 | mang m ang 200 | mao m ao 201 | me m e 202 | mei m ei 203 | men m en 204 | meng m eng 205 | mi m i 206 | mian m ian 207 | miao m iao 208 | mie m ie 209 | min m in 210 | ming m ing 211 | miu m iu 212 | mo m o 213 | mou m ou 214 | mu m u 215 | na n a 216 | nai n ai 217 | nan n an 218 | nang n ang 219 | nao n ao 220 | ne n e 221 | nei n ei 222 | nen n en 223 | neng n eng 224 | ni n i 225 | nian n ian 226 | niang n iang 227 | niao n iao 228 | nie n ie 229 | nin n in 230 | ning n ing 231 | niu n iu 232 | nong n ong 233 | nou n ou 234 | nu n u 235 | nuan n uan 236 | nun n un 237 | nuo n uo 238 | nv n v 239 | nve n ve 240 | o OO o 241 | ou OO ou 242 | pa p a 243 | pai p ai 244 | pan p an 245 | pang p ang 246 | pao p ao 247 | pei p ei 248 | pen p en 249 | peng p eng 250 | pi p i 251 | pian p ian 252 | piao p iao 253 | pie p ie 254 | pin p in 255 | ping p ing 256 | po p o 257 | pou p ou 258 | pu p u 259 | qi q i 260 | qia q ia 261 | qian q ian 262 | qiang q iang 263 | qiao q iao 264 | qie q ie 265 | qin q in 266 | qing q ing 267 | qiong q iong 268 | qiu q iu 269 | qu q v 270 | qv q v 271 | quan q van 272 | qvan q van 273 | que q ve 274 | qve q ve 275 | qun q vn 276 | qvn q vn 277 | ran r an 278 | rang r ang 279 | rao r ao 280 | re r e 281 | ren r en 282 | reng r eng 283 | ri r ir 284 | rong r ong 285 | rou r ou 286 | ru r u 287 | rua r ua 288 | ruan r uan 289 | rui r ui 290 | run r un 291 | ruo r uo 292 | sa s a 293 | sai s ai 294 | san s an 295 | sang s ang 296 | sao s ao 297 | se s e 298 | sen s en 299 | seng s eng 300 | sha sh a 301 | shai sh ai 302 | shan sh an 303 | shang sh ang 304 | shao sh ao 305 | she sh e 306 | shei sh ei 307 | shen sh en 308 | sheng sh eng 309 | shi sh ir 310 | shou sh ou 311 | shu sh u 312 | shua sh ua 313 | shuai sh uai 314 | shuan sh uan 315 | shuang sh uang 316 | shui sh ui 317 | shun sh un 318 | shuo sh uo 319 | si s i0 320 | song s ong 321 | sou s ou 322 | su s u 323 | suan s uan 324 | sui s ui 325 | sun s un 326 | suo s uo 327 | ta t a 328 | tai t ai 329 | tan t an 330 | tang t ang 331 | tao t ao 332 | te t e 333 | tei t ei 334 | teng t eng 335 | ti t i 336 | tian t ian 337 | tiao t iao 338 | tie t ie 339 | ting t ing 340 | tong t ong 341 | tou t ou 342 | tu t u 343 | tuan t uan 344 | tui t ui 345 | tun t un 346 | tuo t uo 347 | wa w a 348 | wai w ai 349 | wan w an 350 | wang w ang 351 | wei w ei 352 | wen w en 353 | weng w eng 354 | wo w o 355 | wu w u 356 | xi x i 357 | xia x ia 358 | xian x ian 359 | xiang x iang 360 | xiao x iao 361 | xie x ie 362 | xin x in 363 | xing x ing 364 | xiong x iong 365 | xiu x iu 366 | xu x v 367 | xv x v 368 | xuan x van 369 | xvan x van 370 | xue x ve 371 | xve x ve 372 | xun x vn 373 | xvn x vn 374 | ya y a 375 | yan y En 376 | yang y ang 377 | yao y ao 378 | ye y E 379 | yi y i 380 | yin y in 381 | ying y ing 382 | yo y o 383 | yong y ong 384 | you y ou 385 | yu y v 386 | yv y v 387 | yuan y van 388 | yvan y van 389 | yue y ve 390 | yve y ve 391 | yun y vn 392 | yvn y vn 393 | za z a 394 | zai z ai 395 | zan z an 396 | zang z ang 397 | zao z ao 398 | ze z e 399 | zei z ei 400 | zen z en 401 | zeng z eng 402 | zha zh a 403 | zhai zh ai 404 | zhan zh an 405 | zhang zh ang 406 | zhao zh ao 407 | zhe zh e 408 | zhei zh ei 409 | zhen zh en 410 | zheng zh eng 411 | zhi zh ir 412 | zhong zh ong 413 | zhou zh ou 414 | zhu zh u 415 | zhua zh ua 416 | zhuai zh uai 417 | zhuan zh uan 418 | zhuang zh uang 419 | zhui zh ui 420 | zhun zh un 421 | zhuo zh uo 422 | zi z i0 423 | zong z ong 424 | zou z ou 425 | zu z u 426 | zuan z uan 427 | zui z ui 428 | zun z un 429 | zuo z uo 430 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/symbols.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # punctuation = ['!', '?', '…', ",", ".","@"]#@是SP停顿 4 | punctuation = ["!", "?", "…", ",", "."] # @是SP停顿 5 | punctuation.append("-") 6 | pu_symbols = punctuation + ["SP", "SP2", "SP3", "UNK"] 7 | # pu_symbols = punctuation + ["SP", 'SP2', 'SP3','SP4', "UNK"] 8 | pad = "_" 9 | 10 | c = [ 11 | "AA", 12 | "EE", 13 | "OO", 14 | "b", 15 | "c", 16 | "ch", 17 | "d", 18 | "f", 19 | "g", 20 | "h", 21 | "j", 22 | "k", 23 | "l", 24 | "m", 25 | "n", 26 | "p", 27 | "q", 28 | "r", 29 | "s", 30 | "sh", 31 | "t", 32 | "w", 33 | "x", 34 | "y", 35 | "z", 36 | "zh", 37 | ] 38 | v = [ 39 | "E1", 40 | "En1", 41 | "a1", 42 | "ai1", 43 | "an1", 44 | "ang1", 45 | "ao1", 46 | "e1", 47 | "ei1", 48 | "en1", 49 | "eng1", 50 | "er1", 51 | "i1", 52 | "i01", 53 | "ia1", 54 | "ian1", 55 | "iang1", 56 | "iao1", 57 | "ie1", 58 | "in1", 59 | "ing1", 60 | "iong1", 61 | "ir1", 62 | "iu1", 63 | "o1", 64 | "ong1", 65 | "ou1", 66 | "u1", 67 | "ua1", 68 | "uai1", 69 | "uan1", 70 | "uang1", 71 | "ui1", 72 | "un1", 73 | "uo1", 74 | "v1", 75 | "van1", 76 | "ve1", 77 | "vn1", 78 | "E2", 79 | "En2", 80 | "a2", 81 | "ai2", 82 | "an2", 83 | "ang2", 84 | "ao2", 85 | "e2", 86 | "ei2", 87 | "en2", 88 | "eng2", 89 | "er2", 90 | "i2", 91 | "i02", 92 | "ia2", 93 | "ian2", 94 | "iang2", 95 | "iao2", 96 | "ie2", 97 | "in2", 98 | "ing2", 99 | "iong2", 100 | "ir2", 101 | "iu2", 102 | "o2", 103 | "ong2", 104 | "ou2", 105 | "u2", 106 | "ua2", 107 | "uai2", 108 | "uan2", 109 | "uang2", 110 | "ui2", 111 | "un2", 112 | "uo2", 113 | "v2", 114 | "van2", 115 | "ve2", 116 | "vn2", 117 | "E3", 118 | "En3", 119 | "a3", 120 | "ai3", 121 | "an3", 122 | "ang3", 123 | "ao3", 124 | "e3", 125 | "ei3", 126 | "en3", 127 | "eng3", 128 | "er3", 129 | "i3", 130 | "i03", 131 | "ia3", 132 | "ian3", 133 | "iang3", 134 | "iao3", 135 | "ie3", 136 | "in3", 137 | "ing3", 138 | "iong3", 139 | "ir3", 140 | "iu3", 141 | "o3", 142 | "ong3", 143 | "ou3", 144 | "u3", 145 | "ua3", 146 | "uai3", 147 | "uan3", 148 | "uang3", 149 | "ui3", 150 | "un3", 151 | "uo3", 152 | "v3", 153 | "van3", 154 | "ve3", 155 | "vn3", 156 | "E4", 157 | "En4", 158 | "a4", 159 | "ai4", 160 | "an4", 161 | "ang4", 162 | "ao4", 163 | "e4", 164 | "ei4", 165 | "en4", 166 | "eng4", 167 | "er4", 168 | "i4", 169 | "i04", 170 | "ia4", 171 | "ian4", 172 | "iang4", 173 | "iao4", 174 | "ie4", 175 | "in4", 176 | "ing4", 177 | "iong4", 178 | "ir4", 179 | "iu4", 180 | "o4", 181 | "ong4", 182 | "ou4", 183 | "u4", 184 | "ua4", 185 | "uai4", 186 | "uan4", 187 | "uang4", 188 | "ui4", 189 | "un4", 190 | "uo4", 191 | "v4", 192 | "van4", 193 | "ve4", 194 | "vn4", 195 | "E5", 196 | "En5", 197 | "a5", 198 | "ai5", 199 | "an5", 200 | "ang5", 201 | "ao5", 202 | "e5", 203 | "ei5", 204 | "en5", 205 | "eng5", 206 | "er5", 207 | "i5", 208 | "i05", 209 | "ia5", 210 | "ian5", 211 | "iang5", 212 | "iao5", 213 | "ie5", 214 | "in5", 215 | "ing5", 216 | "iong5", 217 | "ir5", 218 | "iu5", 219 | "o5", 220 | "ong5", 221 | "ou5", 222 | "u5", 223 | "ua5", 224 | "uai5", 225 | "uan5", 226 | "uang5", 227 | "ui5", 228 | "un5", 229 | "uo5", 230 | "v5", 231 | "van5", 232 | "ve5", 233 | "vn5", 234 | ] 235 | 236 | v_without_tone = [ 237 | "E", 238 | "En", 239 | "a", 240 | "ai", 241 | "an", 242 | "ang", 243 | "ao", 244 | "e", 245 | "ei", 246 | "en", 247 | "eng", 248 | "er", 249 | "i", 250 | "i0", 251 | "ia", 252 | "ian", 253 | "iang", 254 | "iao", 255 | "ie", 256 | "in", 257 | "ing", 258 | "iong", 259 | "ir", 260 | "iu", 261 | "o", 262 | "ong", 263 | "ou", 264 | "u", 265 | "ua", 266 | "uai", 267 | "uan", 268 | "uang", 269 | "ui", 270 | "un", 271 | "uo", 272 | "v", 273 | "van", 274 | "ve", 275 | "vn", 276 | ] 277 | 278 | # japanese 279 | ja_symbols = [ 280 | "I", 281 | "N", 282 | "U", 283 | "a", 284 | "b", 285 | "by", 286 | "ch", 287 | "cl", 288 | "d", 289 | "dy", 290 | "e", 291 | "f", 292 | "g", 293 | "gy", 294 | "h", 295 | "hy", 296 | "i", 297 | "j", 298 | "k", 299 | "ky", 300 | "m", 301 | "my", 302 | "n", 303 | "ny", 304 | "o", 305 | "p", 306 | "py", 307 | "r", 308 | "ry", 309 | "s", 310 | "sh", 311 | "t", 312 | "ts", 313 | "u", 314 | "v", 315 | "w", 316 | "y", 317 | "z", 318 | # "[", #上升调型 319 | # "]", #下降调型 320 | # "$", #结束符 321 | # "^", #开始符 322 | ] 323 | 324 | arpa = { 325 | "AH0", 326 | "S", 327 | "AH1", 328 | "EY2", 329 | "AE2", 330 | "EH0", 331 | "OW2", 332 | "UH0", 333 | "NG", 334 | "B", 335 | "G", 336 | "AY0", 337 | "M", 338 | "AA0", 339 | "F", 340 | "AO0", 341 | "ER2", 342 | "UH1", 343 | "IY1", 344 | "AH2", 345 | "DH", 346 | "IY0", 347 | "EY1", 348 | "IH0", 349 | "K", 350 | "N", 351 | "W", 352 | "IY2", 353 | "T", 354 | "AA1", 355 | "ER1", 356 | "EH2", 357 | "OY0", 358 | "UH2", 359 | "UW1", 360 | "Z", 361 | "AW2", 362 | "AW1", 363 | "V", 364 | "UW2", 365 | "AA2", 366 | "ER", 367 | "AW0", 368 | "UW0", 369 | "R", 370 | "OW1", 371 | "EH1", 372 | "ZH", 373 | "AE0", 374 | "IH2", 375 | "IH", 376 | "Y", 377 | "JH", 378 | "P", 379 | "AY1", 380 | "EY0", 381 | "OY2", 382 | "TH", 383 | "HH", 384 | "D", 385 | "ER0", 386 | "CH", 387 | "AO1", 388 | "AE1", 389 | "AO2", 390 | "OY1", 391 | "AY2", 392 | "IH1", 393 | "OW0", 394 | "L", 395 | "SH", 396 | } 397 | 398 | symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) 399 | symbols = sorted(set(symbols)) 400 | if __name__ == "__main__": 401 | print(len(symbols)) 402 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/README.md: -------------------------------------------------------------------------------- 1 | ## Supported NSW (Non-Standard-Word) Normalization 2 | 3 | |NSW type|raw|normalized| 4 | |:--|:-|:-| 5 | |serial number|电影中梁朝伟扮演的陈永仁的编号27149|电影中梁朝伟扮演的陈永仁的编号二七一四九| 6 | |cardinal|这块黄金重达324.75克
我们班的最高总分为583分|这块黄金重达三百二十四点七五克
我们班的最高总分为五百八十三分| 7 | |numeric range |12\~23
-1.5\~2|十二到二十三
负一点五到二| 8 | |date|她出生于86年8月18日,她弟弟出生于1995年3月1日|她出生于八六年八月十八日, 她弟弟出生于一九九五年三月一日| 9 | |time|等会请在12:05请通知我|等会请在十二点零五分请通知我 10 | |temperature|今天的最低气温达到-10°C|今天的最低气温达到零下十度 11 | |fraction|现场有7/12的观众投出了赞成票|现场有十二分之七的观众投出了赞成票| 12 | |percentage|明天有62%的概率降雨|明天有百分之六十二的概率降雨| 13 | |money|随便来几个价格12块5,34.5元,20.1万|随便来几个价格十二块五,三十四点五元,二十点一万| 14 | |telephone|这是固话0421-33441122
这是手机+86 18544139121|这是固话零四二一三三四四一一二二
这是手机八六一八五四四一三九一二一| 15 | ## References 16 | [Pull requests #658 of DeepSpeech](https://github.com/PaddlePaddle/DeepSpeech/pull/658/files) 17 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from text.zh_normalization.text_normlization import * 15 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/chronology.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import DIGITS 17 | from .num import num2str 18 | from .num import verbalize_cardinal 19 | from .num import verbalize_digit 20 | 21 | 22 | def _time_num2str(num_string: str) -> str: 23 | """A special case for verbalizing number in time.""" 24 | result = num2str(num_string.lstrip('0')) 25 | if num_string.startswith('0'): 26 | result = DIGITS['0'] + result 27 | return result 28 | 29 | 30 | # 时刻表达式 31 | RE_TIME = re.compile(r'([0-1]?[0-9]|2[0-3])' 32 | r':([0-5][0-9])' 33 | r'(:([0-5][0-9]))?') 34 | 35 | # 时间范围,如8:30-12:30 36 | RE_TIME_RANGE = re.compile(r'([0-1]?[0-9]|2[0-3])' 37 | r':([0-5][0-9])' 38 | r'(:([0-5][0-9]))?' 39 | r'(~|-)' 40 | r'([0-1]?[0-9]|2[0-3])' 41 | r':([0-5][0-9])' 42 | r'(:([0-5][0-9]))?') 43 | 44 | 45 | def replace_time(match) -> str: 46 | """ 47 | Args: 48 | match (re.Match) 49 | Returns: 50 | str 51 | """ 52 | 53 | is_range = len(match.groups()) > 5 54 | 55 | hour = match.group(1) 56 | minute = match.group(2) 57 | second = match.group(4) 58 | 59 | if is_range: 60 | hour_2 = match.group(6) 61 | minute_2 = match.group(7) 62 | second_2 = match.group(9) 63 | 64 | result = f"{num2str(hour)}点" 65 | if minute.lstrip('0'): 66 | if int(minute) == 30: 67 | result += "半" 68 | else: 69 | result += f"{_time_num2str(minute)}分" 70 | if second and second.lstrip('0'): 71 | result += f"{_time_num2str(second)}秒" 72 | 73 | if is_range: 74 | result += "至" 75 | result += f"{num2str(hour_2)}点" 76 | if minute_2.lstrip('0'): 77 | if int(minute) == 30: 78 | result += "半" 79 | else: 80 | result += f"{_time_num2str(minute_2)}分" 81 | if second_2 and second_2.lstrip('0'): 82 | result += f"{_time_num2str(second_2)}秒" 83 | 84 | return result 85 | 86 | 87 | RE_DATE = re.compile(r'(\d{4}|\d{2})年' 88 | r'((0?[1-9]|1[0-2])月)?' 89 | r'(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?') 90 | 91 | 92 | def replace_date(match) -> str: 93 | """ 94 | Args: 95 | match (re.Match) 96 | Returns: 97 | str 98 | """ 99 | year = match.group(1) 100 | month = match.group(3) 101 | day = match.group(5) 102 | result = "" 103 | if year: 104 | result += f"{verbalize_digit(year)}年" 105 | if month: 106 | result += f"{verbalize_cardinal(month)}月" 107 | if day: 108 | result += f"{verbalize_cardinal(day)}{match.group(9)}" 109 | return result 110 | 111 | 112 | # 用 / 或者 - 分隔的 YY/MM/DD 或者 YY-MM-DD 日期 113 | RE_DATE2 = re.compile( 114 | r'(\d{4})([- /.])(0[1-9]|1[012])\2(0[1-9]|[12][0-9]|3[01])') 115 | 116 | 117 | def replace_date2(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | year = match.group(1) 125 | month = match.group(3) 126 | day = match.group(4) 127 | result = "" 128 | if year: 129 | result += f"{verbalize_digit(year)}年" 130 | if month: 131 | result += f"{verbalize_cardinal(month)}月" 132 | if day: 133 | result += f"{verbalize_cardinal(day)}日" 134 | return result 135 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | import string 16 | 17 | from pypinyin.constants import SUPPORT_UCS4 18 | 19 | # 全角半角转换 20 | # 英文字符全角 -> 半角映射表 (num: 52) 21 | F2H_ASCII_LETTERS = { 22 | ord(char) + 65248: ord(char) 23 | for char in string.ascii_letters 24 | } 25 | 26 | # 英文字符半角 -> 全角映射表 27 | H2F_ASCII_LETTERS = {value: key for key, value in F2H_ASCII_LETTERS.items()} 28 | 29 | # 数字字符全角 -> 半角映射表 (num: 10) 30 | F2H_DIGITS = {ord(char) + 65248: ord(char) for char in string.digits} 31 | # 数字字符半角 -> 全角映射表 32 | H2F_DIGITS = {value: key for key, value in F2H_DIGITS.items()} 33 | 34 | # 标点符号全角 -> 半角映射表 (num: 32) 35 | F2H_PUNCTUATIONS = {ord(char) + 65248: ord(char) for char in string.punctuation} 36 | # 标点符号半角 -> 全角映射表 37 | H2F_PUNCTUATIONS = {value: key for key, value in F2H_PUNCTUATIONS.items()} 38 | 39 | # 空格 (num: 1) 40 | F2H_SPACE = {'\u3000': ' '} 41 | H2F_SPACE = {' ': '\u3000'} 42 | 43 | # 非"有拼音的汉字"的字符串,可用于NSW提取 44 | if SUPPORT_UCS4: 45 | RE_NSW = re.compile(r'(?:[^' 46 | r'\u3007' # 〇 47 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 48 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 49 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 50 | r'\U00020000-\U0002A6DF' # CJK扩展B:[20000-2A6DF] 51 | r'\U0002A703-\U0002B73F' # CJK扩展C:[2A700-2B73F] 52 | r'\U0002B740-\U0002B81D' # CJK扩展D:[2B740-2B81D] 53 | r'\U0002F80A-\U0002FA1F' # CJK兼容扩展:[2F800-2FA1F] 54 | r'])+') 55 | else: 56 | RE_NSW = re.compile( # pragma: no cover 57 | r'(?:[^' 58 | r'\u3007' # 〇 59 | r'\u3400-\u4dbf' # CJK扩展A:[3400-4DBF] 60 | r'\u4e00-\u9fff' # CJK基本:[4E00-9FFF] 61 | r'\uf900-\ufaff' # CJK兼容:[F900-FAFF] 62 | r'])+') 63 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/num.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Rules to verbalize numbers into Chinese characters. 16 | https://zh.wikipedia.org/wiki/中文数字#現代中文 17 | """ 18 | import re 19 | from collections import OrderedDict 20 | from typing import List 21 | 22 | DIGITS = {str(i): tran for i, tran in enumerate('零一二三四五六七八九')} 23 | UNITS = OrderedDict({ 24 | 1: '十', 25 | 2: '百', 26 | 3: '千', 27 | 4: '万', 28 | 8: '亿', 29 | }) 30 | 31 | COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)' 32 | 33 | # 分数表达式 34 | RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') 35 | 36 | 37 | def replace_frac(match) -> str: 38 | """ 39 | Args: 40 | match (re.Match) 41 | Returns: 42 | str 43 | """ 44 | sign = match.group(1) 45 | nominator = match.group(2) 46 | denominator = match.group(3) 47 | sign: str = "负" if sign else "" 48 | nominator: str = num2str(nominator) 49 | denominator: str = num2str(denominator) 50 | result = f"{sign}{denominator}分之{nominator}" 51 | return result 52 | 53 | 54 | # 百分数表达式 55 | RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%') 56 | 57 | 58 | def replace_percentage(match) -> str: 59 | """ 60 | Args: 61 | match (re.Match) 62 | Returns: 63 | str 64 | """ 65 | sign = match.group(1) 66 | percent = match.group(2) 67 | sign: str = "负" if sign else "" 68 | percent: str = num2str(percent) 69 | result = f"{sign}百分之{percent}" 70 | return result 71 | 72 | 73 | # 整数表达式 74 | # 带负号的整数 -10 75 | RE_INTEGER = re.compile(r'(-)' r'(\d+)') 76 | 77 | 78 | def replace_negative_num(match) -> str: 79 | """ 80 | Args: 81 | match (re.Match) 82 | Returns: 83 | str 84 | """ 85 | sign = match.group(1) 86 | number = match.group(2) 87 | sign: str = "负" if sign else "" 88 | number: str = num2str(number) 89 | result = f"{sign}{number}" 90 | return result 91 | 92 | 93 | # 编号-无符号整形 94 | # 00078 95 | RE_DEFAULT_NUM = re.compile(r'\d{3}\d*') 96 | 97 | 98 | def replace_default_num(match): 99 | """ 100 | Args: 101 | match (re.Match) 102 | Returns: 103 | str 104 | """ 105 | number = match.group(0) 106 | return verbalize_digit(number, alt_one=True) 107 | 108 | 109 | # 数字表达式 110 | # 纯小数 111 | RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))') 112 | # 正整数 + 量词 113 | RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS) 114 | RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))') 115 | 116 | 117 | def replace_positive_quantifier(match) -> str: 118 | """ 119 | Args: 120 | match (re.Match) 121 | Returns: 122 | str 123 | """ 124 | number = match.group(1) 125 | match_2 = match.group(2) 126 | if match_2 == "+": 127 | match_2 = "多" 128 | match_2: str = match_2 if match_2 else "" 129 | quantifiers: str = match.group(3) 130 | number: str = num2str(number) 131 | result = f"{number}{match_2}{quantifiers}" 132 | return result 133 | 134 | 135 | def replace_number(match) -> str: 136 | """ 137 | Args: 138 | match (re.Match) 139 | Returns: 140 | str 141 | """ 142 | sign = match.group(1) 143 | number = match.group(2) 144 | pure_decimal = match.group(5) 145 | if pure_decimal: 146 | result = num2str(pure_decimal) 147 | else: 148 | sign: str = "负" if sign else "" 149 | number: str = num2str(number) 150 | result = f"{sign}{number}" 151 | return result 152 | 153 | 154 | # 范围表达式 155 | # match.group(1) and match.group(8) are copy from RE_NUMBER 156 | 157 | RE_RANGE = re.compile( 158 | r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))[-~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))') 159 | 160 | 161 | def replace_range(match) -> str: 162 | """ 163 | Args: 164 | match (re.Match) 165 | Returns: 166 | str 167 | """ 168 | first, second = match.group(1), match.group(8) 169 | first = RE_NUMBER.sub(replace_number, first) 170 | second = RE_NUMBER.sub(replace_number, second) 171 | result = f"{first}到{second}" 172 | return result 173 | 174 | 175 | def _get_value(value_string: str, use_zero: bool=True) -> List[str]: 176 | stripped = value_string.lstrip('0') 177 | if len(stripped) == 0: 178 | return [] 179 | elif len(stripped) == 1: 180 | if use_zero and len(stripped) < len(value_string): 181 | return [DIGITS['0'], DIGITS[stripped]] 182 | else: 183 | return [DIGITS[stripped]] 184 | else: 185 | largest_unit = next( 186 | power for power in reversed(UNITS.keys()) if power < len(stripped)) 187 | first_part = value_string[:-largest_unit] 188 | second_part = value_string[-largest_unit:] 189 | return _get_value(first_part) + [UNITS[largest_unit]] + _get_value( 190 | second_part) 191 | 192 | 193 | def verbalize_cardinal(value_string: str) -> str: 194 | if not value_string: 195 | return '' 196 | 197 | # 000 -> '零' , 0 -> '零' 198 | value_string = value_string.lstrip('0') 199 | if len(value_string) == 0: 200 | return DIGITS['0'] 201 | 202 | result_symbols = _get_value(value_string) 203 | # verbalized number starting with '一十*' is abbreviated as `十*` 204 | if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[ 205 | '1'] and result_symbols[1] == UNITS[1]: 206 | result_symbols = result_symbols[1:] 207 | return ''.join(result_symbols) 208 | 209 | 210 | def verbalize_digit(value_string: str, alt_one=False) -> str: 211 | result_symbols = [DIGITS[digit] for digit in value_string] 212 | result = ''.join(result_symbols) 213 | if alt_one: 214 | result = result.replace("一", "幺") 215 | return result 216 | 217 | 218 | def num2str(value_string: str) -> str: 219 | integer_decimal = value_string.split('.') 220 | if len(integer_decimal) == 1: 221 | integer = integer_decimal[0] 222 | decimal = '' 223 | elif len(integer_decimal) == 2: 224 | integer, decimal = integer_decimal 225 | else: 226 | raise ValueError( 227 | f"The value string: '${value_string}' has more than one point in it." 228 | ) 229 | 230 | result = verbalize_cardinal(integer) 231 | 232 | decimal = decimal.rstrip('0') 233 | if decimal: 234 | # '.22' is verbalized as '零点二二' 235 | # '3.20' is verbalized as '三点二 236 | result = result if result else "零" 237 | result += '点' + verbalize_digit(decimal) 238 | return result 239 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/phonecode.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import verbalize_digit 17 | 18 | # 规范化固话/手机号码 19 | # 手机 20 | # http://www.jihaoba.com/news/show/13680 21 | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 22 | # 联通:130、131、132、156、155、186、185、176 23 | # 电信:133、153、189、180、181、177 24 | RE_MOBILE_PHONE = re.compile( 25 | r"(? str: 34 | if mobile: 35 | sp_parts = phone_string.strip('+').split() 36 | result = ','.join( 37 | [verbalize_digit(part, alt_one=True) for part in sp_parts]) 38 | return result 39 | else: 40 | sil_parts = phone_string.split('-') 41 | result = ','.join( 42 | [verbalize_digit(part, alt_one=True) for part in sil_parts]) 43 | return result 44 | 45 | 46 | def replace_phone(match) -> str: 47 | """ 48 | Args: 49 | match (re.Match) 50 | Returns: 51 | str 52 | """ 53 | return phone2str(match.group(0), mobile=False) 54 | 55 | 56 | def replace_mobile(match) -> str: 57 | """ 58 | Args: 59 | match (re.Match) 60 | Returns: 61 | str 62 | """ 63 | return phone2str(match.group(0)) 64 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/quantifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from .num import num2str 17 | 18 | # 温度表达式,温度会影响负号的读法 19 | # -3°C 零下三度 20 | RE_TEMPERATURE = re.compile(r'(-?)(\d+(\.\d+)?)(°C|℃|度|摄氏度)') 21 | measure_dict = { 22 | "cm2": "平方厘米", 23 | "cm²": "平方厘米", 24 | "cm3": "立方厘米", 25 | "cm³": "立方厘米", 26 | "cm": "厘米", 27 | "db": "分贝", 28 | "ds": "毫秒", 29 | "kg": "千克", 30 | "km": "千米", 31 | "m2": "平方米", 32 | "m²": "平方米", 33 | "m³": "立方米", 34 | "m3": "立方米", 35 | "ml": "毫升", 36 | "m": "米", 37 | "mm": "毫米", 38 | "s": "秒" 39 | } 40 | 41 | 42 | def replace_temperature(match) -> str: 43 | """ 44 | Args: 45 | match (re.Match) 46 | Returns: 47 | str 48 | """ 49 | sign = match.group(1) 50 | temperature = match.group(2) 51 | unit = match.group(3) 52 | sign: str = "零下" if sign else "" 53 | temperature: str = num2str(temperature) 54 | unit: str = "摄氏度" if unit == "摄氏度" else "度" 55 | result = f"{sign}{temperature}{unit}" 56 | return result 57 | 58 | 59 | def replace_measure(sentence) -> str: 60 | for q_notation in measure_dict: 61 | if q_notation in sentence: 62 | sentence = sentence.replace(q_notation, measure_dict[q_notation]) 63 | return sentence 64 | -------------------------------------------------------------------------------- /GPT_SoVITS/text/zh_normalization/text_normlization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | from typing import List 16 | 17 | from .char_convert import tranditional_to_simplified 18 | from .chronology import RE_DATE 19 | from .chronology import RE_DATE2 20 | from .chronology import RE_TIME 21 | from .chronology import RE_TIME_RANGE 22 | from .chronology import replace_date 23 | from .chronology import replace_date2 24 | from .chronology import replace_time 25 | from .constants import F2H_ASCII_LETTERS 26 | from .constants import F2H_DIGITS 27 | from .constants import F2H_SPACE 28 | from .num import RE_DECIMAL_NUM 29 | from .num import RE_DEFAULT_NUM 30 | from .num import RE_FRAC 31 | from .num import RE_INTEGER 32 | from .num import RE_NUMBER 33 | from .num import RE_PERCENTAGE 34 | from .num import RE_POSITIVE_QUANTIFIERS 35 | from .num import RE_RANGE 36 | from .num import replace_default_num 37 | from .num import replace_frac 38 | from .num import replace_negative_num 39 | from .num import replace_number 40 | from .num import replace_percentage 41 | from .num import replace_positive_quantifier 42 | from .num import replace_range 43 | from .phonecode import RE_MOBILE_PHONE 44 | from .phonecode import RE_NATIONAL_UNIFORM_NUMBER 45 | from .phonecode import RE_TELEPHONE 46 | from .phonecode import replace_mobile 47 | from .phonecode import replace_phone 48 | from .quantifier import RE_TEMPERATURE 49 | from .quantifier import replace_measure 50 | from .quantifier import replace_temperature 51 | 52 | 53 | class TextNormalizer(): 54 | def __init__(self): 55 | self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)') 56 | 57 | def _split(self, text: str, lang="zh") -> List[str]: 58 | """Split long text into sentences with sentence-splitting punctuations. 59 | Args: 60 | text (str): The input text. 61 | Returns: 62 | List[str]: Sentences. 63 | """ 64 | # Only for pure Chinese here 65 | if lang == "zh": 66 | text = text.replace(" ", "") 67 | # 过滤掉特殊字符 68 | text = re.sub(r'[——《》【】<=>{}()()#&@“”^_|…\\]', '', text) 69 | text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) 70 | text = text.strip() 71 | sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] 72 | return sentences 73 | 74 | def _post_replace(self, sentence: str) -> str: 75 | sentence = sentence.replace('/', '每') 76 | sentence = sentence.replace('~', '至') 77 | sentence = sentence.replace('~', '至') 78 | sentence = sentence.replace('①', '一') 79 | sentence = sentence.replace('②', '二') 80 | sentence = sentence.replace('③', '三') 81 | sentence = sentence.replace('④', '四') 82 | sentence = sentence.replace('⑤', '五') 83 | sentence = sentence.replace('⑥', '六') 84 | sentence = sentence.replace('⑦', '七') 85 | sentence = sentence.replace('⑧', '八') 86 | sentence = sentence.replace('⑨', '九') 87 | sentence = sentence.replace('⑩', '十') 88 | sentence = sentence.replace('α', '阿尔法') 89 | sentence = sentence.replace('β', '贝塔') 90 | sentence = sentence.replace('γ', '伽玛').replace('Γ', '伽玛') 91 | sentence = sentence.replace('δ', '德尔塔').replace('Δ', '德尔塔') 92 | sentence = sentence.replace('ε', '艾普西龙') 93 | sentence = sentence.replace('ζ', '捷塔') 94 | sentence = sentence.replace('η', '依塔') 95 | sentence = sentence.replace('θ', '西塔').replace('Θ', '西塔') 96 | sentence = sentence.replace('ι', '艾欧塔') 97 | sentence = sentence.replace('κ', '喀帕') 98 | sentence = sentence.replace('λ', '拉姆达').replace('Λ', '拉姆达') 99 | sentence = sentence.replace('μ', '缪') 100 | sentence = sentence.replace('ν', '拗') 101 | sentence = sentence.replace('ξ', '克西').replace('Ξ', '克西') 102 | sentence = sentence.replace('ο', '欧米克伦') 103 | sentence = sentence.replace('π', '派').replace('Π', '派') 104 | sentence = sentence.replace('ρ', '肉') 105 | sentence = sentence.replace('ς', '西格玛').replace('Σ', '西格玛').replace( 106 | 'σ', '西格玛') 107 | sentence = sentence.replace('τ', '套') 108 | sentence = sentence.replace('υ', '宇普西龙') 109 | sentence = sentence.replace('φ', '服艾').replace('Φ', '服艾') 110 | sentence = sentence.replace('χ', '器') 111 | sentence = sentence.replace('ψ', '普赛').replace('Ψ', '普赛') 112 | sentence = sentence.replace('ω', '欧米伽').replace('Ω', '欧米伽') 113 | # re filter special characters, have one more character "-" than line 68 114 | sentence = re.sub(r'[-——《》【】<=>{}()()#&@“”^_|…\\]', '', sentence) 115 | return sentence 116 | 117 | def normalize_sentence(self, sentence: str) -> str: 118 | # basic character conversions 119 | sentence = tranditional_to_simplified(sentence) 120 | sentence = sentence.translate(F2H_ASCII_LETTERS).translate( 121 | F2H_DIGITS).translate(F2H_SPACE) 122 | 123 | # number related NSW verbalization 124 | sentence = RE_DATE.sub(replace_date, sentence) 125 | sentence = RE_DATE2.sub(replace_date2, sentence) 126 | 127 | # range first 128 | sentence = RE_TIME_RANGE.sub(replace_time, sentence) 129 | sentence = RE_TIME.sub(replace_time, sentence) 130 | 131 | sentence = RE_TEMPERATURE.sub(replace_temperature, sentence) 132 | sentence = replace_measure(sentence) 133 | sentence = RE_FRAC.sub(replace_frac, sentence) 134 | sentence = RE_PERCENTAGE.sub(replace_percentage, sentence) 135 | sentence = RE_MOBILE_PHONE.sub(replace_mobile, sentence) 136 | 137 | sentence = RE_TELEPHONE.sub(replace_phone, sentence) 138 | sentence = RE_NATIONAL_UNIFORM_NUMBER.sub(replace_phone, sentence) 139 | 140 | sentence = RE_RANGE.sub(replace_range, sentence) 141 | sentence = RE_INTEGER.sub(replace_negative_num, sentence) 142 | sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) 143 | sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, 144 | sentence) 145 | sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) 146 | sentence = RE_NUMBER.sub(replace_number, sentence) 147 | sentence = self._post_replace(sentence) 148 | 149 | return sentence 150 | 151 | def normalize(self, text: str) -> List[str]: 152 | sentences = self._split(text) 153 | sentences = [self.normalize_sentence(sent) for sent in sentences] 154 | return sentences 155 | -------------------------------------------------------------------------------- /GPT_SoVITS/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import traceback 9 | 10 | import librosa 11 | import numpy as np 12 | from scipy.io.wavfile import read 13 | import torch 14 | import logging 15 | 16 | logging.getLogger("numba").setLevel(logging.ERROR) 17 | logging.getLogger("matplotlib").setLevel(logging.ERROR) 18 | 19 | MATPLOTLIB_FLAG = False 20 | 21 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 22 | logger = logging 23 | 24 | 25 | def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): 26 | assert os.path.isfile(checkpoint_path) 27 | checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") 28 | iteration = checkpoint_dict["iteration"] 29 | learning_rate = checkpoint_dict["learning_rate"] 30 | if ( 31 | optimizer is not None 32 | and not skip_optimizer 33 | and checkpoint_dict["optimizer"] is not None 34 | ): 35 | optimizer.load_state_dict(checkpoint_dict["optimizer"]) 36 | saved_state_dict = checkpoint_dict["model"] 37 | if hasattr(model, "module"): 38 | state_dict = model.module.state_dict() 39 | else: 40 | state_dict = model.state_dict() 41 | new_state_dict = {} 42 | for k, v in state_dict.items(): 43 | try: 44 | # assert "quantizer" not in k 45 | # print("load", k) 46 | new_state_dict[k] = saved_state_dict[k] 47 | assert saved_state_dict[k].shape == v.shape, ( 48 | saved_state_dict[k].shape, 49 | v.shape, 50 | ) 51 | except: 52 | traceback.print_exc() 53 | print( 54 | "error, %s is not in the checkpoint" % k 55 | ) # shape不对也会,比如text_embedding当cleaner修改时 56 | new_state_dict[k] = v 57 | if hasattr(model, "module"): 58 | model.module.load_state_dict(new_state_dict) 59 | else: 60 | model.load_state_dict(new_state_dict) 61 | print("load ") 62 | logger.info( 63 | "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration) 64 | ) 65 | return model, optimizer, learning_rate, iteration 66 | 67 | from time import time as ttime 68 | import shutil 69 | def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path 70 | dir=os.path.dirname(path) 71 | name=os.path.basename(path) 72 | tmp_path="%s.pth"%(ttime()) 73 | torch.save(fea,tmp_path) 74 | shutil.move(tmp_path,"%s/%s"%(dir,name)) 75 | 76 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 77 | logger.info( 78 | "Saving model and optimizer state at iteration {} to {}".format( 79 | iteration, checkpoint_path 80 | ) 81 | ) 82 | if hasattr(model, "module"): 83 | state_dict = model.module.state_dict() 84 | else: 85 | state_dict = model.state_dict() 86 | # torch.save( 87 | my_save( 88 | { 89 | "model": state_dict, 90 | "iteration": iteration, 91 | "optimizer": optimizer.state_dict(), 92 | "learning_rate": learning_rate, 93 | }, 94 | checkpoint_path, 95 | ) 96 | 97 | 98 | def summarize( 99 | writer, 100 | global_step, 101 | scalars={}, 102 | histograms={}, 103 | images={}, 104 | audios={}, 105 | audio_sampling_rate=22050, 106 | ): 107 | for k, v in scalars.items(): 108 | writer.add_scalar(k, v, global_step) 109 | for k, v in histograms.items(): 110 | writer.add_histogram(k, v, global_step) 111 | for k, v in images.items(): 112 | writer.add_image(k, v, global_step, dataformats="HWC") 113 | for k, v in audios.items(): 114 | writer.add_audio(k, v, global_step, audio_sampling_rate) 115 | 116 | 117 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 118 | f_list = glob.glob(os.path.join(dir_path, regex)) 119 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 120 | x = f_list[-1] 121 | print(x) 122 | return x 123 | 124 | 125 | def plot_spectrogram_to_numpy(spectrogram): 126 | global MATPLOTLIB_FLAG 127 | if not MATPLOTLIB_FLAG: 128 | import matplotlib 129 | 130 | matplotlib.use("Agg") 131 | MATPLOTLIB_FLAG = True 132 | mpl_logger = logging.getLogger("matplotlib") 133 | mpl_logger.setLevel(logging.WARNING) 134 | import matplotlib.pylab as plt 135 | import numpy as np 136 | 137 | fig, ax = plt.subplots(figsize=(10, 2)) 138 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 139 | plt.colorbar(im, ax=ax) 140 | plt.xlabel("Frames") 141 | plt.ylabel("Channels") 142 | plt.tight_layout() 143 | 144 | fig.canvas.draw() 145 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 146 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 147 | plt.close() 148 | return data 149 | 150 | 151 | def plot_alignment_to_numpy(alignment, info=None): 152 | global MATPLOTLIB_FLAG 153 | if not MATPLOTLIB_FLAG: 154 | import matplotlib 155 | 156 | matplotlib.use("Agg") 157 | MATPLOTLIB_FLAG = True 158 | mpl_logger = logging.getLogger("matplotlib") 159 | mpl_logger.setLevel(logging.WARNING) 160 | import matplotlib.pylab as plt 161 | import numpy as np 162 | 163 | fig, ax = plt.subplots(figsize=(6, 4)) 164 | im = ax.imshow( 165 | alignment.transpose(), aspect="auto", origin="lower", interpolation="none" 166 | ) 167 | fig.colorbar(im, ax=ax) 168 | xlabel = "Decoder timestep" 169 | if info is not None: 170 | xlabel += "\n\n" + info 171 | plt.xlabel(xlabel) 172 | plt.ylabel("Encoder timestep") 173 | plt.tight_layout() 174 | 175 | fig.canvas.draw() 176 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 177 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 178 | plt.close() 179 | return data 180 | 181 | 182 | def load_wav_to_torch(full_path): 183 | data, sampling_rate = librosa.load(full_path, sr=None) 184 | return torch.FloatTensor(data), sampling_rate 185 | 186 | 187 | def load_filepaths_and_text(filename, split="|"): 188 | with open(filename, encoding="utf-8") as f: 189 | filepaths_and_text = [line.strip().split(split) for line in f] 190 | return filepaths_and_text 191 | 192 | 193 | def get_hparams(init=True, stage=1): 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument( 196 | "-c", 197 | "--config", 198 | type=str, 199 | default="./configs/s2.json", 200 | help="JSON file for configuration", 201 | ) 202 | parser.add_argument( 203 | "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir" 204 | ) 205 | parser.add_argument( 206 | "-rs", 207 | "--resume_step", 208 | type=int, 209 | required=False, 210 | default=None, 211 | help="resume step", 212 | ) 213 | # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory') 214 | # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights') 215 | # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights') 216 | 217 | args = parser.parse_args() 218 | 219 | config_path = args.config 220 | with open(config_path, "r") as f: 221 | data = f.read() 222 | config = json.loads(data) 223 | 224 | hparams = HParams(**config) 225 | hparams.pretrain = args.pretrain 226 | hparams.resume_step = args.resume_step 227 | # hparams.data.exp_dir = args.exp_dir 228 | if stage == 1: 229 | model_dir = hparams.s1_ckpt_dir 230 | else: 231 | model_dir = hparams.s2_ckpt_dir 232 | config_save_path = os.path.join(model_dir, "config.json") 233 | 234 | if not os.path.exists(model_dir): 235 | os.makedirs(model_dir) 236 | 237 | with open(config_save_path, "w") as f: 238 | f.write(data) 239 | return hparams 240 | 241 | 242 | def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True): 243 | """Freeing up space by deleting saved ckpts 244 | 245 | Arguments: 246 | path_to_models -- Path to the model directory 247 | n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth 248 | sort_by_time -- True -> chronologically delete ckpts 249 | False -> lexicographically delete ckpts 250 | """ 251 | import re 252 | 253 | ckpts_files = [ 254 | f 255 | for f in os.listdir(path_to_models) 256 | if os.path.isfile(os.path.join(path_to_models, f)) 257 | ] 258 | name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1)) 259 | time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)) 260 | sort_key = time_key if sort_by_time else name_key 261 | x_sorted = lambda _x: sorted( 262 | [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], 263 | key=sort_key, 264 | ) 265 | to_del = [ 266 | os.path.join(path_to_models, fn) 267 | for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep]) 268 | ] 269 | del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") 270 | del_routine = lambda x: [os.remove(x), del_info(x)] 271 | rs = [del_routine(fn) for fn in to_del] 272 | 273 | 274 | def get_hparams_from_dir(model_dir): 275 | config_save_path = os.path.join(model_dir, "config.json") 276 | with open(config_save_path, "r") as f: 277 | data = f.read() 278 | config = json.loads(data) 279 | 280 | hparams = HParams(**config) 281 | hparams.model_dir = model_dir 282 | return hparams 283 | 284 | 285 | def get_hparams_from_file(config_path): 286 | with open(config_path, "r") as f: 287 | data = f.read() 288 | config = json.loads(data) 289 | 290 | hparams = HParams(**config) 291 | return hparams 292 | 293 | 294 | def check_git_hash(model_dir): 295 | source_dir = os.path.dirname(os.path.realpath(__file__)) 296 | if not os.path.exists(os.path.join(source_dir, ".git")): 297 | logger.warn( 298 | "{} is not a git repository, therefore hash value comparison will be ignored.".format( 299 | source_dir 300 | ) 301 | ) 302 | return 303 | 304 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 305 | 306 | path = os.path.join(model_dir, "githash") 307 | if os.path.exists(path): 308 | saved_hash = open(path).read() 309 | if saved_hash != cur_hash: 310 | logger.warn( 311 | "git hash values are different. {}(saved) != {}(current)".format( 312 | saved_hash[:8], cur_hash[:8] 313 | ) 314 | ) 315 | else: 316 | open(path, "w").write(cur_hash) 317 | 318 | 319 | def get_logger(model_dir, filename="train.log"): 320 | global logger 321 | logger = logging.getLogger(os.path.basename(model_dir)) 322 | logger.setLevel(logging.DEBUG) 323 | 324 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 325 | if not os.path.exists(model_dir): 326 | os.makedirs(model_dir) 327 | h = logging.FileHandler(os.path.join(model_dir, filename)) 328 | h.setLevel(logging.DEBUG) 329 | h.setFormatter(formatter) 330 | logger.addHandler(h) 331 | return logger 332 | 333 | 334 | class HParams: 335 | def __init__(self, **kwargs): 336 | for k, v in kwargs.items(): 337 | if type(v) == dict: 338 | v = HParams(**v) 339 | self[k] = v 340 | 341 | def keys(self): 342 | return self.__dict__.keys() 343 | 344 | def items(self): 345 | return self.__dict__.items() 346 | 347 | def values(self): 348 | return self.__dict__.values() 349 | 350 | def __len__(self): 351 | return len(self.__dict__) 352 | 353 | def __getitem__(self, key): 354 | return getattr(self, key) 355 | 356 | def __setitem__(self, key, value): 357 | return setattr(self, key, value) 358 | 359 | def __contains__(self, key): 360 | return key in self.__dict__ 361 | 362 | def __repr__(self): 363 | return self.__dict__.__repr__() 364 | 365 | 366 | if __name__ == "__main__": 367 | print( 368 | load_wav_to_torch( 369 | "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac" 370 | ) 371 | ) 372 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 👏 项目描述 3 | 4 | 原始[GPT_SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)的效果体验和推理服务较大依赖于基于Gradio的webui界面,为了更方便地推理体验GPT_SoVITS效果,本项目将其推理部分提取并暴露出来,支持一键式的推理部署。 5 | 6 | ## 🔥 模型列表 7 | 8 | | 模型名称 | 模型下载 | 角色特点 | 语言 | 9 | | :----: | :----: | :----: | :----: | 10 | | TTS-GPT_SoVITS-sunshine_girl | [🤗]() / [🤖](https://modelscope.cn/models/X-D-Lab/TTS-GPT_SoVITS-sunshine_girl/summary) | 阳光少女 | zh | 11 | | TTS-GPT_SoVITS-heartful_sister | [🤗]() / [🤖](https://modelscope.cn/models/X-D-Lab/TTS-GPT_SoVITS-heartful_sister/summary) | 知性姐姐 | zh | 12 | 13 | - 预训练模型 14 | 15 | | 模型名称 | 模型下载 | 16 | | :----: | :----: | 17 | | GPT-SoVITS | [🤗]() / [🤖](https://modelscope.cn/models/X-D-Lab/TTS-GPT_SoVITS-pretrained_models/summary) | 18 | 19 | 20 | ## ⚒️ 安装依赖 21 | 22 | 推荐 Python>=3.9,<=3.10 23 | 24 | ``` 25 | conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia 26 | 27 | git clone https://github.com/X-D-Lab/GPT_SoVITS_Inference.git 28 | cd GPT_SoVITS_Inference 29 | pip install -r requirements.txt 30 | ``` 31 | 如果您是windows使用者,请下载并将 [ffmpeg.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffmpeg.exe) 和 [ffprobe.exe](https://huggingface.co/lj1995/VoiceConversionWebUI/blob/main/ffprobe.exe) 放置在本项目的根目录下。 32 | 33 | 34 | ## 😇 如何使用 35 | 36 | 详细内容可以参见[example.py](./example.py) 37 | 38 | ```Python 39 | 40 | import os 41 | import sys 42 | 43 | project_root = os.path.abspath('.') 44 | sys.path.append(project_root) 45 | 46 | 47 | from get_tts_wav import GPT_SoVITS_TTS_inference 48 | 49 | text = """我是MindChat漫谈心理大模型""" 50 | 51 | inference = GPT_SoVITS_TTS_inference(prompt_language='zh', base_model_id='X-D-Lab/TTS-GPT_SoVITS-pretrained_models', audio_model_id='X-D-Lab/TTS-GPT_SoVITS-sunshine_girl') 52 | 53 | inference.get_tts_wav(text=text, wav_save_path="./temp/output1.wav") 54 | 55 | ``` 56 | ## 👏 Contributors 57 | 本项目仍然属于非常早期的阶段,欢迎各位开发者加入! 58 | 59 | 60 | 61 | 62 | 63 | ### 🙇‍ 致谢 64 | 65 | 本项目基于[GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)进行,感谢他们的开源贡献。 66 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | #gpt_sovits_tts的根目录 ,每次从gpt_sovits_tts导入模块时执行,确保路径不乱飘 3 | root_dir = os.path.dirname(__file__) 4 | sys.path.append(root_dir) 5 | sys.path.append("%s/GPT_SoVITS"%(root_dir)) -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # 可以在任意地方跨目录调用get_tts_wav() 2 | """ 3 | # ===关于推理文本的语种 参考=== 4 | # 在config和调用get_tts_wav时,对于prompt_language和text_language参数 5 | "all_zh" #全部按中文识别 6 | "en" #全部按英文识别#######不变 7 | "all_ja" #全部按日文识别 8 | "zh" #按中英混合识别####不变 9 | "ja" #按日英混合识别####不变 10 | "auto" #多语种混合,启动切分识别语种 11 | } 12 | """ 13 | """ 14 | def get_tts_wav( 15 | text: str, # 要转换为语音的文本。get_tts_wav()内部会对文本按标点自动切割。 16 | text_language: str = "zh", # 推理出的语音语言 17 | wav_savepath: str = "temp/output.wav" # 推理结果存放的路径与文件名称。会得到一个完整的wav 18 | ==其他次要参数== 19 | how_to_cut: str = "凑四句一切", # 切割推理文本的方法,一共有5种。 20 | # 推荐"凑四句一切"和"按标点符号切"。"按标点符号切"语速最慢,推理最准确 21 | # "凑四句一切","凑50字一切","按中文句号。切","按英文句号.切","按标点符号切" 22 | top_k: int = 20, 23 | top_p: float = 0.6, 24 | temperature: float = 0.6, 25 | # 关于上面三个参数 https://github.com/RVC-Boss/GPT-SoVITS/pull/457 26 | ref_free: bool = False # 不输入参考音频内对应文本,进行推理。默认关闭 27 | ) -> None 28 | """ 29 | import os 30 | import sys 31 | 32 | project_root = os.path.abspath('.') 33 | sys.path.append(project_root) 34 | 35 | # from gpt_sovits_tts.get_tts_wav import GPT_SoVITS_TTS_inference 36 | from get_tts_wav import GPT_SoVITS_TTS_inference 37 | 38 | text = """我是MindChat漫谈心理大模型""" 39 | 40 | """ 41 | # 目前[20240227]modelscope上可用的语音模型audio_model_id 42 | X-D-Lab/TTS-GPT_SoVITS-sunshine_girl 43 | X-D-Lab/TTS-GPT_SoVITS-heartful_sister 44 | """ 45 | 46 | inference = GPT_SoVITS_TTS_inference(prompt_language='zh', base_model_id='X-D-Lab/TTS-GPT_SoVITS-pretrained_models', audio_model_id='X-D-Lab/TTS-GPT_SoVITS-sunshine_girl') 47 | 48 | inference.get_tts_wav(text=text, wav_save_path="./temp/output1.wav") 49 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda install -c conda-forge gcc 3 | conda install -c conda-forge gxx 4 | conda install ffmpeg cmake 5 | conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia 6 | pip install -r requirements.txt 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | tensorboard 4 | librosa==0.9.2 5 | numba==0.56.4 6 | pytorch-lightning 7 | ffmpeg-python 8 | onnxruntime 9 | tqdm 10 | funasr 11 | cn2an 12 | pypinyin 13 | pyopenjtalk 14 | g2p_en 15 | torchaudio 16 | modelscope 17 | sentencepiece 18 | transformers 19 | chardet 20 | PyYAML 21 | psutil 22 | jieba_fast 23 | LangSegment -------------------------------------------------------------------------------- /requirements_win.txt: -------------------------------------------------------------------------------- 1 | # 20240205 python 3.9 2 | # check on win 3 | cn2an 4 | einops 5 | ffmpeg_python==0.2.0 6 | g2p_en 7 | gruut 8 | jieba_fast 9 | librosa==0.9.2 10 | numpy==1.23.5 11 | pyopenjtalk 12 | pypinyin 13 | pytorch_lightning==2.1.4 14 | PyYAML 15 | regex 16 | Requests 17 | scipy 18 | soundfile 19 | tools 20 | --find-links https://download.pytorch.org/whl/torch/ 21 | torch==2.1.1+cu118 22 | torchmetrics==1.3.0.post0 23 | tqdm 24 | transformers 25 | typeguard 26 | LangSegment 27 | modelscope 28 | -------------------------------------------------------------------------------- /temp/nltk_data/corpora/cmudict.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/temp/nltk_data/corpora/cmudict.zip -------------------------------------------------------------------------------- /temp/nltk_data/corpora/cmudict/README: -------------------------------------------------------------------------------- 1 | The Carnegie Mellon Pronouncing Dictionary [cmudict.0.7a] 2 | 3 | ftp://ftp.cs.cmu.edu/project/speech/dict/ 4 | https://cmusphinx.svn.sourceforge.net/svnroot/cmusphinx/trunk/cmudict/cmudict.0.7a 5 | 6 | Copyright (C) 1993-2008 Carnegie Mellon University. All rights reserved. 7 | 8 | File Format: Each line consists of an uppercased word, 9 | a counter (for alternative pronunciations), and a transcription. 10 | Vowels are marked for stress (1=primary, 2=secondary, 0=no stress). 11 | E.g.: NATURAL 1 N AE1 CH ER0 AH0 L 12 | 13 | The dictionary contains 127069 entries. Of these, 119400 words are assigned 14 | a unique pronunciation, 6830 words have two pronunciations, and 839 words have 15 | three or more pronunciations. Many of these are fast-speech variants. 16 | 17 | Phonemes: There are 39 phonemes, as shown below: 18 | 19 | Phoneme Example Translation Phoneme Example Translation 20 | ------- ------- ----------- ------- ------- ----------- 21 | AA odd AA D AE at AE T 22 | AH hut HH AH T AO ought AO T 23 | AW cow K AW AY hide HH AY D 24 | B be B IY CH cheese CH IY Z 25 | D dee D IY DH thee DH IY 26 | EH Ed EH D ER hurt HH ER T 27 | EY ate EY T F fee F IY 28 | G green G R IY N HH he HH IY 29 | IH it IH T IY eat IY T 30 | JH gee JH IY K key K IY 31 | L lee L IY M me M IY 32 | N knee N IY NG ping P IH NG 33 | OW oat OW T OY toy T OY 34 | P pee P IY R read R IY D 35 | S sea S IY SH she SH IY 36 | T tea T IY TH theta TH EY T AH 37 | UH hood HH UH D UW two T UW 38 | V vee V IY W we W IY 39 | Y yield Y IY L D Z zee Z IY 40 | ZH seizure S IY ZH ER 41 | 42 | (For NLTK, entries have been sorted so that, e.g. FIRE 1 and FIRE 2 43 | are contiguous, and not separated by FIRE'S 1.) 44 | 45 | Redistribution and use in source and binary forms, with or without 46 | modification, are permitted provided that the following conditions 47 | are met: 48 | 49 | 1. Redistributions of source code must retain the above copyright 50 | notice, this list of conditions and the following disclaimer. 51 | The contents of this file are deemed to be source code. 52 | 53 | 2. Redistributions in binary form must reproduce the above copyright 54 | notice, this list of conditions and the following disclaimer in 55 | the documentation and/or other materials provided with the 56 | distribution. 57 | 58 | This work was supported in part by funding from the Defense Advanced 59 | Research Projects Agency, the Office of Naval Research and the National 60 | Science Foundation of the United States of America, and by member 61 | companies of the Carnegie Mellon Sphinx Speech Consortium. We acknowledge 62 | the contributions of many volunteers to the expansion and improvement of 63 | this dictionary. 64 | 65 | THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 66 | ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 67 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 68 | PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY 69 | NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 70 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 71 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 72 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 73 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 74 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 75 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 76 | 77 | -------------------------------------------------------------------------------- /temp/nltk_data/taggers/averaged_perceptron_tagger.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/temp/nltk_data/taggers/averaged_perceptron_tagger.zip -------------------------------------------------------------------------------- /temp/nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-D-Lab/GPT_SoVITS_Inference/aab032273341da78c44aabf4fb1c132e65f6f29c/temp/nltk_data/taggers/averaged_perceptron_tagger/averaged_perceptron_tagger.pickle --------------------------------------------------------------------------------