├── .gitignore ├── AR ├── __init__.py ├── data │ ├── __init__.py │ ├── bucket_sampler.py │ ├── data_module.py │ ├── data_module_librilight_6k.py │ ├── dataset.py │ └── dataset_librilight_6k.py ├── exps │ ├── __init__.py │ ├── beats │ │ ├── BEATs.py │ │ ├── README.md │ │ ├── Tokenizers.py │ │ ├── __init__.py │ │ ├── backbone.py │ │ ├── config.py │ │ ├── modules.py │ │ ├── ontology.json │ │ └── quantizer.py │ ├── get_beats_librilight.py │ ├── get_phones.py │ ├── get_phones_librilight.py │ ├── get_txt_librilight.py │ ├── split_train_val.py │ ├── t2s.py │ ├── test.py │ ├── text.txt │ ├── train.py │ └── train_librilight_6k.py ├── models │ ├── __init__.py │ ├── t2s_lightning_module.py │ ├── t2s_model.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── activation.py │ ├── embedding.py │ ├── lr_schedulers.py │ ├── optim.py │ ├── scaling.py │ └── transformer.py ├── text_processing │ ├── __init__.py │ ├── phonemizer.py │ └── symbols.py └── utils │ ├── __init__.py │ ├── initialize.py │ └── io.py ├── LICENSE ├── README.md ├── configs ├── s1.yaml └── s2.json ├── data_conf.py ├── extract_ssl_s2.py ├── extract_vq_s1.py ├── feature_extractor ├── __init__.py ├── cnhubert.py └── whisper_enc.py ├── gen_filelist_s1.py ├── gen_filelist_s2.py ├── gen_phonemes.py ├── module ├── __init__.py ├── attentions.py ├── commons.py ├── core_vq.py ├── data_utils.py ├── losses.py ├── mel_processing.py ├── models.py ├── modules.py ├── mrte_model.py ├── quantize.py └── transforms.py ├── requirements.txt ├── resample.py ├── resources └── structure.png ├── s1_infer.py ├── s1_train.py ├── s2_infer.py ├── s2_train.py ├── text ├── __init__.py ├── chinese.py ├── cleaner.py ├── cmudict.rep ├── cmudict_cache.pickle ├── english.py ├── japanese.py ├── opencpop-strict.txt ├── symbols.py └── tone_sandhi.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /AR/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/AR/__init__.py -------------------------------------------------------------------------------- /AR/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/AR/data/__init__.py -------------------------------------------------------------------------------- /AR/data/bucket_sampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/bucketsampler.py 2 | import itertools 3 | import math 4 | import random 5 | from random import shuffle 6 | from typing import Iterator 7 | from typing import Optional 8 | from typing import TypeVar 9 | 10 | import torch 11 | import torch.distributed as dist 12 | from torch.utils.data import Dataset 13 | from torch.utils.data import Sampler 14 | 15 | __all__ = [ 16 | "DistributedBucketSampler", 17 | ] 18 | 19 | T_co = TypeVar('T_co', covariant=True) 20 | 21 | 22 | class DistributedBucketSampler(Sampler[T_co]): 23 | r""" 24 | sort the dataset wrt. input length 25 | divide samples into buckets 26 | sort within buckets 27 | divide buckets into batches 28 | sort batches 29 | """ 30 | 31 | def __init__(self, 32 | dataset: Dataset, 33 | num_replicas: Optional[int]=None, 34 | rank: Optional[int]=None, 35 | shuffle: bool=True, 36 | seed: int=0, 37 | drop_last: bool=False, 38 | batch_size: int=32) -> None: 39 | if num_replicas is None: 40 | if not dist.is_available(): 41 | raise RuntimeError( 42 | "Requires distributed package to be available") 43 | num_replicas = dist.get_world_size() 44 | if rank is None: 45 | if not dist.is_available(): 46 | raise RuntimeError( 47 | "Requires distributed package to be available") 48 | rank = dist.get_rank() 49 | torch.cuda.set_device(rank) 50 | if rank >= num_replicas or rank < 0: 51 | raise ValueError("Invalid rank {}, rank should be in the interval" 52 | " [0, {}]".format(rank, num_replicas - 1)) 53 | self.dataset = dataset 54 | self.num_replicas = num_replicas 55 | self.rank = rank 56 | self.epoch = 0 57 | self.drop_last = drop_last 58 | # If the dataset length is evenly divisible by # of replicas, then there 59 | # is no need to drop any data, since the dataset will be split equally. 60 | if self.drop_last and len( 61 | self. 62 | dataset) % self.num_replicas != 0: # type: ignore[arg-type] 63 | # Split to nearest available length that is evenly divisible. 64 | # This is to ensure each rank receives the same amount of data when 65 | # using this Sampler. 66 | self.num_samples = math.ceil( 67 | (len(self.dataset) - self.num_replicas) / 68 | self.num_replicas # type: ignore[arg-type] 69 | ) 70 | else: 71 | self.num_samples = math.ceil( 72 | len(self.dataset) / self.num_replicas) # type: ignore[arg-type] 73 | self.total_size = self.num_samples * self.num_replicas 74 | self.shuffle = shuffle 75 | self.seed = seed 76 | self.batch_size = batch_size 77 | self.id_with_length = self._get_sample_lengths() 78 | self.id_buckets = self.make_buckets(bucket_width=2.0) 79 | 80 | def _get_sample_lengths(self): 81 | id_with_lengths = [] 82 | for i in range(len(self.dataset)): 83 | id_with_lengths.append((i, self.dataset.get_sample_length(i))) 84 | id_with_lengths.sort(key=lambda x: x[1]) 85 | return id_with_lengths 86 | 87 | def make_buckets(self, bucket_width: float=2.0): 88 | buckets = [] 89 | cur = [] 90 | max_sec = bucket_width 91 | for id, sec in self.id_with_length: 92 | if sec < max_sec: 93 | cur.append(id) 94 | else: 95 | buckets.append(cur) 96 | cur = [id] 97 | max_sec += bucket_width 98 | if len(cur) > 0: 99 | buckets.append(cur) 100 | return buckets 101 | 102 | def __iter__(self) -> Iterator[T_co]: 103 | if self.shuffle: 104 | # deterministically shuffle based on epoch and seed 105 | g = torch.Generator() 106 | g.manual_seed(self.seed + self.epoch) 107 | random.seed(self.epoch + self.seed) 108 | shuffled_bucket = [] 109 | for buc in self.id_buckets: 110 | buc_copy = buc.copy() 111 | shuffle(buc_copy) 112 | shuffled_bucket.append(buc_copy) 113 | grouped_batch_size = self.batch_size * self.num_replicas 114 | shuffled_bucket = list(itertools.chain(*shuffled_bucket)) 115 | n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size)) 116 | batches = [ 117 | shuffled_bucket[b * grouped_batch_size:(b + 1) * 118 | grouped_batch_size] for b in range(n_batch) 119 | ] 120 | shuffle(batches) 121 | indices = list(itertools.chain(*batches)) 122 | else: 123 | # type: ignore[arg-type] 124 | indices = list(range(len(self.dataset))) 125 | 126 | if not self.drop_last: 127 | # add extra samples to make it evenly divisible 128 | padding_size = self.total_size - len(indices) 129 | if padding_size <= len(indices): 130 | indices += indices[:padding_size] 131 | else: 132 | indices += (indices * math.ceil(padding_size / 133 | len(indices)))[:padding_size] 134 | else: 135 | # remove tail of data to make it evenly divisible. 136 | indices = indices[:self.total_size] 137 | assert len(indices) == self.total_size 138 | 139 | # subsample 140 | indices = indices[self.rank:self.total_size:self.num_replicas] 141 | assert len(indices) == self.num_samples 142 | 143 | return iter(indices) 144 | 145 | def __len__(self) -> int: 146 | return self.num_samples 147 | 148 | def set_epoch(self, epoch: int) -> None: 149 | r""" 150 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 151 | use a different random ordering for each epoch. Otherwise, the next iteration of this 152 | sampler will yield the same ordering. 153 | 154 | Args: 155 | epoch (int): Epoch number. 156 | """ 157 | self.epoch = epoch 158 | -------------------------------------------------------------------------------- /AR/data/data_module.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py 2 | from pytorch_lightning import LightningDataModule 3 | from AR.data.bucket_sampler import DistributedBucketSampler 4 | from AR.data.dataset import Text2SemanticDataset 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class Text2SemanticDataModule(LightningDataModule): 9 | def __init__(self, config, train_semantic_path, train_phoneme_path, 10 | dev_semantic_path, dev_phoneme_path): 11 | super().__init__() 12 | self.config = config 13 | self.train_semantic_path = train_semantic_path 14 | self.train_phoneme_path = train_phoneme_path 15 | self.dev_semantic_path = dev_semantic_path 16 | self.dev_phoneme_path = dev_phoneme_path 17 | self.num_workers = self.config['data']['num_workers'] 18 | 19 | def prepare_data(self): 20 | pass 21 | 22 | def setup(self, stage=None, output_logs=False): 23 | self._train_dataset = Text2SemanticDataset( 24 | phoneme_path=self.train_phoneme_path, 25 | semantic_path=self.train_semantic_path, 26 | max_sec=self.config['data']['max_sec'], 27 | pad_val=self.config['data']['pad_val']) 28 | self._dev_dataset = Text2SemanticDataset( 29 | phoneme_path=self.dev_phoneme_path, 30 | semantic_path=self.dev_semantic_path, 31 | max_sample=self.config['data']['max_eval_sample'], 32 | max_sec=self.config['data']['max_sec'], 33 | pad_val=self.config['data']['pad_val']) 34 | 35 | def train_dataloader(self): 36 | batch_size = self.config['train']['batch_size'] 37 | sampler = DistributedBucketSampler( 38 | self._train_dataset, batch_size=batch_size) 39 | return DataLoader( 40 | self._train_dataset, 41 | batch_size=batch_size, 42 | sampler=sampler, 43 | collate_fn=self._train_dataset.collate, 44 | num_workers=self.num_workers,) 45 | 46 | def val_dataloader(self): 47 | return DataLoader( 48 | self._dev_dataset, 49 | batch_size=1, 50 | shuffle=False, 51 | collate_fn=self._train_dataset.collate, 52 | num_workers=self.num_workers,) 53 | 54 | # 这个会使用到嘛? 55 | def test_dataloader(self): 56 | return DataLoader( 57 | self._dev_dataset, 58 | batch_size=1, 59 | shuffle=False, 60 | collate_fn=self._train_dataset.collate) 61 | -------------------------------------------------------------------------------- /AR/data/data_module_librilight_6k.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py 2 | from pytorch_lightning import LightningDataModule 3 | from AR.data.bucket_sampler import DistributedBucketSampler 4 | from AR.data.dataset_librilight_6k import Text2SemanticDataset 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class Text2SemanticDataModule(LightningDataModule): 9 | def __init__(self, 10 | config, 11 | train_semantic_dirs, 12 | train_phoneme_dirs, 13 | dev_semantic_dirs, 14 | dev_phoneme_dirs, 15 | train_non_speech_dirs=None, 16 | dev_non_speech_dirs=None): 17 | super().__init__() 18 | self.config = config 19 | self.train_semantic_dirs = train_semantic_dirs 20 | self.train_phoneme_dirs = train_phoneme_dirs 21 | self.dev_semantic_dirs = dev_semantic_dirs 22 | self.dev_phoneme_dirs = dev_phoneme_dirs 23 | self.train_non_speech_dirs = train_non_speech_dirs 24 | self.dev_non_speech_dirs = dev_non_speech_dirs 25 | self.num_workers = self.config['data']['num_workers'] 26 | 27 | def prepare_data(self): 28 | pass 29 | 30 | def setup(self, stage=None, output_logs=False): 31 | self._train_dataset = Text2SemanticDataset( 32 | phoneme_dirs=self.train_phoneme_dirs, 33 | semantic_dirs=self.train_semantic_dirs, 34 | non_speech_dirs=self.train_non_speech_dirs, 35 | max_sec=self.config['data']['max_sec'], 36 | pad_val=self.config['data']['pad_val']) 37 | self._dev_dataset = Text2SemanticDataset( 38 | phoneme_dirs=self.dev_phoneme_dirs, 39 | semantic_dirs=self.dev_semantic_dirs, 40 | non_speech_dirs=self.dev_non_speech_dirs, 41 | max_sample=self.config['data']['max_eval_sample'], 42 | max_sec=self.config['data']['max_sec'], 43 | pad_val=self.config['data']['pad_val']) 44 | 45 | def train_dataloader(self): 46 | batch_size = self.config['train']['batch_size'] 47 | sampler = DistributedBucketSampler( 48 | self._train_dataset, batch_size=batch_size) 49 | return DataLoader( 50 | self._train_dataset, 51 | batch_size=batch_size, 52 | sampler=sampler, 53 | collate_fn=self._train_dataset.collate, 54 | num_workers=self.num_workers, ) 55 | 56 | def val_dataloader(self): 57 | return DataLoader( 58 | self._dev_dataset, 59 | batch_size=1, 60 | shuffle=False, 61 | collate_fn=self._train_dataset.collate, 62 | num_workers=self.num_workers, ) 63 | 64 | # 这个会使用到嘛? 65 | def test_dataloader(self): 66 | return DataLoader( 67 | self._dev_dataset, 68 | batch_size=1, 69 | shuffle=False, 70 | collate_fn=self._train_dataset.collate) 71 | -------------------------------------------------------------------------------- /AR/data/dataset.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/t2s_dataset.py 2 | from typing import Dict 3 | from typing import List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data import Dataset 10 | 11 | from text import cleaned_text_to_sequence 12 | 13 | 14 | def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0): 15 | seq = sequences[0] 16 | ndim = seq.ndim 17 | if axis < 0: 18 | axis += ndim 19 | dtype = seq.dtype 20 | pad_value = dtype.type(pad_value) 21 | seq_lengths = [seq.shape[axis] for seq in sequences] 22 | max_length = np.max(seq_lengths) 23 | 24 | padded_sequences = [] 25 | for seq, length in zip(sequences, seq_lengths): 26 | padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * ( 27 | ndim - axis - 1) 28 | padded_seq = np.pad( 29 | seq, padding, mode='constant', constant_values=pad_value) 30 | padded_sequences.append(padded_seq) 31 | batch = np.stack(padded_sequences) 32 | return batch 33 | 34 | 35 | class Text2SemanticDataset(Dataset): 36 | """dataset class for text tokens to semantic model training.""" 37 | 38 | def __init__(self, 39 | phoneme_path: str, 40 | semantic_path: str, 41 | max_sample: int = None, 42 | max_sec: int = 100, 43 | pad_val: int = 1024, 44 | # min value of phoneme/sec 45 | min_ps_ratio: int = 6, 46 | # max value of phoneme/sec 47 | max_ps_ratio: int = 25) -> None: 48 | super().__init__() 49 | 50 | self.semantic_data = pd.read_csv(semantic_path, delimiter='\t') 51 | # get dict 52 | self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() 53 | # pad for semantic tokens 54 | self.PAD: int = pad_val 55 | self.hz = 50 56 | # max seconds of semantic token 57 | self.max_sec = max_sec 58 | self.min_ps_ratio = min_ps_ratio 59 | self.max_ps_ratio = max_ps_ratio 60 | 61 | if max_sample is not None: 62 | self.semantic_data = self.semantic_data[:max_sample] 63 | 64 | # {idx: (semantic, phoneme)} 65 | # semantic list, phoneme list 66 | self.semantic_phoneme = [] 67 | self.item_names = [] 68 | 69 | self.inited = False 70 | 71 | if not self.inited: 72 | # 调用初始化函数 73 | self.init_batch() 74 | self.inited = True 75 | del self.semantic_data 76 | del self.phoneme_data 77 | 78 | 79 | def init_batch(self): 80 | semantic_data_len = len(self.semantic_data) 81 | phoneme_data_len = len(self.phoneme_data.keys()) 82 | print("semantic_data_len:", semantic_data_len) 83 | print("phoneme_data_len:", phoneme_data_len) 84 | idx = 0 85 | num_not_in = 0 86 | num_deleted_bigger = 0 87 | num_deleted_ps = 0 88 | for i in range(semantic_data_len): 89 | # 先依次遍历 90 | # get str 91 | item_name = self.semantic_data['item_name'][i] 92 | try: 93 | phoneme = self.phoneme_data[item_name] 94 | except Exception: 95 | # print(f"{item_name} not in self.phoneme_data !") 96 | num_not_in += 1 97 | continue 98 | 99 | semantic_str = self.semantic_data['semantic_audio'][i] 100 | # get token list 101 | semantic_ids = [int(idx) for idx in semantic_str.split(' ')] 102 | # (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len 103 | # 过滤掉太长的样本 104 | if len(semantic_ids) > self.max_sec * self.hz: 105 | num_deleted_bigger += 1 106 | continue 107 | 108 | # (T, ), 这个速度不会很慢,所以可以在一开始就处理,无需在 __getitem__ 里面单个处理 109 | phoneme = phoneme.split(' ') 110 | phoneme_ids = cleaned_text_to_sequence(phoneme) 111 | if len(phoneme_ids) >400: 112 | num_deleted_ps += 1 113 | continue 114 | if len(semantic_ids) > 1024: 115 | num_deleted_bigger += 1 116 | continue 117 | 118 | ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz) 119 | 120 | if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: 121 | num_deleted_ps += 1 122 | continue 123 | 124 | self.semantic_phoneme.append((semantic_ids, phoneme_ids)) 125 | idx += 1 126 | self.item_names.append(item_name) 127 | if num_not_in > 0: 128 | print(f"there are {num_not_in} semantic datas not in phoneme datas") 129 | if num_deleted_bigger > 0: 130 | print( 131 | f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds" 132 | ) 133 | if num_deleted_ps > 0: 134 | # 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值 135 | print( 136 | f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}" 137 | ) 138 | # 345410 for LibriTTS 139 | print("dataset.__len__():", self.__len__()) 140 | 141 | def __get_item_names__(self) -> List[str]: 142 | return self.item_names 143 | 144 | def __len__(self) -> int: 145 | return len(self.semantic_phoneme) 146 | 147 | def __getitem__(self, idx: int) -> Dict: 148 | semantic_ids, phoneme_ids = self.semantic_phoneme[idx] 149 | phoneme_ids_len = len(phoneme_ids) 150 | # semantic tokens target 151 | semantic_ids_len = len(semantic_ids) 152 | return { 153 | 'idx': idx, 154 | 'phoneme_ids': phoneme_ids, 155 | 'phoneme_ids_len': phoneme_ids_len, 156 | 'semantic_ids': semantic_ids, 157 | 'semantic_ids_len': semantic_ids_len 158 | } 159 | 160 | def get_sample_length(self, idx: int): 161 | semantic_ids = self.semantic_phoneme[idx][0] 162 | sec = 1.0 * len(semantic_ids) / self.hz 163 | return sec 164 | 165 | def collate(self, examples: List[Dict]) -> Dict: 166 | sample_index: List[int] = [] 167 | phoneme_ids: List[torch.Tensor] = [] 168 | phoneme_ids_lens: List[int] = [] 169 | semantic_ids: List[torch.Tensor] = [] 170 | semantic_ids_lens: List[int] = [] 171 | 172 | for item in examples: 173 | sample_index.append(item["idx"]) 174 | phoneme_ids.append(np.array(item["phoneme_ids"], dtype=np.int64)) 175 | semantic_ids.append(np.array(item["semantic_ids"], dtype=np.int64)) 176 | phoneme_ids_lens.append(item["phoneme_ids_len"]) 177 | semantic_ids_lens.append(item["semantic_ids_len"]) 178 | 179 | # pad 0 180 | phoneme_ids = batch_sequences(phoneme_ids) 181 | semantic_ids = batch_sequences(semantic_ids, pad_value=self.PAD) 182 | 183 | # # convert each batch to torch.tensor 184 | phoneme_ids = torch.tensor(phoneme_ids) 185 | semantic_ids = torch.tensor(semantic_ids) 186 | phoneme_ids_lens = torch.tensor(phoneme_ids_lens) 187 | semantic_ids_lens = torch.tensor(semantic_ids_lens) 188 | 189 | return { 190 | # List[int] 191 | "ids": sample_index, 192 | # torch.Tensor (B, max_phoneme_length) 193 | "phoneme_ids": phoneme_ids, 194 | # torch.Tensor (B) 195 | "phoneme_ids_len": phoneme_ids_lens, 196 | # torch.Tensor (B, max_semantic_ids_length) 197 | "semantic_ids": semantic_ids, 198 | # torch.Tensor (B) 199 | "semantic_ids_len": semantic_ids_lens, 200 | } 201 | 202 | 203 | if __name__ == '__main__': 204 | root_dir = '/nfs-speech-cpfs/dev/yuantian04/Vivid_TTS/SoundStorm/SoundStorm/ar_s1/SoundStorm/dump_libritts/train/' 205 | dataset = Text2SemanticDataset( 206 | phoneme_path=root_dir + 'phonemes.npy', 207 | semantic_path=root_dir + 'semantic_token.tsv') 208 | 209 | batch_size = 12 210 | dataloader = DataLoader( 211 | dataset, 212 | batch_size=batch_size, 213 | collate_fn=dataset.collate, 214 | shuffle=False) 215 | # for i, batch in enumerate(dataloader): 216 | # if i == 0: 217 | # print('batch["ids"]:', batch["ids"]) 218 | # print('batch["phoneme_ids"]:', batch["phoneme_ids"], 219 | # batch["phoneme_ids"].shape) 220 | # print('batch["phoneme_ids_len"]:', batch["phoneme_ids_len"], 221 | # batch["phoneme_ids_len"].shape) 222 | # print('batch["semantic_ids"]:', batch["semantic_ids"], 223 | # batch["semantic_ids"].shape) 224 | # print('batch["semantic_ids_len"]:', batch["semantic_ids_len"], 225 | # batch["semantic_ids_len"].shape) 226 | -------------------------------------------------------------------------------- /AR/exps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/AR/exps/__init__.py -------------------------------------------------------------------------------- /AR/exps/beats/BEATs.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | import logging 10 | from typing import Optional 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torchaudio.compliance.kaldi as ta_kaldi 15 | from torch.nn import LayerNorm 16 | 17 | from .backbone import TransformerEncoder 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class BEATsConfig: 23 | def __init__(self, cfg=None): 24 | self.input_patch_size: int = -1 # path size of patch embedding 25 | self.embed_dim: int = 512 # patch embedding dimension 26 | self.conv_bias: bool = False # include bias in conv encoder 27 | 28 | self.encoder_layers: int = 12 # num encoder layers in the transformer 29 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 30 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 31 | self.encoder_attention_heads: int = 12 # num encoder attention heads 32 | self.activation_fn: str = "gelu" # activation function to use 33 | 34 | self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay 35 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 36 | self.deep_norm: bool = False # apply deep_norm first in the transformer 37 | 38 | # dropouts 39 | self.dropout: float = 0.1 # dropout probability for the transformer 40 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 41 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 42 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 43 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 44 | 45 | # positional embeddings 46 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 47 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 48 | 49 | # relative position embedding 50 | self.relative_position_embedding: bool = False # apply relative position embedding 51 | self.num_buckets: int = 320 # number of buckets for relative position embedding 52 | self.max_distance: int = 1280 # maximum distance for relative position embedding 53 | self.gru_rel_pos: bool = False # apply gated relative position embedding 54 | 55 | # label predictor 56 | self.finetuned_model: bool = False # whether the model is a fine-tuned model. 57 | self.predictor_dropout: float = 0.1 # dropout probability for the predictor 58 | self.predictor_class: int = 527 # target class number for the predictor 59 | 60 | if cfg is not None: 61 | self.update(cfg) 62 | 63 | def update(self, cfg: dict): 64 | self.__dict__.update(cfg) 65 | 66 | 67 | class BEATs(nn.Module): 68 | def __init__( 69 | self, 70 | cfg: BEATsConfig, ) -> None: 71 | super().__init__() 72 | logger.info(f"BEATs Config: {cfg.__dict__}") 73 | 74 | self.cfg = cfg 75 | 76 | self.embed = cfg.embed_dim 77 | self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) 78 | if self.embed != cfg.encoder_embed_dim else 79 | None) 80 | 81 | self.input_patch_size = cfg.input_patch_size 82 | self.patch_embedding = nn.Conv2d( 83 | 1, 84 | self.embed, 85 | kernel_size=self.input_patch_size, 86 | stride=self.input_patch_size, 87 | bias=cfg.conv_bias) 88 | 89 | self.dropout_input = nn.Dropout(cfg.dropout_input) 90 | 91 | assert not cfg.deep_norm or not cfg.layer_norm_first 92 | self.encoder = TransformerEncoder(cfg) 93 | self.layer_norm = LayerNorm(self.embed) 94 | 95 | if cfg.finetuned_model: 96 | self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) 97 | self.predictor = nn.Linear(cfg.encoder_embed_dim, 98 | cfg.predictor_class) 99 | else: 100 | self.predictor = None 101 | 102 | def forward_padding_mask( 103 | self, 104 | features: torch.Tensor, 105 | padding_mask: torch.Tensor, ) -> torch.Tensor: 106 | extra = padding_mask.size(1) % features.size(1) 107 | if extra > 0: 108 | padding_mask = padding_mask[:, :-extra] 109 | padding_mask = padding_mask.view( 110 | padding_mask.size(0), features.size(1), -1) 111 | padding_mask = padding_mask.all(-1) 112 | return padding_mask 113 | 114 | def preprocess( 115 | self, 116 | source: torch.Tensor, 117 | fbank_mean: float=15.41663, 118 | fbank_std: float=6.55582, ) -> torch.Tensor: 119 | fbanks = [] 120 | for waveform in source: 121 | waveform = waveform.unsqueeze(0) * 2**15 122 | fbank = ta_kaldi.fbank( 123 | waveform, 124 | num_mel_bins=128, 125 | sample_frequency=16000, 126 | frame_length=25, 127 | frame_shift=10) 128 | fbanks.append(fbank) 129 | fbank = torch.stack(fbanks, dim=0) 130 | fbank = (fbank - fbank_mean) / (2 * fbank_std) 131 | return fbank 132 | 133 | def extract_features( 134 | self, 135 | source: torch.Tensor, 136 | padding_mask: Optional[torch.Tensor]=None, 137 | fbank_mean: float=15.41663, 138 | fbank_std: float=6.55582, ): 139 | fbank = self.preprocess( 140 | source, fbank_mean=fbank_mean, fbank_std=fbank_std) 141 | 142 | if padding_mask is not None: 143 | padding_mask = self.forward_padding_mask(fbank, padding_mask) 144 | 145 | fbank = fbank.unsqueeze(1) 146 | features = self.patch_embedding(fbank) 147 | features = features.reshape(features.shape[0], features.shape[1], -1) 148 | features = features.transpose(1, 2) 149 | features = self.layer_norm(features) 150 | 151 | if padding_mask is not None: 152 | padding_mask = self.forward_padding_mask(features, padding_mask) 153 | 154 | if self.post_extract_proj is not None: 155 | features = self.post_extract_proj(features) 156 | 157 | x = self.dropout_input(features) 158 | 159 | x, layer_results = self.encoder( 160 | x, 161 | padding_mask=padding_mask, ) 162 | 163 | if self.predictor is not None: 164 | x = self.predictor_dropout(x) 165 | logits = self.predictor(x) 166 | 167 | if padding_mask is not None and padding_mask.any(): 168 | logits[padding_mask] = 0 169 | logits = logits.sum(dim=1) 170 | logits = logits / (~padding_mask).sum( 171 | dim=1).unsqueeze(-1).expand_as(logits) 172 | else: 173 | logits = logits.mean(dim=1) 174 | 175 | lprobs = torch.sigmoid(logits) 176 | 177 | return lprobs, padding_mask 178 | else: 179 | return x, padding_mask 180 | -------------------------------------------------------------------------------- /AR/exps/beats/README.md: -------------------------------------------------------------------------------- 1 | 2 | # BEATs 3 | 4 | [**BEATs**](https://arxiv.org/abs/2212.09058): **Audio Pre-Training with Acoustic Tokenizers** 5 | 6 | Official PyTorch implementation and pretrained models of BEATs 7 | 8 | ## Pre-Trained and Fine-Tuned Tokenizers and Models 9 | Iterations | Tokenizer | Pre-Trained Model | AudioSet Fine-Tuned Model 1 | AudioSet Fine-Tuned Model 2 10 | |---|---|---|---|--- 11 | Iter1 | Random Projection | [BEATs_iter1](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter1 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter1_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | 12 | Iter2 | [Tokenizer_iter2](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter2](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter2 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter2_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | 13 | Iter3 | [Tokenizer_iter3](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3 (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | 14 | Iter3+ | [Tokenizer_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS20K)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS20K) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS20K_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | 15 | Iter3+ | [Tokenizer_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/Tokenizer_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D)| [BEATs_iter3+ (AS2M)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt1)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt1.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | [Fine-tuned BEATs_iter3+ (AS2M) (cpt2)](https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) | 16 | 17 | 18 | ### Load Tokenizers 19 | 20 | ```python 21 | import torch 22 | from Tokenizers import TokenizersConfig, Tokenizers 23 | 24 | # load the pre-trained checkpoints 25 | checkpoint = torch.load('/path/to/tokenizer.pt') 26 | 27 | cfg = TokenizersConfig(checkpoint['cfg']) 28 | BEATs_tokenizer = Tokenizers(cfg) 29 | BEATs_tokenizer.load_state_dict(checkpoint['model']) 30 | BEATs_tokenizer.eval() 31 | 32 | # tokenize the audio and generate the labels 33 | audio_input_16khz = torch.randn(1, 10000) 34 | padding_mask = torch.zeros(1, 10000).bool() 35 | 36 | labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=padding_mask) 37 | ``` 38 | 39 | 40 | ### Load Pre-Trained Models 41 | 42 | ```python 43 | import torch 44 | from BEATs import BEATs, BEATsConfig 45 | 46 | # load the pre-trained checkpoints 47 | checkpoint = torch.load('/path/to/model.pt') 48 | 49 | cfg = BEATsConfig(checkpoint['cfg']) 50 | BEATs_model = BEATs(cfg) 51 | BEATs_model.load_state_dict(checkpoint['model']) 52 | BEATs_model.eval() 53 | 54 | # extract the the audio representation 55 | audio_input_16khz = torch.randn(1, 10000) 56 | padding_mask = torch.zeros(1, 10000).bool() 57 | 58 | representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0] 59 | ``` 60 | 61 | 62 | ### Load Fine-tuned Models 63 | 64 | ```python 65 | import torch 66 | from BEATs import BEATs, BEATsConfig 67 | 68 | # load the fine-tuned checkpoints 69 | checkpoint = torch.load('/path/to/model.pt') 70 | 71 | cfg = BEATsConfig(checkpoint['cfg']) 72 | BEATs_model = BEATs(cfg) 73 | BEATs_model.load_state_dict(checkpoint['model']) 74 | BEATs_model.eval() 75 | 76 | # predict the classification probability of each class 77 | audio_input_16khz = torch.randn(3, 10000) 78 | padding_mask = torch.zeros(3, 10000).bool() 79 | 80 | probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0] 81 | 82 | for i, (top5_label_prob, top5_label_idx) in enumerate(zip(*probs.topk(k=5))): 83 | top5_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top5_label_idx] 84 | print(f'Top 5 predicted labels of the {i}th audio are {top5_label} with probability of {top5_label_prob}') 85 | ``` 86 | 87 | ## Evaluation Results 88 | 89 | ### Comparing with the SOTA Single Models 90 | ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png) 91 | 92 | 93 | ### Comparing with the SOTA Ensemble Models 94 | ![alt text](Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png) 95 | 96 | 97 | ### Comparing Different BEATS Tokenizers 98 | ![alt text](Evaluation_Results/Comparing_Different_BEATS_Tokenizers.png) 99 | 100 | 101 | ### Comparing Different Pre-Training Targets 102 | ![alt text](Evaluation_Results/Comparing_Different_Pre-Training_Targets.png) 103 | 104 | 105 | ## License 106 | This project is licensed under the license found in the LICENSE file in the root directory of this source tree. 107 | Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) and [VQGAN](https://github.com/CompVis/taming-transformers) project. 108 | 109 | [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct) 110 | 111 | 112 | ### Reference 113 | If you find our work is useful in your research, please cite the following paper: 114 | ``` latex 115 | @article{Chen2022beats, 116 | title = {BEATs: Audio Pre-Training with Acoustic Tokenizers}, 117 | author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Shujie Liu and Daniel Tompkins and Zhuo Chen and Furu Wei}, 118 | eprint={2212.09058}, 119 | archivePrefix={arXiv}, 120 | year={2022} 121 | } 122 | ``` 123 | ### Contact Information 124 | 125 | For help or issues using BEATs models, please submit a GitHub issue. 126 | 127 | For other communications related to BEATs, please contact Yu Wu (`yuwu1@microsoft.com`). 128 | -------------------------------------------------------------------------------- /AR/exps/beats/Tokenizers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | import logging 10 | from typing import Optional 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torchaudio.compliance.kaldi as ta_kaldi 15 | from backbone import ( 16 | TransformerEncoder, ) 17 | from quantizer import ( 18 | NormEMAVectorQuantizer, ) 19 | from torch.nn import LayerNorm 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class TokenizersConfig: 25 | def __init__(self, cfg=None): 26 | self.input_patch_size: int = -1 # path size of patch embedding 27 | self.embed_dim: int = 512 # patch embedding dimension 28 | self.conv_bias: bool = False # include bias in conv encoder 29 | 30 | self.encoder_layers: int = 12 # num encoder layers in the transformer 31 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 32 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 33 | self.encoder_attention_heads: int = 12 # num encoder attention heads 34 | self.activation_fn: str = "gelu" # activation function to use 35 | 36 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 37 | self.deep_norm: bool = False # apply deep_norm first in the transformer 38 | 39 | # dropouts 40 | self.dropout: float = 0.1 # dropout probability for the transformer 41 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 42 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 43 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 44 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 45 | 46 | # positional embeddings 47 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 48 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 49 | 50 | # relative position embedding 51 | self.relative_position_embedding: bool = False # apply relative position embedding 52 | self.num_buckets: int = 320 # number of buckets for relative position embedding 53 | self.max_distance: int = 1280 # maximum distance for relative position embedding 54 | self.gru_rel_pos: bool = False # apply gated relative position embedding 55 | 56 | # quantizer 57 | self.quant_n: int = 1024 # codebook number in quantizer 58 | self.quant_dim: int = 256 # codebook dimension in quantizer 59 | 60 | if cfg is not None: 61 | self.update(cfg) 62 | 63 | def update(self, cfg: dict): 64 | self.__dict__.update(cfg) 65 | 66 | 67 | class Tokenizers(nn.Module): 68 | def __init__( 69 | self, 70 | cfg: TokenizersConfig, ) -> None: 71 | super().__init__() 72 | logger.info(f"Tokenizers Config: {cfg.__dict__}") 73 | 74 | self.cfg = cfg 75 | 76 | self.embed = cfg.embed_dim 77 | self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) 78 | if self.embed != cfg.encoder_embed_dim else 79 | None) 80 | 81 | self.input_patch_size = cfg.input_patch_size 82 | self.patch_embedding = nn.Conv2d( 83 | 1, 84 | self.embed, 85 | kernel_size=self.input_patch_size, 86 | stride=self.input_patch_size, 87 | bias=cfg.conv_bias) 88 | 89 | self.dropout_input = nn.Dropout(cfg.dropout_input) 90 | 91 | assert not cfg.deep_norm or not cfg.layer_norm_first 92 | self.encoder = TransformerEncoder(cfg) 93 | self.layer_norm = LayerNorm(self.embed) 94 | 95 | self.quantize = NormEMAVectorQuantizer( 96 | n_embed=cfg.quant_n, 97 | embedding_dim=cfg.quant_dim, 98 | beta=1.0, 99 | kmeans_init=True, 100 | decay=0.99, ) 101 | self.quant_n = cfg.quant_n 102 | self.quantize_layer = nn.Sequential( 103 | nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), 104 | nn.Tanh(), 105 | nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize 106 | ) 107 | 108 | def forward_padding_mask( 109 | self, 110 | features: torch.Tensor, 111 | padding_mask: torch.Tensor, ) -> torch.Tensor: 112 | extra = padding_mask.size(1) % features.size(1) 113 | if extra > 0: 114 | padding_mask = padding_mask[:, :-extra] 115 | padding_mask = padding_mask.view( 116 | padding_mask.size(0), features.size(1), -1) 117 | padding_mask = padding_mask.all(-1) 118 | return padding_mask 119 | 120 | def preprocess( 121 | self, 122 | source: torch.Tensor, 123 | fbank_mean: float=15.41663, 124 | fbank_std: float=6.55582, ) -> torch.Tensor: 125 | fbanks = [] 126 | for waveform in source: 127 | waveform = waveform.unsqueeze(0) * 2**15 128 | fbank = ta_kaldi.fbank( 129 | waveform, 130 | num_mel_bins=128, 131 | sample_frequency=16000, 132 | frame_length=25, 133 | frame_shift=10) 134 | fbanks.append(fbank) 135 | fbank = torch.stack(fbanks, dim=0) 136 | fbank = (fbank - fbank_mean) / (2 * fbank_std) 137 | return fbank 138 | 139 | def extract_labels( 140 | self, 141 | source: torch.Tensor, 142 | padding_mask: Optional[torch.Tensor]=None, 143 | fbank_mean: float=15.41663, 144 | fbank_std: float=6.55582, ): 145 | fbank = self.preprocess( 146 | source, fbank_mean=fbank_mean, fbank_std=fbank_std) 147 | 148 | if padding_mask is not None: 149 | padding_mask = self.forward_padding_mask(fbank, padding_mask) 150 | 151 | fbank = fbank.unsqueeze(1) 152 | features = self.patch_embedding(fbank) 153 | features = features.reshape(features.shape[0], features.shape[1], -1) 154 | features = features.transpose(1, 2) 155 | features = self.layer_norm(features) 156 | 157 | if padding_mask is not None: 158 | padding_mask = self.forward_padding_mask(features, padding_mask) 159 | 160 | if self.post_extract_proj is not None: 161 | features = self.post_extract_proj(features) 162 | 163 | x = self.dropout_input(features) 164 | 165 | x, layer_results = self.encoder( 166 | x, 167 | padding_mask=padding_mask, ) 168 | 169 | quantize_input = self.quantize_layer(x) 170 | quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) 171 | 172 | return embed_ind 173 | -------------------------------------------------------------------------------- /AR/exps/beats/__init__.py: -------------------------------------------------------------------------------- 1 | # this folder is modified from https://github.com/microsoft/unilm/tree/master/beats 2 | # ontology.json is from https://github.com/audioset/ontology/ 3 | -------------------------------------------------------------------------------- /AR/exps/beats/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | # 获取当前脚本的所在目录 5 | script_dir = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | # JSON 文件的文件名 8 | json_filename = "ontology.json" 9 | 10 | # 构建 JSON 文件的完整路径 11 | json_path = os.path.join(script_dir, json_filename) 12 | 13 | id_name_dict = {} 14 | 15 | with open(json_path, 'r') as f: 16 | json_items = json.load(f) 17 | # '/m/0dgw9r' -> 'Human sounds' and etc. 18 | for item in json_items: 19 | id_name_dict[item['id']] = item['name'] 20 | -------------------------------------------------------------------------------- /AR/exps/beats/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | import math 10 | import warnings 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import nn 15 | 16 | 17 | class GradMultiply(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, x, scale): 20 | ctx.scale = scale 21 | res = x.new(x) 22 | return res 23 | 24 | @staticmethod 25 | def backward(ctx, grad): 26 | return grad * ctx.scale, None 27 | 28 | 29 | class SamePad(nn.Module): 30 | def __init__(self, kernel_size, causal=False): 31 | super().__init__() 32 | if causal: 33 | self.remove = kernel_size - 1 34 | else: 35 | self.remove = 1 if kernel_size % 2 == 0 else 0 36 | 37 | def forward(self, x): 38 | if self.remove > 0: 39 | x = x[:, :, :-self.remove] 40 | return x 41 | 42 | 43 | class Swish(nn.Module): 44 | def __init__(self): 45 | super(Swish, self).__init__() 46 | self.act = torch.nn.Sigmoid() 47 | 48 | def forward(self, x): 49 | return x * self.act(x) 50 | 51 | 52 | class GLU_Linear(nn.Module): 53 | def __init__(self, 54 | input_dim, 55 | output_dim, 56 | glu_type="sigmoid", 57 | bias_in_glu=True): 58 | super(GLU_Linear, self).__init__() 59 | 60 | self.glu_type = glu_type 61 | self.output_dim = output_dim 62 | 63 | if glu_type == "sigmoid": 64 | self.glu_act = torch.nn.Sigmoid() 65 | elif glu_type == "swish": 66 | self.glu_act = Swish() 67 | elif glu_type == "relu": 68 | self.glu_act = torch.nn.ReLU() 69 | elif glu_type == "gelu": 70 | self.glu_act = torch.nn.GELU() 71 | 72 | if bias_in_glu: 73 | self.linear = nn.Linear(input_dim, output_dim * 2, True) 74 | else: 75 | self.linear = nn.Linear(input_dim, output_dim * 2, False) 76 | 77 | def forward(self, x): 78 | # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case 79 | x = self.linear(x) 80 | 81 | if self.glu_type == "bilinear": 82 | x = (x[:, :, 0:self.output_dim] * 83 | x[:, :, self.output_dim:self.output_dim * 2]) 84 | else: 85 | x = (x[:, :, 0:self.output_dim] * 86 | self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) 87 | 88 | return x 89 | 90 | 91 | def gelu_accurate(x): 92 | if not hasattr(gelu_accurate, "_a"): 93 | gelu_accurate._a = math.sqrt(2 / math.pi) 94 | return (0.5 * x * (1 + torch.tanh(gelu_accurate._a * 95 | (x + 0.044715 * torch.pow(x, 3))))) 96 | 97 | 98 | def gelu(x: torch.Tensor) -> torch.Tensor: 99 | return torch.nn.functional.gelu(x.float()).type_as(x) 100 | 101 | 102 | def get_activation_fn(activation: str): 103 | """Returns the activation function corresponding to `activation`""" 104 | 105 | if activation == "relu": 106 | return F.relu 107 | elif activation == "gelu": 108 | return gelu 109 | elif activation == "gelu_fast": 110 | warnings.warn( 111 | "--activation-fn=gelu_fast has been renamed to gelu_accurate") 112 | return gelu_accurate 113 | elif activation == "gelu_accurate": 114 | return gelu_accurate 115 | elif activation == "tanh": 116 | return torch.tanh 117 | elif activation == "linear": 118 | return lambda x: x 119 | elif activation == "glu": 120 | return lambda x: x 121 | else: 122 | raise RuntimeError( 123 | "--activation-fn {} not supported".format(activation)) 124 | 125 | 126 | def quant_noise(module, p, block_size): 127 | """ 128 | Wraps modules and applies quantization noise to the weights for 129 | subsequent quantization with Iterative Product Quantization as 130 | described in "Training with Quantization Noise for Extreme Model Compression" 131 | 132 | Args: 133 | - module: nn.Module 134 | - p: amount of Quantization Noise 135 | - block_size: size of the blocks for subsequent quantization with iPQ 136 | 137 | Remarks: 138 | - Module weights must have the right sizes wrt the block size 139 | - Only Linear, Embedding and Conv2d modules are supported for the moment 140 | - For more detail on how to quantize by blocks with convolutional weights, 141 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" 142 | - We implement the simplest form of noise here as stated in the paper 143 | which consists in randomly dropping blocks 144 | """ 145 | 146 | # if no quantization noise, don't register hook 147 | if p <= 0: 148 | return module 149 | 150 | # supported modules 151 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) 152 | 153 | # test whether module.weight has the right sizes wrt block_size 154 | is_conv = module.weight.ndim == 4 155 | 156 | # 2D matrix 157 | if not is_conv: 158 | assert ( 159 | module.weight.size(1) % 160 | block_size == 0), "Input features must be a multiple of block sizes" 161 | 162 | # 4D matrix 163 | else: 164 | # 1x1 convolutions 165 | if module.kernel_size == (1, 1): 166 | assert (module.in_channels % block_size == 0 167 | ), "Input channels must be a multiple of block sizes" 168 | # regular convolutions 169 | else: 170 | k = module.kernel_size[0] * module.kernel_size[1] 171 | assert k % block_size == 0, "Kernel size must be a multiple of block size" 172 | 173 | def _forward_pre_hook(mod, input): 174 | # no noise for evaluation 175 | if mod.training: 176 | if not is_conv: 177 | # gather weight and sizes 178 | weight = mod.weight 179 | in_features = weight.size(1) 180 | out_features = weight.size(0) 181 | 182 | # split weight matrix into blocks and randomly drop selected blocks 183 | mask = torch.zeros( 184 | in_features // block_size * out_features, 185 | device=weight.device) 186 | mask.bernoulli_(p) 187 | mask = mask.repeat_interleave(block_size, -1).view(-1, 188 | in_features) 189 | 190 | else: 191 | # gather weight and sizes 192 | weight = mod.weight 193 | in_channels = mod.in_channels 194 | out_channels = mod.out_channels 195 | 196 | # split weight matrix into blocks and randomly drop selected blocks 197 | if mod.kernel_size == (1, 1): 198 | mask = torch.zeros( 199 | int(in_channels // block_size * out_channels), 200 | device=weight.device, ) 201 | mask.bernoulli_(p) 202 | mask = mask.repeat_interleave(block_size, -1).view( 203 | -1, in_channels) 204 | else: 205 | mask = torch.zeros( 206 | weight.size(0), weight.size(1), device=weight.device) 207 | mask.bernoulli_(p) 208 | mask = ( 209 | mask.unsqueeze(2).unsqueeze(3) 210 | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])) 211 | 212 | # scale weights and apply mask 213 | mask = mask.to( 214 | torch. 215 | bool) # x.bool() is not currently supported in TorchScript 216 | s = 1 / (1 - p) 217 | mod.weight.data = s * weight.masked_fill(mask, 0) 218 | 219 | module.register_forward_pre_hook(_forward_pre_hook) 220 | return module 221 | -------------------------------------------------------------------------------- /AR/exps/beats/quantizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on VQGAN code bases 7 | # https://github.com/CompVis/taming-transformers 8 | # --------------------------------------------------------' 9 | import torch 10 | import torch.distributed as distributed 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | try: 15 | from einops import rearrange, repeat 16 | except ImportError: 17 | pass 18 | 19 | 20 | def l2norm(t): 21 | return F.normalize(t, p=2, dim=-1) 22 | 23 | 24 | def ema_inplace(moving_avg, new, decay): 25 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 26 | 27 | 28 | def sample_vectors(samples, num): 29 | num_samples, device = samples.shape[0], samples.device 30 | 31 | if num_samples >= num: 32 | indices = torch.randperm(num_samples, device=device)[:num] 33 | else: 34 | indices = torch.randint(0, num_samples, (num, ), device=device) 35 | 36 | return samples[indices] 37 | 38 | 39 | def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): 40 | dim, dtype, device = samples.shape[-1], samples.dtype, samples.device 41 | 42 | means = sample_vectors(samples, num_clusters) 43 | 44 | for _ in range(num_iters): 45 | if use_cosine_sim: 46 | dists = samples @ means.t() 47 | else: 48 | diffs = rearrange(samples, 'n d -> n () d') \ 49 | - rearrange(means, 'c d -> () c d') 50 | dists = -(diffs**2).sum(dim=-1) 51 | 52 | buckets = dists.max(dim=-1).indices 53 | bins = torch.bincount(buckets, minlength=num_clusters) 54 | zero_mask = bins == 0 55 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 56 | 57 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 58 | new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) 59 | new_means = new_means / bins_min_clamped[..., None] 60 | 61 | if use_cosine_sim: 62 | new_means = l2norm(new_means) 63 | 64 | means = torch.where(zero_mask[..., None], means, new_means) 65 | 66 | return means, bins 67 | 68 | 69 | class EmbeddingEMA(nn.Module): 70 | def __init__(self, 71 | num_tokens, 72 | codebook_dim, 73 | decay=0.99, 74 | eps=1e-5, 75 | kmeans_init=True, 76 | codebook_init_path=''): 77 | super().__init__() 78 | self.num_tokens = num_tokens 79 | self.codebook_dim = codebook_dim 80 | self.decay = decay 81 | self.eps = eps 82 | if codebook_init_path == '': 83 | if not kmeans_init: 84 | weight = torch.randn(num_tokens, codebook_dim) 85 | weight = l2norm(weight) 86 | else: 87 | weight = torch.zeros(num_tokens, codebook_dim) 88 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 89 | else: 90 | print(f"load init codebook weight from {codebook_init_path}") 91 | codebook_ckpt_weight = torch.load( 92 | codebook_init_path, map_location='cpu') 93 | weight = codebook_ckpt_weight.clone() 94 | self.register_buffer('initted', torch.Tensor([True])) 95 | 96 | self.weight = nn.Parameter(weight, requires_grad=False) 97 | self.cluster_size = nn.Parameter( 98 | torch.zeros(num_tokens), requires_grad=False) 99 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) 100 | # self.register_buffer('initted', torch.Tensor([not kmeans_init])) 101 | self.update = True 102 | 103 | @torch.jit.ignore 104 | def init_embed_(self, data): 105 | if self.initted: 106 | return 107 | print("Performing Kemans init for codebook") 108 | embed, cluster_size = kmeans( 109 | data, self.num_tokens, 10, use_cosine_sim=True) 110 | self.weight.data.copy_(embed) 111 | self.cluster_size.data.copy_(cluster_size) 112 | self.initted.data.copy_(torch.Tensor([True])) 113 | 114 | def forward(self, embed_id): 115 | return F.embedding(embed_id, self.weight) 116 | 117 | def cluster_size_ema_update(self, new_cluster_size): 118 | self.cluster_size.data.mul_(self.decay).add_( 119 | new_cluster_size, alpha=1 - self.decay) 120 | 121 | def embed_avg_ema_update(self, new_embed_avg): 122 | self.embed_avg.data.mul_(self.decay).add_( 123 | new_embed_avg, alpha=1 - self.decay) 124 | 125 | def weight_update(self, num_tokens): 126 | n = self.cluster_size.sum() 127 | smoothed_cluster_size = ( 128 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n) 129 | # normalize embedding average with smoothed cluster size 130 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 131 | # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) 132 | self.weight.data.copy_(embed_normalized) 133 | 134 | 135 | def norm_ema_inplace(moving_avg, new, decay): 136 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 137 | moving_avg.data.copy_(l2norm(moving_avg.data)) 138 | 139 | 140 | class NormEMAVectorQuantizer(nn.Module): 141 | def __init__(self, 142 | n_embed, 143 | embedding_dim, 144 | beta, 145 | decay=0.99, 146 | eps=1e-5, 147 | statistic_code_usage=True, 148 | kmeans_init=False, 149 | codebook_init_path=''): 150 | super().__init__() 151 | self.codebook_dim = embedding_dim 152 | self.num_tokens = n_embed 153 | self.beta = beta 154 | self.decay = decay 155 | 156 | # learnable = True if orthogonal_reg_weight > 0 else False 157 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, 158 | eps, kmeans_init, codebook_init_path) 159 | 160 | self.statistic_code_usage = statistic_code_usage 161 | if statistic_code_usage: 162 | self.register_buffer('cluster_size', torch.zeros(n_embed)) 163 | if distributed.is_available() and distributed.is_initialized(): 164 | print( 165 | "ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!" 166 | ) 167 | self.all_reduce_fn = distributed.all_reduce 168 | else: 169 | self.all_reduce_fn = nn.Identity() 170 | 171 | def reset_cluster_size(self, device): 172 | if self.statistic_code_usage: 173 | self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) 174 | self.cluster_size = self.cluster_size.to(device) 175 | 176 | def forward(self, z): 177 | # reshape z -> (batch, height, width, channel) and flatten 178 | # z, 'b c h w -> b h w c' 179 | # z = rearrange(z, 'b c h w -> b h w c') 180 | # z = z.transpose(1, 2) 181 | z = l2norm(z) 182 | z_flattened = z.reshape(-1, self.codebook_dim) 183 | 184 | self.embedding.init_embed_(z_flattened) 185 | 186 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 187 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 188 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 189 | 190 | encoding_indices = torch.argmin(d, dim=1) 191 | 192 | z_q = self.embedding(encoding_indices).view(z.shape) 193 | 194 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 195 | 196 | if not self.training: 197 | with torch.no_grad(): 198 | cluster_size = encodings.sum(0) 199 | self.all_reduce_fn(cluster_size) 200 | ema_inplace(self.cluster_size, cluster_size, self.decay) 201 | 202 | if self.training and self.embedding.update: 203 | # EMA cluster size 204 | 205 | bins = encodings.sum(0) 206 | self.all_reduce_fn(bins) 207 | 208 | # self.embedding.cluster_size_ema_update(bins) 209 | ema_inplace(self.cluster_size, bins, self.decay) 210 | 211 | zero_mask = (bins == 0) 212 | bins = bins.masked_fill(zero_mask, 1.) 213 | 214 | embed_sum = z_flattened.t() @ encodings 215 | self.all_reduce_fn(embed_sum) 216 | 217 | embed_normalized = (embed_sum / bins.unsqueeze(0)).t() 218 | embed_normalized = l2norm(embed_normalized) 219 | 220 | embed_normalized = torch.where( 221 | zero_mask[..., None], self.embedding.weight, embed_normalized) 222 | norm_ema_inplace(self.embedding.weight, embed_normalized, 223 | self.decay) 224 | 225 | # compute loss for embedding 226 | loss = self.beta * F.mse_loss(z_q.detach(), z) 227 | 228 | # preserve gradients 229 | z_q = z + (z_q - z).detach() 230 | 231 | # reshape back to match original input shape 232 | # z_q, 'b h w c -> b c h w' 233 | # z_q = rearrange(z_q, 'b h w c -> b c h w') 234 | # z_q = z_q.transpose(1, 2) 235 | return z_q, loss, encoding_indices 236 | -------------------------------------------------------------------------------- /AR/exps/get_phones.py: -------------------------------------------------------------------------------- 1 | """ 2 | 1. read text of dataset 3 | 2. text -> IPA by GruutPhonemizer 4 | 3. save out a *.npy dict for all text 5 | my_dict = {"utt_id1": text1, "utt_id2": text2} 6 | np.save(output_filename, my_dict) 7 | my_dict = np.load(output_filename, allow_pickle=True).item() 8 | """ 9 | import argparse 10 | import os 11 | from concurrent.futures import ThreadPoolExecutor 12 | from operator import itemgetter 13 | from pathlib import Path 14 | from typing import List 15 | 16 | import numpy as np 17 | import tqdm 18 | from AR.text_processing.phonemizer import GruutPhonemizer 19 | 20 | 21 | def read_txt(txt_file): 22 | utt_name = txt_file.stem 23 | utt_id = utt_name.split('.')[0] 24 | try: 25 | with open(txt_file, 'r') as file: 26 | txt = file.readline() 27 | record = {"utt_id": utt_id, "txt": txt} 28 | except Exception: 29 | print("occur Exception") 30 | traceback.print_exc() 31 | return None 32 | return record 33 | 34 | 35 | def read_txts(txt_files: List[Path], nprocs: int=1): 36 | if nprocs == 1: 37 | results = [] 38 | for txt_file in tqdm.tqdm(txt_files, total=len(txt_files)): 39 | record = read_txt(txt_file=txt_file) 40 | if record: 41 | results.append(record) 42 | else: 43 | with ThreadPoolExecutor(nprocs) as pool: 44 | futures = [] 45 | with tqdm.tqdm(total=len(txt_files)) as progress: 46 | for txt_file in txt_files: 47 | future = pool.submit(read_txt, txt_file) 48 | future.add_done_callback(lambda p: progress.update()) 49 | futures.append(future) 50 | 51 | results = [] 52 | for ft in futures: 53 | record = ft.result() 54 | if record: 55 | results.append(record) 56 | 57 | results.sort(key=itemgetter("utt_id")) 58 | return_list = [] 59 | for item in results: 60 | return_list.append((item["utt_id"], item["txt"])) 61 | return return_list 62 | 63 | 64 | def process_sentence(item, phonemizer): 65 | utt_id, text = item 66 | try: 67 | phonemes = phonemizer.phonemize(text, espeak=False) 68 | record = {"utt_id": utt_id, "phonemes": phonemes} 69 | except Exception: 70 | print("occur Exception") 71 | traceback.print_exc() 72 | return None 73 | return record 74 | 75 | 76 | def process_sentences(items, phonemizer, output_dir, nprocs: int=1): 77 | if nprocs == 1: 78 | results = [] 79 | for item in tqdm.tqdm(items, total=len(items)): 80 | record = process_sentence(item=item, phonemizer=phonemizer) 81 | if record: 82 | results.append(record) 83 | else: 84 | with ThreadPoolExecutor(nprocs) as pool: 85 | futures = [] 86 | with tqdm.tqdm(total=len(items)) as progress: 87 | for item in items: 88 | future = pool.submit(process_sentence, item, phonemizer) 89 | future.add_done_callback(lambda p: progress.update()) 90 | futures.append(future) 91 | 92 | results = [] 93 | for ft in futures: 94 | record = ft.result() 95 | if record: 96 | results.append(record) 97 | results.sort(key=itemgetter("utt_id")) 98 | npy_dict = {} 99 | for item in results: 100 | utt_id = item["utt_id"] 101 | phonemes = item["phonemes"] 102 | npy_dict[utt_id] = phonemes 103 | filename = output_dir / 'phonemes.npy' 104 | np.save(filename, npy_dict) 105 | print(f"npy file '{filename}' write down") 106 | 107 | 108 | def main(): 109 | # parse config and args 110 | parser = argparse.ArgumentParser(description="Get phones for datasets") 111 | 112 | parser.add_argument( 113 | "--dataset", 114 | default="ljspeech", 115 | type=str, 116 | help="name of dataset, should in {ljspeech, libritts} now") 117 | 118 | parser.add_argument( 119 | "--data_dir", default=None, type=str, help="directory to dataset.") 120 | 121 | parser.add_argument( 122 | "--dump_dir", 123 | type=str, 124 | required=True, 125 | help="directory to dump feature files.") 126 | parser.add_argument( 127 | "--num-cpu", type=int, default=1, help="number of process.") 128 | 129 | args = parser.parse_args() 130 | 131 | data_dir = Path(args.data_dir).expanduser() 132 | dump_dir = Path(args.dump_dir).expanduser() 133 | # use absolute path 134 | dump_dir = dump_dir.resolve() 135 | dump_dir.mkdir(parents=True, exist_ok=True) 136 | 137 | assert data_dir.is_dir() 138 | 139 | if args.dataset == "ljspeech": 140 | data_dict = {} 141 | text_path = data_dir / 'metadata.csv' 142 | with open(text_path, 'r') as rf: 143 | for line in rf: 144 | line_list = line.strip().split('|') 145 | utt_id = line_list[0] 146 | raw_text = line_list[-1] 147 | data_dict[utt_id] = raw_text 148 | 149 | sorted_dict = sorted(data_dict.items()) 150 | 151 | num_train = 12900 152 | num_dev = 100 153 | # (utt_id, txt) 154 | train_txts = sorted_dict[:num_train] 155 | dev_txts = sorted_dict[num_train:num_train + num_dev] 156 | test_txts = sorted_dict[num_train + num_dev:] 157 | 158 | elif args.dataset == "libritts": 159 | ''' 160 | we use train-clean-100、train-clean-360、train-other-500 here 161 | and split dev and test from them, don't use test-* and dev-* cause the speakers are disjoint 162 | the file structure is LibriTTS_R/train-clean-100/spkid/*/*.wav 163 | there are about 2311 in these subsets, we split 1 dev and 1 test wav out from each speaker 164 | ''' 165 | txt_files = [] 166 | train_txt_files = [] 167 | dev_txt_files = [] 168 | test_txt_files = [] 169 | sub_num_dev = 1 170 | for sub_dataset_name in { 171 | "train-clean-100", "train-clean-360", "train-other-500" 172 | }: 173 | sub_dataset_dir = data_dir / sub_dataset_name 174 | # filter out hidden files 175 | speaker_list = [ 176 | file for file in os.listdir(sub_dataset_dir) 177 | if not file.startswith('.') 178 | ] 179 | for speaker in speaker_list: 180 | txt_files = sorted( 181 | list((sub_dataset_dir / speaker).rglob( 182 | "*/*.normalized.txt"))) 183 | # filter out ._*.wav 184 | txt_files = [ 185 | file for file in txt_files if not file.name.startswith('._') 186 | ] 187 | train_txt_files += txt_files[:-sub_num_dev * 2] 188 | dev_txt_files += txt_files[-sub_num_dev * 2:-sub_num_dev] 189 | test_txt_files += txt_files[-sub_num_dev:] 190 | print("len(train_txt_files):", len(train_txt_files)) 191 | print("len(dev_txt_files):", len(dev_txt_files)) 192 | print("len(test_txt_files):", len(test_txt_files)) 193 | 194 | train_txts = read_txts(train_txt_files) 195 | dev_txts = read_txts(dev_txt_files) 196 | test_txts = read_txts(test_txt_files) 197 | 198 | else: 199 | print("dataset should in {ljspeech, libritts} now!") 200 | 201 | train_dump_dir = dump_dir / "train" 202 | train_dump_dir.mkdir(parents=True, exist_ok=True) 203 | dev_dump_dir = dump_dir / "dev" 204 | dev_dump_dir.mkdir(parents=True, exist_ok=True) 205 | test_dump_dir = dump_dir / "test" 206 | test_dump_dir.mkdir(parents=True, exist_ok=True) 207 | 208 | phonemizer = GruutPhonemizer(language='en-us') 209 | 210 | # process for the 3 sections 211 | if train_txts: 212 | process_sentences( 213 | items=train_txts, 214 | output_dir=train_dump_dir, 215 | phonemizer=phonemizer, 216 | nprocs=args.num_cpu) 217 | if dev_txts: 218 | process_sentences( 219 | items=dev_txts, 220 | output_dir=dev_dump_dir, 221 | phonemizer=phonemizer, 222 | nprocs=args.num_cpu) 223 | if test_txts: 224 | process_sentences( 225 | items=test_txts, 226 | output_dir=test_dump_dir, 227 | phonemizer=phonemizer, 228 | nprocs=args.num_cpu) 229 | 230 | 231 | if __name__ == "__main__": 232 | main() 233 | -------------------------------------------------------------------------------- /AR/exps/get_phones_librilight.py: -------------------------------------------------------------------------------- 1 | """ 2 | 1. read text of dataset, for LibriLight read txt_*.npy -> 需要整理成 list(utt_id, txt) 的形式 3 | 2. text -> IPA by GruutPhonemizer 4 | 3. save out a *.npy dict for all text 5 | 4. LibriLight 每个 split 分开处理 6 | my_dict = {"utt_id1": text1, "utt_id2": text2} 7 | np.save(output_filename, my_dict) 8 | my_dict = np.load(output_filename, allow_pickle=True).item() 9 | """ 10 | import argparse 11 | import os 12 | import time 13 | import traceback 14 | from concurrent.futures import ThreadPoolExecutor 15 | from operator import itemgetter 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | import tqdm 20 | from AR.text_processing.phonemizer import GruutPhonemizer 21 | from soundstorm.utils import check_txt_file 22 | 23 | 24 | def read_txts(txt_file: Path, nprocs: int=1): 25 | ''' 26 | txt_file: path of npy dict, {"utt_id1": text1, "utt_id2": text2} 27 | ''' 28 | txt_dict = np.load(txt_file, allow_pickle=True).item() 29 | #[(utt_id, txt), ...] 30 | return_list = list(txt_dict.items()) 31 | return return_list 32 | 33 | 34 | def process_sentence(item, phonemizer, output_dir): 35 | utt_id, text = item 36 | phonemes_dir = output_dir / "phonemes" 37 | phonemes_dir.mkdir(parents=True, exist_ok=True) 38 | phonemes_path = phonemes_dir / (utt_id + ".txt") 39 | try: 40 | if os.path.exists(phonemes_path) and check_txt_file(phonemes_path): 41 | # print(phonemes_path, 'exits!') 42 | pass 43 | else: 44 | phonemes = phonemizer.phonemize(text, espeak=False) 45 | with open(phonemes_path, 'w') as f: 46 | f.write(phonemes) 47 | record = {"utt_id": utt_id, "phonemes_path": phonemes_path} 48 | except Exception: 49 | print("occur Exception") 50 | traceback.print_exc() 51 | return None 52 | return record 53 | 54 | 55 | def process_sentences(args, items, phonemizer, output_dir, nprocs: int=1): 56 | print("nprocs:", nprocs) 57 | if nprocs == 1: 58 | results = [] 59 | for item in tqdm.tqdm(items, total=len(items)): 60 | record = process_sentence( 61 | item=item, phonemizer=phonemizer, output_dir=output_dir) 62 | if record: 63 | results.append(record) 64 | else: 65 | with ThreadPoolExecutor(nprocs) as pool: 66 | futures = [] 67 | with tqdm.tqdm(total=len(items)) as progress: 68 | for item in items: 69 | future = pool.submit(process_sentence, item, phonemizer, 70 | output_dir) 71 | future.add_done_callback(lambda p: progress.update()) 72 | futures.append(future) 73 | 74 | results = [] 75 | for ft in futures: 76 | record = ft.result() 77 | if record: 78 | results.append(record) 79 | 80 | results.sort(key=itemgetter("utt_id")) 81 | 82 | npy_dict = {} 83 | print(f"start to save {args.rank}_{args.nshard}.npy ...") 84 | save_start_time = time.time() 85 | for item in tqdm.tqdm(results, total=len(results), colour='green'): 86 | # 这里加 try, 因为 txt 文件可能损坏 87 | try: 88 | utt_id = item["utt_id"] 89 | phonemes = check_txt_file(item["phonemes_path"]) 90 | if phonemes is not False: 91 | npy_dict[utt_id] = phonemes 92 | else: 93 | print(f'phonemes of {utt_id} is False') 94 | except Exception: 95 | print(f"{utt_id} occur Exception") 96 | traceback.print_exc() 97 | continue 98 | 99 | filename = output_dir / f'phonemes_{args.rank}_{args.nshard}.npy' 100 | np.save(filename, npy_dict) 101 | print(f"npy file '{filename}' write down") 102 | print('time of save stage:', time.time() - save_start_time) 103 | 104 | 105 | def main(): 106 | # parse config and args 107 | parser = argparse.ArgumentParser( 108 | description="Get phones for LibriLight dataset from txt_*.npy") 109 | 110 | parser.add_argument( 111 | "--dump_dir", 112 | type=str, 113 | required=True, 114 | help="directory to dump feature files.") 115 | parser.add_argument( 116 | "--num-cpu", type=int, default=1, help="number of process.") 117 | 118 | parser.add_argument( 119 | '--train_txt_dir', 120 | type=str, 121 | default='dump/small/train/', 122 | help='dir of train txt files') 123 | parser.add_argument( 124 | '--dev_txt_dir', 125 | type=str, 126 | default='dump/small/dev/', 127 | help='dir of dev txt files') 128 | parser.add_argument( 129 | '--test_txt_dir', 130 | type=str, 131 | default='dump/small/test/', 132 | help='dir of test txt files') 133 | 134 | parser.add_argument( 135 | "--sub_dataset", 136 | default="small", 137 | type=str, 138 | help="name of sub dataset of LibriLight", 139 | choices=['small', 'medium', 'large', 'duplicate'], ) 140 | parser.add_argument("--nshard", type=int, default=3) 141 | parser.add_argument("--rank", type=int, default=0) 142 | 143 | args = parser.parse_args() 144 | print(f"nshard: {args.nshard}, rank: {args.rank}") 145 | 146 | train_txt_dir = Path(args.train_txt_dir) 147 | dev_txt_dir = Path(args.dev_txt_dir) 148 | test_txt_dir = Path(args.test_txt_dir) 149 | 150 | dump_dir = Path(args.dump_dir).expanduser() 151 | # use absolute path 152 | dump_dir = dump_dir.resolve() 153 | dump_dir.mkdir(parents=True, exist_ok=True) 154 | 155 | train_txt_file = train_txt_dir / f'txt_{args.rank}_{args.nshard}.npy' 156 | dev_txt_file = dev_txt_dir / f'txt_{args.rank}_{args.nshard}.npy' 157 | test_txt_file = test_txt_dir / f'txt_{args.rank}_{args.nshard}.npy' 158 | 159 | train_txts = read_txts(train_txt_file) 160 | dev_txts = read_txts(dev_txt_file) 161 | test_txts = read_txts(test_txt_file) 162 | 163 | sub_dataset_dump_dir = dump_dir / args.sub_dataset 164 | sub_dataset_dump_dir.mkdir(parents=True, exist_ok=True) 165 | train_dump_dir = sub_dataset_dump_dir / "train" 166 | train_dump_dir.mkdir(parents=True, exist_ok=True) 167 | dev_dump_dir = sub_dataset_dump_dir / "dev" 168 | dev_dump_dir.mkdir(parents=True, exist_ok=True) 169 | test_dump_dir = sub_dataset_dump_dir / "test" 170 | test_dump_dir.mkdir(parents=True, exist_ok=True) 171 | phonemizer = GruutPhonemizer(language='en-us') 172 | 173 | # process for the 3 sections 174 | if train_txts: 175 | process_sentences( 176 | args=args, 177 | items=train_txts, 178 | output_dir=train_dump_dir, 179 | phonemizer=phonemizer, 180 | nprocs=args.num_cpu) 181 | if dev_txts: 182 | process_sentences( 183 | args=args, 184 | items=dev_txts, 185 | output_dir=dev_dump_dir, 186 | phonemizer=phonemizer, 187 | nprocs=args.num_cpu) 188 | if test_txts: 189 | process_sentences( 190 | args=args, 191 | items=test_txts, 192 | output_dir=test_dump_dir, 193 | phonemizer=phonemizer, 194 | nprocs=args.num_cpu) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /AR/exps/split_train_val.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pandas 3 | 4 | semantic_path = 'dump/semantic.tsv' 5 | phoneme_path = 'dump/phoneme.npy' 6 | train_semantic_path = 'dump/semantic_train.tsv' 7 | train_phoneme_path = 'dump/phoneme_train.npy' 8 | dev_semantic_path = 'dump/semantic_dev.tsv' 9 | dev_phoneme_path = 'dump/phoneme_dev.npy' 10 | 11 | # 读取dump/semantic.tsv 12 | semantic_df = pandas.read_csv(semantic_path, sep='\t') 13 | # pd.DataFrame(columns=["item_name", "semantic_audio"]) 14 | # # 读取dump/phoneme.npy 15 | phoneme_dict = numpy.load(phoneme_path, allow_pickle=True).item() 16 | 17 | dev_num = 20 18 | # 随机从semantic_df中选取dev_num个 19 | dev_df = semantic_df.sample(n=dev_num) 20 | # 剩下的是train 21 | train_df = semantic_df.drop(dev_df.index) 22 | # 保存 23 | dev_df.to_csv(dev_semantic_path, sep='\t', index=False) 24 | train_df.to_csv(train_semantic_path, sep='\t', index=False) 25 | 26 | # 将dev_df中的item_name取出来 作为dev_phoneme_dict的key 27 | dev_item_names = dev_df['item_name'].tolist() 28 | dev_phoneme_dict = {k: phoneme_dict[k] for k in dev_item_names if k in phoneme_dict} 29 | train_phoneme_dict = {k: phoneme_dict[k] for k in phoneme_dict.keys() if k not in dev_item_names} 30 | 31 | numpy.save(dev_phoneme_path, dev_phoneme_dict) 32 | numpy.save(train_phoneme_path, train_phoneme_dict) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /AR/exps/t2s.py: -------------------------------------------------------------------------------- 1 | # text to semantic 2 | import argparse 3 | import os 4 | import re 5 | import time 6 | from pathlib import Path 7 | 8 | import librosa 9 | import numpy as np 10 | import torch 11 | import whisper 12 | from AR.models.t2s_lightning_module import Text2SemanticLightningModule 13 | from AR.text_processing.phonemizer import GruutPhonemizer 14 | from AR.utils.io import load_yaml_config 15 | 16 | 17 | def get_batch(text, phonemizer): 18 | # phoneme_ids 和 phoneme_ids_len 是需要的 19 | phoneme = phonemizer.phonemize(text, espeak=False) 20 | phoneme_ids = phonemizer.transform(phoneme) 21 | phoneme_ids_len = len(phoneme_ids) 22 | phoneme_ids = np.array(phoneme_ids) 23 | # add batch axis here 24 | phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0) 25 | phoneme_ids_len = torch.tensor([phoneme_ids_len]) 26 | print("phoneme:", phoneme) 27 | batch = { 28 | # torch.Tensor (B, max_phoneme_length) 29 | "phoneme_ids": phoneme_ids, 30 | # torch.Tensor (B) 31 | "phoneme_ids_len": phoneme_ids_len 32 | } 33 | return batch 34 | 35 | 36 | def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer): 37 | sample_rate = 16000 38 | # to get prompt 39 | prompt_name = os.path.basename(prompt_wav_path).split('.')[0] 40 | wav, _ = librosa.load(prompt_wav_path, sr=sample_rate) 41 | # 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止 42 | wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)] 43 | # wav 需要挪出末尾的静音否则也可能提前停住 44 | prompt_text = asr_model.transcribe(wav)["text"] 45 | # 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿 46 | prompt_text = prompt_text.replace(".", "") 47 | prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False) 48 | prompt_phoneme_ids = phonemizer.transform(prompt_phoneme) 49 | prompt_phoneme_ids_len = len(prompt_phoneme_ids) 50 | # get prompt_semantic 51 | # (T) -> (1, T) 52 | wav = torch.tensor(wav).unsqueeze(0) 53 | wav = wav.cuda() 54 | # (1, T) 55 | prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32) 56 | prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0) 57 | prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len]) 58 | 59 | result = { 60 | 'prompt_name': prompt_name, 61 | 'prompt_phoneme_ids': prompt_phoneme_ids, 62 | 'prompt_semantic_tokens': prompt_semantic_tokens, 63 | 'prompt_phoneme_ids_len': prompt_phoneme_ids_len 64 | } 65 | 66 | return result 67 | 68 | 69 | def parse_args(): 70 | # parse args and config 71 | parser = argparse.ArgumentParser( 72 | description="Run SoundStorm AR S1 model for input text file") 73 | 74 | parser.add_argument( 75 | '--config_file', 76 | type=str, 77 | default='conf/default.yaml', 78 | help='path of config file') 79 | 80 | parser.add_argument( 81 | "--text_file", 82 | type=str, 83 | help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line." 84 | ) 85 | 86 | parser.add_argument( 87 | '--ckpt_path', 88 | type=str, 89 | default='exp/default/ckpt/epoch=99-step=49000.ckpt', 90 | help='Checkpoint file of SoundStorm AR S1 model.') 91 | 92 | parser.add_argument( 93 | '--prompt_wav_path', 94 | type=str, 95 | default=None, 96 | help='extract prompt semantic and prompt phonemes from prompt wav') 97 | 98 | # to get semantic tokens from prompt_wav 99 | parser.add_argument("--hubert_path", type=str, default=None) 100 | parser.add_argument("--quantizer_path", type=str, default=None) 101 | 102 | parser.add_argument("--output_dir", type=str, help="output dir.") 103 | 104 | args = parser.parse_args() 105 | return args 106 | 107 | 108 | def main(): 109 | args = parse_args() 110 | config = load_yaml_config(args.config_file) 111 | 112 | output_dir = Path(args.output_dir) 113 | output_dir.mkdir(parents=True, exist_ok=True) 114 | 115 | hz = 50 116 | max_sec = config['data']['max_sec'] 117 | 118 | # get models 119 | t2s_model = Text2SemanticLightningModule.load_from_checkpoint( 120 | checkpoint_path=args.ckpt_path, config=config) 121 | t2s_model.cuda() 122 | t2s_model.eval() 123 | 124 | phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us') 125 | 126 | # models for prompt 127 | asr_model = whisper.load_model("tiny.en") 128 | 129 | semantic_tokenizer = SemanticTokenizer( 130 | hubert_path=args.hubert_path, 131 | quantizer_path=args.quantizer_path, 132 | duplicate=True) 133 | 134 | prompt_result = get_prompt( 135 | prompt_wav_path=args.prompt_wav_path, 136 | asr_model=asr_model, 137 | phonemizer=phonemizer, 138 | semantic_tokenizer=semantic_tokenizer) 139 | 140 | # zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的 141 | # (B, 1) 142 | # prompt = torch.ones( 143 | # batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0 144 | 145 | prompt = prompt_result['prompt_semantic_tokens'] 146 | prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len'] 147 | prompt_phoneme_ids = prompt_result['prompt_phoneme_ids'] 148 | 149 | sentences = [] 150 | with open(args.text_file, 'rt', encoding='utf-8') as f: 151 | for line in f: 152 | if line.strip() != "": 153 | items = re.split(r"\s+", line.strip(), 1) 154 | utt_id = items[0] 155 | sentence = " ".join(items[1:]) 156 | sentences.append((utt_id, sentence)) 157 | semantic_data = [['item_name', 'semantic_audio']] 158 | for utt_id, sentence in sentences[1:]: 159 | # 需要自己构造伪 batch 输入给模型 160 | batch = get_batch(sentence, phonemizer) 161 | # prompt 和真正的输入拼接 162 | all_phoneme_ids = torch.cat( 163 | [prompt_phoneme_ids, batch['phoneme_ids']], dim=1) 164 | # 或者可以直接求 all_phoneme_ids 的 shape[-1] 165 | all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len'] 166 | st = time.time() 167 | with torch.no_grad(): 168 | pred_semantic = t2s_model.model.infer( 169 | all_phoneme_ids.cuda(), 170 | all_phoneme_len.cuda(), 171 | prompt.cuda(), 172 | top_k=config['inference']['top_k'], 173 | early_stop_num=hz * max_sec) 174 | print(f'{time.time() - st} sec used in T2S') 175 | 176 | # 删除 prompt 对应的部分 177 | prompt_len = prompt.shape[-1] 178 | pred_semantic = pred_semantic[:, prompt_len:] 179 | 180 | # bs = 1 181 | pred_semantic = pred_semantic[0] 182 | semantic_token = pred_semantic.detach().cpu().numpy().tolist() 183 | semantic_token_str = ' '.join(str(x) for x in semantic_token) 184 | semantic_data.append([utt_id, semantic_token_str]) 185 | 186 | delimiter = '\t' 187 | filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv' 188 | with open(filename, 'w', encoding='utf-8') as writer: 189 | for row in semantic_data: 190 | line = delimiter.join(row) 191 | writer.write(line + '\n') 192 | # clean semantic token for next setence 193 | semantic_data = [['item_name', 'semantic_audio']] 194 | 195 | 196 | if __name__ == "__main__": 197 | main() 198 | -------------------------------------------------------------------------------- /AR/exps/test.py: -------------------------------------------------------------------------------- 1 | # test from dump file 2 | import argparse 3 | import time 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from AR.data.dataset import Text2SemanticDataset 9 | from AR.models.t2s_lightning_module import Text2SemanticLightningModule 10 | from AR.utils.io import load_yaml_config 11 | from torch.utils.data import DataLoader 12 | 13 | 14 | def parse_args(): 15 | # parse args and config 16 | parser = argparse.ArgumentParser( 17 | description="Run SoundStorm AR S1 model for test set.") 18 | 19 | parser.add_argument( 20 | '--config_file', 21 | type=str, 22 | default='conf/default.yaml', 23 | help='path of config file') 24 | 25 | # args for dataset 26 | parser.add_argument( 27 | '--test_semantic_path', 28 | type=str, 29 | default='dump/test/semantic_token.tsv') 30 | parser.add_argument( 31 | '--test_phoneme_path', type=str, default='dump/test/phonemes.npy') 32 | 33 | parser.add_argument( 34 | '--ckpt_path', 35 | type=str, 36 | default='exp/default/ckpt/epoch=99-step=49000.ckpt', 37 | help='Checkpoint file of SoundStorm AR S1 model.') 38 | 39 | parser.add_argument("--output_dir", type=str, help="output dir.") 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | 45 | def main(): 46 | args = parse_args() 47 | 48 | config = load_yaml_config(args.config_file) 49 | 50 | output_dir = Path(args.output_dir) 51 | output_dir.mkdir(parents=True, exist_ok=True) 52 | 53 | batch_size = 1 54 | hz = 50 55 | max_sec = config['data']['max_sec'] 56 | 57 | # get dataset 58 | test_dataset = Text2SemanticDataset( 59 | phoneme_path=args.test_phoneme_path, 60 | semantic_path=args.test_semantic_path, 61 | # max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等 62 | # 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断 63 | max_sec=100, 64 | max_sample=8, 65 | pad_val=config['data']['pad_val']) 66 | # get model 67 | t2s_model = Text2SemanticLightningModule.load_from_checkpoint( 68 | checkpoint_path=args.ckpt_path, config=config) 69 | t2s_model.cuda() 70 | t2s_model.eval() 71 | 72 | # 获取 batch_size 条 73 | # 创建 DataLoader,并指定 collate_fn 函数 74 | dataloader = DataLoader( 75 | test_dataset, 76 | batch_size=batch_size, 77 | shuffle=False, 78 | collate_fn=test_dataset.collate) 79 | 80 | item_names = test_dataset.__get_item_names__() 81 | 82 | # 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应 83 | semantic_data = [['item_name', 'semantic_audio']] 84 | for i, batch in enumerate(dataloader): 85 | # 要保证 bs = 1 86 | utt_id = item_names[i] 87 | if i == 0: 88 | print("utt_id:", utt_id) 89 | # bs > 1 时会补零 90 | # 与 validation_step() 保持一致 91 | semantic_len = batch['semantic_ids'].size(1) 92 | # 以 batch['semantic_ids'] 的前 150 个为 prompt 93 | # 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样 94 | prompt_len = min(int(semantic_len * 0.5), 150) 95 | # 输入纯文本时 prompt 该输入什么?=> see t2s.py 96 | prompt = batch['semantic_ids'][:, :prompt_len] 97 | # # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的 98 | # 证明 semantic token 中还是包含了音色信息 99 | # prompt = torch.ones( 100 | # batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0 101 | # print("prompt:", prompt) 102 | # print("prompt.shape:", prompt.shape) 103 | np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy()) 104 | 105 | st = time.time() 106 | with torch.no_grad(): 107 | # calculate acc for test 108 | loss, acc = t2s_model.model.forward( 109 | batch['phoneme_ids'].cuda(), 110 | batch['phoneme_ids_len'].cuda(), 111 | batch['semantic_ids'].cuda(), 112 | batch['semantic_ids_len'].cuda()) 113 | print("top_3_acc of this batch:", acc) 114 | pred_semantic = t2s_model.model.infer( 115 | batch['phoneme_ids'].cuda(), 116 | batch['phoneme_ids_len'].cuda(), 117 | prompt.cuda(), 118 | top_k=config['inference']['top_k'], 119 | # hz * max_sec in train dataloader 120 | # 生成的长度是 1002 应该是有一些 pad 121 | early_stop_num=hz * max_sec) 122 | # bs = 1 123 | pred_semantic = pred_semantic[0] 124 | print(f'{time.time() - st} sec used in T2S') 125 | semantic_token = pred_semantic.detach().cpu().numpy().tolist() 126 | semantic_token_str = ' '.join(str(x) for x in semantic_token) 127 | semantic_data.append([utt_id, semantic_token_str]) 128 | else: 129 | break 130 | delimiter = '\t' 131 | filename = output_dir / "semantic_token.tsv" 132 | with open(filename, 'w', encoding='utf-8') as writer: 133 | for row in semantic_data: 134 | line = delimiter.join(row) 135 | writer.write(line + '\n') 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /AR/exps/text.txt: -------------------------------------------------------------------------------- 1 | 001 Life was like a box of chocolates, you never know what you're gonna get. 2 | 002 With great power there must come great responsibility. 3 | 003 To be or not to be, that’s a question. 4 | 004 A man can be destroyed but not defeated 5 | 005 Do not, for one repulse, give up the purpose that you resolved to effort. 6 | 006 Death is just a part of life, something we're all destined to do. 7 | 007 I think it's hard winning a war with words. 8 | 008 Don’t argue with the people of strong determination, because they may change the fact! 9 | 009 Love you three thousand times. 10 | 010 tidy tiger tied a tie tighter to tidy her tiny tall. -------------------------------------------------------------------------------- /AR/exps/train.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py 2 | import argparse 3 | import logging 4 | import os 5 | from pathlib import Path 6 | 7 | import torch 8 | from pytorch_lightning import seed_everything 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from pytorch_lightning.loggers import WandbLogger 12 | from pytorch_lightning.strategies import DDPStrategy 13 | from AR.data.data_module import Text2SemanticDataModule 14 | from AR.models.t2s_lightning_module import Text2SemanticLightningModule 15 | from soundstorm.utils.io import load_yaml_config 16 | logging.getLogger('numba').setLevel(logging.WARNING) 17 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 18 | torch.set_float32_matmul_precision('high') 19 | from soundstorm.utils import get_newest_ckpt 20 | 21 | 22 | def main(args): 23 | output_dir = Path(args.output_dir) 24 | output_dir.mkdir(parents=True, exist_ok=True) 25 | 26 | ckpt_dir = output_dir / 'ckpt' 27 | ckpt_dir.mkdir(parents=True, exist_ok=True) 28 | 29 | config = load_yaml_config(args.config_file) 30 | 31 | seed_everything(config["train"]["seed"], workers=True) 32 | ckpt_callback: ModelCheckpoint = ModelCheckpoint( 33 | save_top_k=-1, 34 | save_on_train_epoch_end=False, 35 | every_n_epochs=config["train"]["save_every_n_epoch"], 36 | dirpath=ckpt_dir) 37 | logger = WandbLogger( 38 | project="AR_S1", 39 | name=output_dir.stem, 40 | save_dir=output_dir, 41 | # resume the loss curve 42 | resume=True, 43 | # id='k19kvsq8' 44 | ) 45 | trainer: Trainer = Trainer( 46 | max_epochs=config["train"]["epochs"], 47 | accelerator='gpu', 48 | devices=-1, 49 | benchmark=False, 50 | fast_dev_run=False, 51 | strategy=DDPStrategy(find_unused_parameters=True), 52 | precision=config["train"]["precision"], 53 | logger=logger, 54 | callbacks=[ckpt_callback]) 55 | 56 | model: Text2SemanticLightningModule = Text2SemanticLightningModule( 57 | config, output_dir) 58 | 59 | data_module: Text2SemanticDataModule = Text2SemanticDataModule( 60 | config, 61 | train_semantic_path=args.train_semantic_path, 62 | train_phoneme_path=args.train_phoneme_path, 63 | dev_semantic_path=args.dev_semantic_path, 64 | dev_phoneme_path=args.dev_phoneme_path) 65 | 66 | try: 67 | # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序 68 | newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) 69 | ckpt_path = ckpt_dir / newest_ckpt_name 70 | except Exception: 71 | ckpt_path = None 72 | print("ckpt_path:", ckpt_path) 73 | trainer.fit(model, data_module, ckpt_path=ckpt_path) 74 | 75 | 76 | # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument( 80 | '--config_file', 81 | type=str, 82 | default='conf/default.yaml', 83 | help='path of config file') 84 | # args for dataset 85 | parser.add_argument( 86 | '--train_semantic_path', 87 | type=str, 88 | default='dump/train/semantic_token.tsv') 89 | parser.add_argument( 90 | '--train_phoneme_path', type=str, default='dump/train/phonemes.npy') 91 | parser.add_argument( 92 | '--dev_semantic_path', type=str, default='dump/dev/semantic_token.tsv') 93 | parser.add_argument( 94 | '--dev_phoneme_path', type=str, default='dump/dev/phonemes.npy') 95 | parser.add_argument( 96 | '--output_dir', 97 | type=str, 98 | default='exp/default', 99 | help='directory to save the results') 100 | 101 | args = parser.parse_args() 102 | logging.info(str(args)) 103 | main(args) 104 | -------------------------------------------------------------------------------- /AR/exps/train_librilight_6k.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py 2 | import argparse 3 | import logging 4 | import os 5 | from pathlib import Path 6 | 7 | import torch 8 | from pytorch_lightning import seed_everything 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from pytorch_lightning.loggers import WandbLogger 12 | from pytorch_lightning.strategies import DDPStrategy 13 | from AR.data.data_module_librilight_6k import Text2SemanticDataModule 14 | from AR.models.t2s_lightning_module import Text2SemanticLightningModule 15 | from soundstorm.utils import get_newest_ckpt 16 | from soundstorm.utils.io import load_yaml_config 17 | 18 | logging.getLogger('numba').setLevel(logging.WARNING) 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | torch.set_float32_matmul_precision('high') 21 | 22 | 23 | def main(args): 24 | output_dir = Path(args.output_dir) 25 | output_dir.mkdir(parents=True, exist_ok=True) 26 | 27 | ckpt_dir = output_dir / 'ckpt' 28 | ckpt_dir.mkdir(parents=True, exist_ok=True) 29 | 30 | config = load_yaml_config(args.config_file) 31 | 32 | seed_everything(config["train"]["seed"], workers=True) 33 | 34 | ckpt_callback: ModelCheckpoint = ModelCheckpoint( 35 | save_top_k=-1, 36 | save_on_train_epoch_end=False, 37 | every_n_train_steps=config["train"]["every_n_train_steps"], 38 | dirpath=ckpt_dir) 39 | logger = WandbLogger( 40 | project="AR_S1_LibriLight", 41 | name=output_dir.stem, 42 | save_dir=output_dir, 43 | # resume the loss curve 44 | resume=True, 45 | # id='k19kvsq8' 46 | ) 47 | trainer: Trainer = Trainer( 48 | max_epochs=config["train"]["epochs"], 49 | accelerator='gpu', 50 | devices=-1, 51 | benchmark=False, 52 | fast_dev_run=False, 53 | strategy=DDPStrategy(find_unused_parameters=True), 54 | precision=config["train"]["precision"], 55 | logger=logger, 56 | callbacks=[ckpt_callback]) 57 | 58 | model: Text2SemanticLightningModule = Text2SemanticLightningModule( 59 | config, output_dir) 60 | 61 | data_module: Text2SemanticDataModule = Text2SemanticDataModule( 62 | config, 63 | train_semantic_dirs=args.train_semantic_dirs, 64 | train_phoneme_dirs=args.train_phoneme_dirs, 65 | dev_semantic_dirs=args.dev_semantic_dirs, 66 | dev_phoneme_dirs=args.dev_phoneme_dirs, 67 | train_non_speech_dirs=args.train_non_speech_dirs, 68 | dev_non_speech_dirs=args.dev_non_speech_dirs) 69 | try: 70 | newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) 71 | ckpt_path = ckpt_dir / newest_ckpt_name 72 | except Exception: 73 | ckpt_path = None 74 | 75 | print("ckpt_path:", ckpt_path) 76 | trainer.fit(model, data_module, ckpt_path=ckpt_path) 77 | 78 | 79 | # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument( 83 | '--config_file', 84 | type=str, 85 | default='conf/default.yaml', 86 | help='path of config file') 87 | # args for dataset 88 | parser.add_argument( 89 | '--train_semantic_dirs', 90 | type=list, 91 | nargs='+', 92 | default=["dump/small/train/"], 93 | help='dirs of train semantic') 94 | parser.add_argument( 95 | '--train_phoneme_dirs', 96 | type=list, 97 | nargs='+', 98 | default=["dump/small/train/"], 99 | help='dirs of train phoneme') 100 | parser.add_argument( 101 | '--dev_semantic_dirs', 102 | type=list, 103 | nargs='+', 104 | default=["dump/small/dev/"], 105 | help='dirs of dev semantic') 106 | parser.add_argument( 107 | '--dev_phoneme_dirs', 108 | type=list, 109 | nargs='+', 110 | default=["dump/small/dev/"], 111 | help='dirs of dev phoneme') 112 | parser.add_argument( 113 | '--output_dir', 114 | type=str, 115 | default='exp/default', 116 | help='directory to save the results') 117 | 118 | parser.add_argument( 119 | '--train_non_speech_dirs', 120 | type=list, 121 | nargs='+', 122 | default=None, 123 | help='dirs of train non_speech data') 124 | 125 | parser.add_argument( 126 | '--dev_non_speech_dirs', 127 | type=list, 128 | nargs='+', 129 | default=None, 130 | help='dirs of dev non_speech data') 131 | 132 | args = parser.parse_args() 133 | 134 | new_train_semantic_dirs = [] 135 | new_train_phoneme_dirs = [] 136 | new_dev_semantic_dirs = [] 137 | new_dev_phoneme_dirs = [] 138 | 139 | new_train_non_speech_dirs = [] 140 | new_dev_non_speech_dirs = [] 141 | 142 | # format dataset dirs 143 | for item in args.train_semantic_dirs: 144 | new_train_semantic_dirs.append(''.join(item)) 145 | args.train_semantic_dirs = new_train_semantic_dirs 146 | 147 | for item in args.train_phoneme_dirs: 148 | new_train_phoneme_dirs.append(''.join(item)) 149 | args.train_phoneme_dirs = new_train_phoneme_dirs 150 | 151 | for item in args.dev_semantic_dirs: 152 | new_dev_semantic_dirs.append(''.join(item)) 153 | args.dev_semantic_dirs = new_dev_semantic_dirs 154 | 155 | for item in args.dev_phoneme_dirs: 156 | new_dev_phoneme_dirs.append(''.join(item)) 157 | args.dev_phoneme_dirs = new_dev_phoneme_dirs 158 | 159 | if args.train_non_speech_dirs is not None: 160 | for item in args.train_non_speech_dirs: 161 | new_train_non_speech_dirs.append(''.join(item)) 162 | args.train_non_speech_dirs = new_train_non_speech_dirs 163 | 164 | if args.dev_non_speech_dirs is not None: 165 | for item in args.dev_non_speech_dirs: 166 | new_dev_non_speech_dirs.append(''.join(item)) 167 | args.dev_non_speech_dirs = new_dev_non_speech_dirs 168 | 169 | logging.info(str(args)) 170 | main(args) 171 | -------------------------------------------------------------------------------- /AR/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/AR/models/__init__.py -------------------------------------------------------------------------------- /AR/models/t2s_lightning_module.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py 2 | import os 3 | from typing import Dict 4 | 5 | import torch 6 | from pytorch_lightning import LightningModule 7 | from AR.models.t2s_model import Text2SemanticDecoder 8 | from AR.modules.lr_schedulers import WarmupCosineLRSchedule 9 | from AR.modules.optim import ScaledAdam 10 | 11 | 12 | class Text2SemanticLightningModule(LightningModule): 13 | def __init__(self, config, output_dir): 14 | super().__init__() 15 | self.config = config 16 | self.top_k = 3 17 | self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) 18 | self.automatic_optimization = False 19 | self.save_hyperparameters() 20 | self.eval_dir = output_dir / 'eval' 21 | self.eval_dir.mkdir(parents=True, exist_ok=True) 22 | 23 | def training_step(self, batch: Dict, batch_idx: int): 24 | 25 | opt = self.optimizers() 26 | scheduler = self.lr_schedulers() 27 | loss, acc = self.model.forward( 28 | batch['phoneme_ids'], batch['phoneme_ids_len'], 29 | batch['semantic_ids'], batch['semantic_ids_len']) 30 | self.manual_backward(loss) 31 | 32 | if batch_idx > 0 and batch_idx % 4 == 0: 33 | opt.step() 34 | opt.zero_grad() 35 | scheduler.step() 36 | 37 | self.log( 38 | "total_loss", 39 | loss, 40 | on_step=True, 41 | on_epoch=True, 42 | prog_bar=True, 43 | sync_dist=True) 44 | self.log( 45 | "lr", 46 | scheduler.get_last_lr()[0], 47 | on_epoch=True, 48 | prog_bar=True, 49 | sync_dist=True) 50 | self.log( 51 | f"top_{self.top_k}_acc", 52 | acc, 53 | on_step=True, 54 | on_epoch=True, 55 | prog_bar=True, 56 | sync_dist=True) 57 | 58 | def validation_step(self, batch: Dict, batch_idx: int): 59 | # get loss 60 | loss, acc = self.model.forward( 61 | batch['phoneme_ids'], batch['phoneme_ids_len'], 62 | batch['semantic_ids'], batch['semantic_ids_len']) 63 | 64 | self.log( 65 | "val_total_loss", 66 | loss, 67 | on_step=True, 68 | on_epoch=True, 69 | prog_bar=True, 70 | sync_dist=True) 71 | self.log( 72 | f"val_top_{self.top_k}_acc", 73 | acc, 74 | on_step=True, 75 | on_epoch=True, 76 | prog_bar=True, 77 | sync_dist=True) 78 | 79 | # get infer output 80 | semantic_len = batch['semantic_ids'].size(1) 81 | prompt_len = min(int(semantic_len * 0.5), 150) 82 | prompt = batch['semantic_ids'][:, :prompt_len] 83 | pred_semantic = self.model.infer(batch['phoneme_ids'], 84 | batch['phoneme_ids_len'], prompt) 85 | save_name = f'semantic_toks_{batch_idx}.pt' 86 | save_path = os.path.join(self.eval_dir, save_name) 87 | torch.save(pred_semantic.detach().cpu(), save_path) 88 | 89 | def configure_optimizers(self): 90 | model_parameters = self.model.parameters() 91 | parameters_names = [] 92 | parameters_names.append([ 93 | name_param_pair[0] 94 | for name_param_pair in self.model.named_parameters() 95 | ]) 96 | lm_opt = ScaledAdam( 97 | model_parameters, 98 | lr=0.01, 99 | betas=(0.9, 0.95), 100 | clipping_scale=2.0, 101 | parameters_names=parameters_names, 102 | show_dominant_parameters=False, 103 | clipping_update_period=1000, ) 104 | 105 | return { 106 | "optimizer": lm_opt, 107 | "lr_scheduler": { 108 | "scheduler": 109 | WarmupCosineLRSchedule( 110 | lm_opt, 111 | init_lr=self.config['optimizer']['lr_init'], 112 | peak_lr=self.config['optimizer']['lr'], 113 | end_lr=self.config['optimizer']['lr_end'], 114 | warmup_steps=self.config['optimizer']['warmup_steps'], 115 | total_steps=self.config['optimizer']['decay_steps']) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /AR/models/t2s_model.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from AR.models.utils import make_pad_mask 6 | from AR.models.utils import topk_sampling 7 | from AR.modules.embedding import SinePositionalEmbedding 8 | from AR.modules.embedding import TokenEmbedding 9 | from AR.modules.transformer import LayerNorm 10 | from AR.modules.transformer import TransformerEncoder 11 | from AR.modules.transformer import TransformerEncoderLayer 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torchmetrics.classification import MulticlassAccuracy 15 | 16 | default_config = { 17 | "embedding_dim": 512, 18 | "hidden_dim": 512, 19 | "num_head": 8, 20 | "num_layers": 12, 21 | "num_codebook": 8, 22 | "p_dropout": 0.0, 23 | "vocab_size": 1024 + 1, 24 | "phoneme_vocab_size": 512, 25 | "EOS": 1024 26 | } 27 | 28 | 29 | class Text2SemanticDecoder(nn.Module): 30 | def __init__(self, config, norm_first=False, top_k=3): 31 | super(Text2SemanticDecoder, self).__init__() 32 | self.model_dim = config['model']["hidden_dim"] 33 | self.embedding_dim = config['model']["embedding_dim"] 34 | self.num_head = config['model']["head"] 35 | self.num_layers = config['model']["n_layer"] 36 | self.norm_first = norm_first 37 | self.vocab_size = config['model']["vocab_size"] 38 | self.phoneme_vocab_size = config['model']["phoneme_vocab_size"] 39 | self.p_dropout = config['model']["dropout"] 40 | self.EOS = config['model']["EOS"] 41 | self.norm_first = norm_first 42 | assert self.EOS == self.vocab_size - 1 43 | # should be same as num of kmeans bin 44 | # assert self.EOS == 1024 45 | 46 | self.ar_text_embedding = TokenEmbedding( 47 | self.embedding_dim, self.phoneme_vocab_size, self.p_dropout) 48 | self.ar_text_position = SinePositionalEmbedding( 49 | self.embedding_dim, dropout=0.1, scale=False, alpha=True) 50 | self.ar_audio_embedding = TokenEmbedding( 51 | self.embedding_dim, self.vocab_size, self.p_dropout) 52 | self.ar_audio_position = SinePositionalEmbedding( 53 | self.embedding_dim, dropout=0.1, scale=False, alpha=True) 54 | 55 | self.h = TransformerEncoder( 56 | TransformerEncoderLayer( 57 | d_model=self.model_dim, 58 | nhead=self.num_head, 59 | dim_feedforward=self.model_dim * 4, 60 | dropout=0.1, 61 | batch_first=True, 62 | norm_first=norm_first, ), 63 | num_layers=self.num_layers, 64 | norm=LayerNorm(self.model_dim) if norm_first else None, ) 65 | 66 | self.ar_predict_layer = nn.Linear( 67 | self.model_dim, self.vocab_size, bias=False) 68 | self.loss_fct = nn.CrossEntropyLoss(reduction='sum') 69 | 70 | self.ar_accuracy_metric = MulticlassAccuracy( 71 | self.vocab_size, 72 | top_k=top_k, 73 | average="micro", 74 | multidim_average="global", 75 | ignore_index=self.EOS, ) 76 | 77 | def forward(self, x, x_lens, y, y_lens): 78 | ''' 79 | x: phoneme_ids 80 | y: semantic_ids 81 | ''' 82 | x = self.ar_text_embedding(x) 83 | x = self.ar_text_position(x) 84 | x_mask = make_pad_mask(x_lens) 85 | 86 | y_mask = make_pad_mask(y_lens) 87 | y_mask_int = y_mask.type(torch.int64) 88 | codes = y.type(torch.int64) * (1 - y_mask_int) 89 | 90 | # Training 91 | # AR Decoder 92 | y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS) 93 | x_len = x_lens.max() 94 | y_len = y_lens.max() 95 | y_emb = self.ar_audio_embedding(y) 96 | y_pos = self.ar_audio_position(y_emb) 97 | 98 | xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) 99 | ar_xy_padding_mask = xy_padding_mask 100 | 101 | x_attn_mask = F.pad( 102 | torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), 103 | (0, y_len), 104 | value=True, ) 105 | y_attn_mask = F.pad( 106 | torch.triu( 107 | torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), 108 | diagonal=1, ), 109 | (x_len, 0), 110 | value=False, ) 111 | xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) 112 | bsz, src_len = x.shape[0], x_len + y_len 113 | _xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len) 114 | .expand(-1, self.num_head, -1, -1) 115 | .reshape(bsz * self.num_head, 1, src_len)) 116 | xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) 117 | new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) 118 | new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) 119 | xy_attn_mask = new_attn_mask 120 | # x 和完整的 y 一次性输入模型 121 | xy_pos = torch.concat([x, y_pos], dim=1) 122 | xy_dec, _ = self.h( 123 | (xy_pos, None), 124 | mask=xy_attn_mask, ) 125 | logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) 126 | # loss 127 | # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum 128 | loss = F.cross_entropy(logits, targets, reduction='sum') 129 | acc = self.ar_accuracy_metric(logits.detach(), targets).item() 130 | return loss, acc 131 | 132 | # 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么 133 | def infer(self, 134 | x, 135 | x_lens, 136 | prompts, 137 | top_k: int=-100, 138 | early_stop_num: int=-1, 139 | temperature: float=1.0): 140 | 141 | x = self.ar_text_embedding(x) 142 | x = self.ar_text_position(x) 143 | 144 | # AR Decoder 145 | y = prompts 146 | prefix_len = y.shape[1] 147 | x_len = x.shape[1] 148 | x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) 149 | stop = False 150 | for _ in tqdm(range(1500)): 151 | y_emb = self.ar_audio_embedding(y) 152 | y_pos = self.ar_audio_position(y_emb) 153 | # x 和逐渐增长的 y 一起输入给模型 154 | xy_pos = torch.concat([x, y_pos], dim=1) 155 | y_len = y.shape[1] 156 | x_attn_mask_pad = F.pad( 157 | x_attn_mask, 158 | (0, y_len), 159 | value=True, ) 160 | y_attn_mask = F.pad( 161 | torch.triu( 162 | torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), 163 | (x_len, 0), 164 | value=False, ) 165 | xy_attn_mask = torch.concat( 166 | [x_attn_mask_pad, y_attn_mask], dim=0).to(y.device) 167 | 168 | xy_dec, _ = self.h( 169 | (xy_pos, None), 170 | mask=xy_attn_mask, ) 171 | logits = self.ar_predict_layer(xy_dec[:, -1]) 172 | samples = topk_sampling( 173 | logits, top_k=top_k, top_p=1.0, temperature=temperature) 174 | 175 | if early_stop_num != -1 and (y.shape[1] - prefix_len 176 | ) > early_stop_num: 177 | print("use early stop num:", early_stop_num) 178 | stop = True 179 | 180 | if torch.argmax( 181 | logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS: 182 | # print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS) 183 | stop = True 184 | if stop: 185 | if prompts.shape[1] == y.shape[1]: 186 | y = torch.concat([y, torch.zeros_like(samples)], dim=1) 187 | print('bad zero prediction') 188 | print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]") 189 | break 190 | # 本次生成的 semantic_ids 和之前的 y 构成新的 y 191 | y = torch.concat([y, samples], dim=1) 192 | return y 193 | 194 | def pad_y_eos(self, y, y_mask_int, eos_id): 195 | targets = F.pad( 196 | y, (0, 1), value=0) + eos_id * F.pad( 197 | y_mask_int, (0, 1), value=1) 198 | # 错位 199 | return targets[:, :-1], targets[:, 1:] 200 | -------------------------------------------------------------------------------- /AR/models/utils.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\ 2 | import torch 3 | import torch.nn.functional as F 4 | import torchaudio 5 | 6 | 7 | def sequence_mask(length, max_length=None): 8 | if max_length is None: 9 | max_length = length.max() 10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 11 | return x.unsqueeze(0) < length.unsqueeze(1) 12 | 13 | 14 | def make_pad_mask(lengths: torch.Tensor, max_len: int=0) -> torch.Tensor: 15 | """ 16 | Args: 17 | lengths: 18 | A 1-D tensor containing sentence lengths. 19 | max_len: 20 | The length of masks. 21 | Returns: 22 | Return a 2-D bool tensor, where masked positions 23 | are filled with `True` and non-masked positions are 24 | filled with `False`. 25 | 26 | #>>> lengths = torch.tensor([1, 3, 2, 5]) 27 | #>>> make_pad_mask(lengths) 28 | tensor([[False, True, True, True, True], 29 | [False, False, False, True, True], 30 | [False, False, True, True, True], 31 | [False, False, False, False, False]]) 32 | """ 33 | assert lengths.ndim == 1, lengths.ndim 34 | max_len = max(max_len, lengths.max()) 35 | n = lengths.size(0) 36 | seq_range = torch.arange(0, max_len, device=lengths.device) 37 | expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) 38 | 39 | return expaned_lengths >= lengths.unsqueeze(-1) 40 | 41 | 42 | # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py 43 | def top_k_top_p_filtering(logits, 44 | top_k=0, 45 | top_p=1.0, 46 | filter_value=-float("Inf"), 47 | min_tokens_to_keep=1): 48 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 49 | Args: 50 | logits: logits distribution shape (batch size, vocabulary size) 51 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 52 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 53 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 54 | Make sure we keep at least min_tokens_to_keep per batch example in the output 55 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 56 | """ 57 | if top_k > 0: 58 | top_k = min(max(top_k, min_tokens_to_keep), 59 | logits.size(-1)) # Safety check 60 | # Remove all tokens with a probability less than the last token of the top-k 61 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 62 | logits[indices_to_remove] = filter_value 63 | 64 | if top_p < 1.0: 65 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 66 | cumulative_probs = torch.cumsum( 67 | F.softmax(sorted_logits, dim=-1), dim=-1) 68 | 69 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 70 | sorted_indices_to_remove = cumulative_probs > top_p 71 | if min_tokens_to_keep > 1: 72 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 73 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 74 | # Shift the indices to the right to keep also the first token above the threshold 75 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ 76 | ..., :-1].clone() 77 | sorted_indices_to_remove[..., 0] = 0 78 | 79 | # scatter sorted tensors to original indexing 80 | indices_to_remove = sorted_indices_to_remove.scatter( 81 | 1, sorted_indices, sorted_indices_to_remove) 82 | logits[indices_to_remove] = filter_value 83 | return logits 84 | 85 | 86 | def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): 87 | # temperature: (`optional`) float 88 | # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. 89 | # top_k: (`optional`) int 90 | # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. 91 | # top_p: (`optional`) float 92 | # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. 93 | 94 | # Temperature (higher temperature => more likely to sample low probability tokens) 95 | if temperature != 1.0: 96 | logits = logits / temperature 97 | # Top-p/top-k filtering 98 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 99 | # Sample 100 | token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) 101 | return token 102 | -------------------------------------------------------------------------------- /AR/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/AR/modules/__init__.py -------------------------------------------------------------------------------- /AR/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class TokenEmbedding(nn.Module): 9 | def __init__( 10 | self, 11 | embedding_dim: int, 12 | vocab_size: int, 13 | dropout: float=0.0, ): 14 | super().__init__() 15 | 16 | self.vocab_size = vocab_size 17 | self.embedding_dim = embedding_dim 18 | 19 | self.dropout = torch.nn.Dropout(p=dropout) 20 | self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) 21 | 22 | @property 23 | def weight(self) -> torch.Tensor: 24 | return self.word_embeddings.weight 25 | 26 | def embedding(self, index: int) -> torch.Tensor: 27 | return self.word_embeddings.weight[index:index + 1] 28 | 29 | def forward(self, x: torch.Tensor): 30 | x = self.word_embeddings(x) 31 | x = self.dropout(x) 32 | return x 33 | 34 | 35 | class SinePositionalEmbedding(nn.Module): 36 | def __init__( 37 | self, 38 | embedding_dim: int, 39 | dropout: float=0.0, 40 | scale: bool=False, 41 | alpha: bool=False, ): 42 | super().__init__() 43 | self.embedding_dim = embedding_dim 44 | self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 45 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 46 | self.dropout = torch.nn.Dropout(p=dropout) 47 | 48 | self.reverse = False 49 | self.pe = None 50 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 51 | 52 | def extend_pe(self, x): 53 | """Reset the positional encodings.""" 54 | if self.pe is not None: 55 | if self.pe.size(1) >= x.size(1): 56 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 57 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 58 | return 59 | pe = torch.zeros(x.size(1), self.embedding_dim) 60 | if self.reverse: 61 | position = torch.arange( 62 | x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) 63 | else: 64 | position = torch.arange( 65 | 0, x.size(1), dtype=torch.float32).unsqueeze(1) 66 | div_term = torch.exp( 67 | torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * 68 | -(math.log(10000.0) / self.embedding_dim)) 69 | pe[:, 0::2] = torch.sin(position * div_term) 70 | pe[:, 1::2] = torch.cos(position * div_term) 71 | pe = pe.unsqueeze(0) 72 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | self.extend_pe(x) 76 | output = x.unsqueeze(-1) if x.ndim == 2 else x 77 | output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)] 78 | return self.dropout(output) 79 | -------------------------------------------------------------------------------- /AR/modules/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py 2 | import math 3 | 4 | import torch 5 | from matplotlib import pyplot as plt 6 | from torch import nn 7 | from torch.optim import Adam 8 | 9 | 10 | class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): 11 | """ 12 | Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. 13 | """ 14 | 15 | def __init__(self, 16 | optimizer, 17 | init_lr, 18 | peak_lr, 19 | end_lr, 20 | warmup_steps=10000, 21 | total_steps=400000, 22 | current_step=0): 23 | self.init_lr = init_lr 24 | self.peak_lr = peak_lr 25 | self.end_lr = end_lr 26 | self.optimizer = optimizer 27 | self._warmup_rate = (peak_lr - init_lr) / warmup_steps 28 | self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps) 29 | self._current_step = current_step 30 | self.lr = init_lr 31 | self.warmup_steps = warmup_steps 32 | self.total_steps = total_steps 33 | self._last_lr = [self.lr] 34 | 35 | def set_lr(self, lr): 36 | self._last_lr = [g['lr'] for g in self.optimizer.param_groups] 37 | for g in self.optimizer.param_groups: 38 | g['lr'] = lr 39 | 40 | def step(self): 41 | if self._current_step < self.warmup_steps: 42 | lr = self.init_lr + self._warmup_rate * self._current_step 43 | 44 | elif self._current_step > self.total_steps: 45 | lr = self.end_lr 46 | 47 | else: 48 | decay_ratio = (self._current_step - self.warmup_steps) / ( 49 | self.total_steps - self.warmup_steps) 50 | if decay_ratio < 0.0 or decay_ratio > 1.0: 51 | raise RuntimeError( 52 | "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings." 53 | ) 54 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 55 | lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) 56 | 57 | self.set_lr(lr) 58 | self.lr = lr 59 | self._current_step += 1 60 | return self.lr 61 | 62 | 63 | if __name__ == '__main__': 64 | m = nn.Linear(10, 10) 65 | opt = Adam(m.parameters(), lr=1e-4) 66 | s = WarmupCosineLRSchedule( 67 | opt, 68 | 1e-6, 69 | 2e-4, 70 | 1e-6, 71 | warmup_steps=2000, 72 | total_steps=20000, 73 | current_step=0) 74 | lrs = [] 75 | for i in range(25000): 76 | s.step() 77 | lrs.append(s.lr) 78 | print(s.lr) 79 | 80 | plt.plot(lrs) 81 | plt.plot(range(0, 25000), lrs) 82 | plt.show() 83 | -------------------------------------------------------------------------------- /AR/text_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/AR/text_processing/__init__.py -------------------------------------------------------------------------------- /AR/text_processing/phonemizer.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/phonemizer.py 2 | import itertools 3 | import re 4 | from typing import Dict 5 | from typing import List 6 | 7 | import regex 8 | from gruut import sentences 9 | from gruut.const import Sentence 10 | from gruut.const import Word 11 | from AR.text_processing.symbols import SYMBOL_TO_ID 12 | 13 | 14 | class GruutPhonemizer: 15 | def __init__(self, language: str): 16 | self._phonemizer = sentences 17 | self.lang = language 18 | self.symbol_to_id = SYMBOL_TO_ID 19 | self._special_cases_dict: Dict[str] = { 20 | r"\.\.\.": "... ", 21 | ";": "; ", 22 | ":": ": ", 23 | ",": ", ", 24 | r"\.": ". ", 25 | "!": "! ", 26 | r"\?": "? ", 27 | "—": "—", 28 | "…": "… ", 29 | "«": "«", 30 | "»": "»" 31 | } 32 | self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])" 33 | 34 | def _normalize_punctuation(self, text: str) -> str: 35 | text = regex.sub(fr"\pZ+{self._punctuation_regexp}", r"\1", text) 36 | text = regex.sub(fr"{self._punctuation_regexp}(\pL)", r"\1 \2", text) 37 | text = regex.sub(r"\pZ+", r" ", text) 38 | return text.strip() 39 | 40 | def _convert_punctuation(self, word: Word) -> str: 41 | if not word.phonemes: 42 | return '' 43 | if word.phonemes[0] in ['‖', '|']: 44 | return word.text.strip() 45 | 46 | phonemes = ''.join(word.phonemes) 47 | # remove modifier characters ˈˌː with regex 48 | phonemes = re.sub(r'[ˈˌː͡]', '', phonemes) 49 | return phonemes.strip() 50 | 51 | def phonemize(self, text: str, espeak: bool=False) -> str: 52 | text_to_phonemize: str = self._normalize_punctuation(text) 53 | sents: List[Sentence] = [ 54 | sent 55 | for sent in self._phonemizer( 56 | text_to_phonemize, lang="en-us", espeak=espeak) 57 | ] 58 | words: List[str] = [ 59 | self._convert_punctuation(word) for word in itertools.chain(*sents) 60 | ] 61 | return ' '.join(words) 62 | 63 | def transform(self, phonemes): 64 | # convert phonemes to ids 65 | # dictionary is in symbols.py 66 | return [ 67 | self.symbol_to_id[p] for p in phonemes 68 | if p in self.symbol_to_id.keys() 69 | ] 70 | 71 | 72 | if __name__ == "__main__": 73 | phonemizer = GruutPhonemizer("en-us") 74 | # text -> IPA 75 | phonemes = phonemizer.phonemize("Hello, wor-ld ?") 76 | print("phonemes:", phonemes) 77 | print("len(phonemes):", len(phonemes)) 78 | phoneme_ids = phonemizer.transform(phonemes) 79 | print("phoneme_ids:", phoneme_ids) 80 | print("len(phoneme_ids):", len(phoneme_ids)) 81 | -------------------------------------------------------------------------------- /AR/text_processing/symbols.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/text_processing/symbols.py 2 | PAD = '_' 3 | PUNCTUATION = ';:,.!?¡¿—…"«»“” ' 4 | LETTERS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 5 | IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 6 | SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS) 7 | SPACE_ID = SYMBOLS.index(" ") 8 | SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)} 9 | ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)} 10 | -------------------------------------------------------------------------------- /AR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def str2bool(str): 5 | return True if str.lower() == 'true' else False 6 | 7 | 8 | def get_newest_ckpt(string_list): 9 | # 定义一个正则表达式模式,用于匹配字符串中的数字 10 | pattern = r'epoch=(\d+)-step=(\d+)\.ckpt' 11 | 12 | # 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表 13 | extracted_info = [] 14 | for string in string_list: 15 | match = re.match(pattern, string) 16 | if match: 17 | epoch = int(match.group(1)) 18 | step = int(match.group(2)) 19 | extracted_info.append((epoch, step, string)) 20 | # 按照 epoch 后面的数字和 step 后面的数字进行排序 21 | sorted_info = sorted( 22 | extracted_info, key=lambda x: (x[0], x[1]), reverse=True) 23 | # 获取最新的 ckpt 文件名 24 | newest_ckpt = sorted_info[0][2] 25 | return newest_ckpt 26 | 27 | 28 | # 文本存在且不为空时 return True 29 | def check_txt_file(file_path): 30 | try: 31 | with open(file_path, 'r') as file: 32 | text = file.readline().strip() 33 | assert text.strip() != '' 34 | return text 35 | except Exception: 36 | return False 37 | return False 38 | -------------------------------------------------------------------------------- /AR/utils/initialize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Initialize modules for espnet2 neural networks.""" 3 | import torch 4 | from typeguard import check_argument_types 5 | 6 | 7 | def initialize(model: torch.nn.Module, init: str): 8 | """Initialize weights of a neural network module. 9 | 10 | Parameters are initialized using the given method or distribution. 11 | 12 | Custom initialization routines can be implemented into submodules 13 | as function `espnet_initialization_fn` within the custom module. 14 | 15 | Args: 16 | model: Target. 17 | init: Method of initialization. 18 | """ 19 | assert check_argument_types() 20 | print("init with", init) 21 | 22 | # weight init 23 | for p in model.parameters(): 24 | if p.dim() > 1: 25 | if init == "xavier_uniform": 26 | torch.nn.init.xavier_uniform_(p.data) 27 | elif init == "xavier_normal": 28 | torch.nn.init.xavier_normal_(p.data) 29 | elif init == "kaiming_uniform": 30 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") 31 | elif init == "kaiming_normal": 32 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") 33 | else: 34 | raise ValueError("Unknown initialization: " + init) 35 | # bias init 36 | for name, p in model.named_parameters(): 37 | if ".bias" in name and p.dim() == 1: 38 | p.data.zero_() 39 | -------------------------------------------------------------------------------- /AR/utils/io.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import yaml 5 | 6 | 7 | def load_yaml_config(path): 8 | with open(path) as f: 9 | config = yaml.full_load(f) 10 | return config 11 | 12 | 13 | def save_config_to_yaml(config, path): 14 | assert path.endswith('.yaml') 15 | with open(path, 'w') as f: 16 | f.write(yaml.dump(config)) 17 | f.close() 18 | 19 | 20 | def write_args(args, path): 21 | args_dict = dict((name, getattr(args, name)) for name in dir(args) 22 | if not name.startswith('_')) 23 | with open(path, 'a') as args_file: 24 | args_file.write('==> torch version: {}\n'.format(torch.__version__)) 25 | args_file.write( 26 | '==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) 27 | args_file.write('==> Cmd:\n') 28 | args_file.write(str(sys.argv)) 29 | args_file.write('\n==> args:\n') 30 | for k, v in sorted(args_dict.items()): 31 | args_file.write(' %s: %s\n' % (str(k), str(v))) 32 | args_file.close() 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 rcell123 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoRegressive-VITS 2 | 3 | (WIP) text to speech using autoregressive transformer and VITS 4 | ## Note 5 | + 模型效果未完全验证,不一定会好,请谨慎踩坑,预训练模型还在练 6 | + 从零训练需要海量数据(至少上千小时?)(类似valle、speartts、soundstorm)数据量少一定不会有好效果。。 7 | + 由于vits+refenc在zeroshot方向局限性很大,因此本仓库不追求zeroshot,本仓库的目标是,在有一个大的lm的pretrain的情况下,借助自回归lm的力量,希望在对小数据finetune以后能有很好的韵律。 8 | + 简单更新了一些初步的 [合成samples](https://huggingface.co/innnky/ar-tts-models/tree/main/gpt-vits) 9 | ## Todo 10 | + [x] 在原神数据上训练 11 | + [x] 收集更多中文开源数据训练(预计600H左右)训练并放出pretrain(x) --> out-of-distribution文本效果很差,例如读文言文 并且长句效果不好, 会抽风 12 | + [ ] 添加word level bert 并repeat到phoneme level改善out-of-distribution效果 13 | + [ ] 将同一spk的数据多条合并为一条音频 提高平均数据时长 改善长句合成效果稳定性 14 | + [ ] 更换为RoPE相对位置编码改善长句合成效果稳定性? 15 | + [ ] 编写finetune相关代码,增加sid支持 16 | + [ ] 优化日语和英语文本前端,收集更多日、英数据(预计每种语言600H)训练并放出pretrain 17 | 18 | ## structure 19 | ![structure.png](resources%2Fstructure.png) 20 | 21 | + decoder only text2semantic from [SoundStorm](https://github.com/yangdongchao/SoundStorm/tree/master/soundstorm/s1/AR) 22 | + VITS from [VITS](https://github.com/jaywalnut310/vits) 23 | + reference encoder from [TransferTTS](https://github.com/hcy71o/TransferTTS) 24 | 25 | ## Training pipeline 26 | 1. jointly train S2 vits decoder and quantizer 27 | 2. extract semantic tokens 28 | 3. train S1 text to semantic 29 | 30 | ## vits S2 training 31 | + resample.py 32 | + gen_phonemes.py 33 | + extract_ssl_s2.py 34 | + gen_filelist_s2.py 35 | + train_s2.py 36 | 37 | ## gpt S1 training 38 | + extract_vq_s1.py 39 | + gen_filelist_s1.py 40 | + train_s1.py 41 | 42 | ## Inference 43 | + s1_infer.py/s2_infer.py (work in progress) 44 | 45 | ## Pretrained models 46 | + work in progress 47 | -------------------------------------------------------------------------------- /configs/s1.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | seed: 1234 3 | epochs: 100 4 | batch_size: 6 5 | gradient_accumulation: 4 6 | save_every_n_epoch: 1 7 | precision: 32 8 | gradient_clip: 1.0 9 | optimizer: 10 | lr: 0.01 11 | lr_init: 0.00001 12 | lr_end: 0.0001 13 | warmup_steps: 2000 14 | decay_steps: 40000 15 | data: 16 | max_eval_sample: 8 17 | max_sec: 20 18 | num_workers: 1 19 | pad_val: 1024 # same with EOS in model 20 | model: 21 | vocab_size: 1025 22 | phoneme_vocab_size: 512 23 | embedding_dim: 512 24 | hidden_dim: 512 25 | head: 16 26 | linear_units: 2048 27 | n_layer: 12 28 | dropout: 0 29 | EOS: 1024 30 | inference: 31 | top_k: 5 -------------------------------------------------------------------------------- /configs/s2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 0.0001, 8 | "betas": [ 9 | 0.8, 10 | 0.99 11 | ], 12 | "eps": 1e-09, 13 | "batch_size": 16, 14 | "fp16_run": true, 15 | "lr_decay": 0.999875, 16 | "segment_size": 20480, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 45, 20 | "c_kl": 1.0 21 | }, 22 | "data": { 23 | "training_files": "dump/s2_train_files.list", 24 | "validation_files": "dump/s2_val_files.list", 25 | "max_wav_value": 32768.0, 26 | "sampling_rate": 32000, 27 | "filter_length": 2048, 28 | "hop_length": 640, 29 | "win_length": 2048, 30 | "n_mel_channels": 128, 31 | "mel_fmin": 0.0, 32 | "mel_fmax": null, 33 | "add_blank": true, 34 | "n_speakers": 300, 35 | "cleaned_text": true 36 | }, 37 | "model": { 38 | "inter_channels": 192, 39 | "hidden_channels": 192, 40 | "filter_channels": 768, 41 | "n_heads": 2, 42 | "n_layers": 6, 43 | "kernel_size": 3, 44 | "p_dropout": 0.1, 45 | "resblock": "1", 46 | "resblock_kernel_sizes": [ 47 | 3, 48 | 7, 49 | 11 50 | ], 51 | "resblock_dilation_sizes": [ 52 | [ 53 | 1, 54 | 3, 55 | 5 56 | ], 57 | [ 58 | 1, 59 | 3, 60 | 5 61 | ], 62 | [ 63 | 1, 64 | 3, 65 | 5 66 | ] 67 | ], 68 | "upsample_rates": [ 69 | 10, 70 | 8, 71 | 2, 72 | 2, 73 | 2 74 | ], 75 | "upsample_initial_channel": 512, 76 | "upsample_kernel_sizes": [ 77 | 16, 78 | 16, 79 | 8, 80 | 2, 81 | 2 82 | ], 83 | "n_layers_q": 3, 84 | "use_spectral_norm": false, 85 | "gin_channels": 512, 86 | "semantic_frame_rate": "25hz", 87 | "freeze_quantizer": false 88 | }, 89 | "s2_ckpt_dir": "logs/s2", 90 | "content_module": "cnhubert" 91 | } -------------------------------------------------------------------------------- /data_conf.py: -------------------------------------------------------------------------------- 1 | data_root = 'dataset' 2 | -------------------------------------------------------------------------------- /extract_ssl_s2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | import argparse 4 | from random import shuffle 5 | import torch.multiprocessing as mp 6 | 7 | import torch 8 | from glob import glob 9 | from tqdm import tqdm 10 | 11 | import utils 12 | from data_conf import data_root 13 | from feature_extractor import content_module_map 14 | import logging 15 | 16 | logging.getLogger("numba").setLevel(logging.WARNING) 17 | import librosa 18 | 19 | 20 | def process_one(file_path, model, device, content_module): 21 | 22 | ssl_path = file_path.replace(".wav", ".ssl.pt") 23 | try: 24 | wav16k, sr = librosa.load(file_path, sr=16000) 25 | wav16k = torch.from_numpy(wav16k).to(device) 26 | ssl_content = content_module.get_content(model, wav_16k_tensor=wav16k) 27 | torch.save(ssl_content.cpu().half(), ssl_path) 28 | del ssl_content 29 | del wav16k 30 | except: 31 | print("skip", file_path) 32 | 33 | def process_batch(filenames, content_module): 34 | content_module = content_module_map[content_module] 35 | print("Loading content model...") 36 | rank = mp.current_process()._identity 37 | rank = rank[0] if len(rank) > 0 else 0 38 | gpu_id = rank % torch.cuda.device_count() 39 | device = torch.device(f"cuda:{gpu_id}") 40 | print(device) 41 | ssl_model = content_module.get_model().to(device) 42 | print("Loaded content model.") 43 | for filename in tqdm(filenames): 44 | process_one(filename, ssl_model, device, content_module) 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | 50 | parser.add_argument( 51 | "--config", type=str, default="configs/s2.json", help="path to config" 52 | ) 53 | args = parser.parse_args() 54 | filenames = glob(f"{data_root}/**/*.wav", recursive=True) # [:10] 55 | hps = utils.get_hparams_from_file(args.config) 56 | shuffle(filenames) 57 | multiprocessing.set_start_method("spawn", force=True) 58 | 59 | num_processes = 8 60 | chunk_size = int(math.ceil(len(filenames) / num_processes)) 61 | chunks = [ 62 | filenames[i : i + chunk_size] for i in range(0, len(filenames), chunk_size) 63 | ] 64 | print([len(c) for c in chunks]) 65 | processes = [ 66 | multiprocessing.Process(target=process_batch, args=(chunk,hps.content_module)) for chunk in chunks 67 | ] 68 | for p in processes: 69 | p.start() -------------------------------------------------------------------------------- /extract_vq_s1.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | import os 4 | from random import shuffle 5 | import torch.multiprocessing as mp 6 | 7 | import torch 8 | from glob import glob 9 | from tqdm import tqdm 10 | 11 | import utils 12 | import logging 13 | 14 | from data_conf import data_root 15 | from module.models import SynthesizerTrn 16 | 17 | logging.getLogger("numba").setLevel(logging.WARNING) 18 | import librosa 19 | 20 | 21 | def process_one(f, file_path, model,vq_model, device): 22 | 23 | try: 24 | # wav16k, sr = librosa.load(file_path, sr=16000) 25 | # wav16k = torch.from_numpy(wav16k).to(device) 26 | # ssl_content = content_module.get_content(model, wav_16k_tensor=wav16k) 27 | ssl_content = torch.load(file_path.replace(".wav", ".ssl.pt")).float().to(device) 28 | codes = vq_model.extract_latent(ssl_content) 29 | semantic = " ".join([str(i) for i in codes[0, 0, :].tolist()]) 30 | f.write(f"{file_path}\t{semantic}\n") 31 | f.flush() 32 | except: 33 | print("skip", file_path) 34 | 35 | def process_batch(filenames): 36 | print("Loading models ...") 37 | process_idx = mp.current_process()._identity 38 | rank = process_idx[0] if len(process_idx) > 0 else 0 39 | gpu_id = rank % torch.cuda.device_count() 40 | device = torch.device(f"cuda:{gpu_id}") 41 | print(device) 42 | # ssl_model = content_module.get_model().to(device) 43 | hps = utils.get_hparams_from_file("configs/s2.json") 44 | vq_model = SynthesizerTrn( 45 | hps.data.filter_length // 2 + 1, 46 | hps.train.segment_size // hps.data.hop_length, 47 | n_speakers=hps.data.n_speakers, 48 | **hps.model).to(device) 49 | vq_model.eval() 50 | utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), vq_model, 51 | None, True) 52 | 53 | print("Loaded .") 54 | with torch.no_grad(): 55 | with open(f"dump/semantic_{process_idx[0]}.tsv", "w") as f: 56 | for filename in tqdm(filenames): 57 | process_one(f, filename, None ,vq_model, device) 58 | 59 | in_dir = data_root 60 | 61 | if __name__ == "__main__": 62 | filenames = glob(f"{in_dir}/**/*.wav", recursive=True) # [:10] 63 | shuffle(filenames) 64 | multiprocessing.set_start_method("spawn", force=True) 65 | 66 | num_processes = 8 67 | chunk_size = int(math.ceil(len(filenames) / num_processes)) 68 | chunks = [ 69 | filenames[i : i + chunk_size] for i in range(0, len(filenames), chunk_size) 70 | ] 71 | print([len(c) for c in chunks]) 72 | processes = [ 73 | multiprocessing.Process(target=process_batch, args=(chunk,)) for chunk in chunks 74 | ] 75 | for p in processes: 76 | p.start() 77 | 78 | for p in processes: 79 | p.join() 80 | with open(f"dump/semantic.tsv", "w") as f: 81 | f.write("item_name\tsemantic_audio\n") 82 | for i in range(num_processes): 83 | with open(f"dump/semantic_{i+1}.tsv", "r") as f2: 84 | f.write(f2.read()) 85 | os.remove(f"dump/semantic_{i+1}.tsv") 86 | -------------------------------------------------------------------------------- /feature_extractor/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cnhubert, whisper_enc 2 | 3 | content_module_map = { 4 | 'cnhubert': cnhubert, 5 | 'whisper': whisper_enc 6 | } -------------------------------------------------------------------------------- /feature_extractor/cnhubert.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import librosa 4 | import torch 5 | import torch.nn.functional as F 6 | import soundfile as sf 7 | import logging 8 | 9 | logging.getLogger("numba").setLevel(logging.WARNING) 10 | 11 | from transformers import ( 12 | Wav2Vec2FeatureExtractor, 13 | HubertModel, 14 | ) 15 | 16 | import utils 17 | import torch.nn as nn 18 | 19 | class CNHubert(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-base") 23 | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("TencentGameMate/chinese-hubert-base") 24 | def forward(self, x): 25 | input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device) 26 | feats = self.model(input_values)["last_hidden_state"] 27 | return feats 28 | 29 | 30 | 31 | def get_model(): 32 | model = CNHubert() 33 | model.eval() 34 | return model 35 | 36 | def get_content(hmodel, wav_16k_tensor): 37 | with torch.no_grad(): 38 | feats = hmodel(wav_16k_tensor) 39 | return feats.transpose(1,2) 40 | 41 | 42 | if __name__ == '__main__': 43 | model = get_model() 44 | src_path = "/Users/Shared/原音频2.wav" 45 | wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000) 46 | model = model 47 | wav_16k_tensor = wav_16k_tensor 48 | feats = get_content(model,wav_16k_tensor) 49 | print(feats.shape) 50 | 51 | -------------------------------------------------------------------------------- /feature_extractor/whisper_enc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_model(): 5 | import whisper 6 | model = whisper.load_model("small", device='cpu') 7 | 8 | return model.encoder 9 | 10 | 11 | def get_content(model=None, wav_16k_tensor=None): 12 | from whisper import log_mel_spectrogram, pad_or_trim 13 | dev = next(model.parameters()).device 14 | mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000] 15 | # if torch.cuda.is_available(): 16 | # mel = mel.to(torch.float16) 17 | feature_len = mel.shape[-1] // 2 18 | assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频" 19 | with torch.no_grad(): 20 | feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1,2) 21 | return feature 22 | 23 | -------------------------------------------------------------------------------- /gen_filelist_s1.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pandas 3 | 4 | semantic_path = 'dump/semantic.tsv' 5 | phoneme_path = 'dump/phoneme.npy' 6 | train_semantic_path = 'dump/semantic_train.tsv' 7 | train_phoneme_path = 'dump/phoneme_train.npy' 8 | dev_semantic_path = 'dump/semantic_dev.tsv' 9 | dev_phoneme_path = 'dump/phoneme_dev.npy' 10 | 11 | # 读取dump/semantic.tsv 12 | semantic_df = pandas.read_csv(semantic_path, sep='\t') 13 | # pd.DataFrame(columns=["item_name", "semantic_audio"]) 14 | # # 读取dump/phoneme.npy 15 | phoneme_dict = numpy.load(phoneme_path, allow_pickle=True).item() 16 | 17 | dev_num = 20 18 | # 随机从semantic_df中选取dev_num个 19 | dev_df = semantic_df.sample(n=dev_num) 20 | # 剩下的是train 21 | train_df = semantic_df.drop(dev_df.index) 22 | # 保存 23 | dev_df.to_csv(dev_semantic_path, sep='\t', index=False) 24 | train_df.to_csv(train_semantic_path, sep='\t', index=False) 25 | 26 | # 将dev_df中的item_name取出来 作为dev_phoneme_dict的key 27 | dev_item_names = dev_df['item_name'].tolist() 28 | dev_phoneme_dict = {k: phoneme_dict[k] for k in dev_item_names if k in phoneme_dict} 29 | train_phoneme_dict = {k: phoneme_dict[k] for k in phoneme_dict.keys() if k not in dev_item_names} 30 | 31 | numpy.save(dev_phoneme_path, dev_phoneme_dict) 32 | numpy.save(train_phoneme_path, train_phoneme_dict) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /gen_filelist_s2.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from random import shuffle 3 | from data_conf import data_root 4 | 5 | filenames = glob(f"{data_root}/**/*.wav", recursive=True) # [:10] 6 | 7 | shuffle(filenames) 8 | val_num = 8 9 | train = filenames[:-val_num] 10 | val = filenames[-val_num:] 11 | train.sort() 12 | val.sort() 13 | 14 | with open('dump/s2_train_files.list', 'w') as f: 15 | f.write('\n'.join(train)) 16 | with open('dump/s2_val_files.list', 'w') as f: 17 | f.write('\n'.join(val)) 18 | 19 | 20 | -------------------------------------------------------------------------------- /gen_phonemes.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from glob import glob 3 | from tqdm import tqdm 4 | 5 | from data_conf import data_root 6 | from text.cleaner import clean_text 7 | import numpy as np 8 | from multiprocessing import Pool 9 | 10 | out_dir = "dump" 11 | os.makedirs(out_dir, exist_ok=True) 12 | phoneme_path = os.path.join(out_dir, "phoneme.npy") 13 | phone_dict = {} 14 | 15 | def process_file(data): 16 | wav_path, language = data 17 | lab_path = wav_path.replace(".wav", ".lab") 18 | if os.path.exists(lab_path): 19 | text = open(lab_path).readline().strip() 20 | phones = clean_text(text, language) 21 | phones = " ".join(phones) 22 | return (wav_path, phones) 23 | else: 24 | return None 25 | for language in ["zh", 'en', 'ja']: 26 | filenames = glob(f"{data_root}/{language}/**/*.wav", recursive=True) 27 | 28 | # Define the number of processes to use 29 | num_processes = 5 # You can adjust this as needed 30 | 31 | with Pool(num_processes) as pool: 32 | results = list(tqdm(pool.imap(process_file, [(f, language) for f in filenames]), total=len(filenames))) 33 | 34 | for result in results: 35 | if result is not None: 36 | phone_dict[result[0]] = result[1] 37 | 38 | np.save(phoneme_path, phone_dict) 39 | -------------------------------------------------------------------------------- /module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/module/__init__.py -------------------------------------------------------------------------------- /module/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d( 68 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 69 | position = torch.arange(length, dtype=torch.float) 70 | num_timescales = channels // 2 71 | log_timescale_increment = ( 72 | math.log(float(max_timescale) / float(min_timescale)) / 73 | (num_timescales - 1)) 74 | inv_timescales = min_timescale * torch.exp( 75 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | l = pad_shape[::-1] 112 | pad_shape = [item for sublist in l for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | device = duration.device 134 | 135 | b, _, t_y, t_x = mask.shape 136 | cum_duration = torch.cumsum(duration, -1) 137 | 138 | cum_duration_flat = cum_duration.view(b * t_x) 139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 140 | path = path.view(b, t_x, t_y) 141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 142 | path = path.unsqueeze(1).transpose(2,3) * mask 143 | return path 144 | 145 | 146 | def clip_grad_value_(parameters, clip_value, norm_type=2): 147 | if isinstance(parameters, torch.Tensor): 148 | parameters = [parameters] 149 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 150 | norm_type = float(norm_type) 151 | if clip_value is not None: 152 | clip_value = float(clip_value) 153 | 154 | total_norm = 0 155 | for p in parameters: 156 | param_norm = p.grad.data.norm(norm_type) 157 | total_norm += param_norm.item() ** norm_type 158 | if clip_value is not None: 159 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 160 | total_norm = total_norm ** (1. / norm_type) 161 | return total_norm 162 | 163 | 164 | def squeeze(x, x_mask=None, n_sqz=2): 165 | b, c, t = x.size() 166 | 167 | t = (t // n_sqz) * n_sqz 168 | x = x[:, :, :t] 169 | x_sqz = x.view(b, c, t // n_sqz, n_sqz) 170 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) 171 | 172 | if x_mask is not None: 173 | x_mask = x_mask[:, :, n_sqz - 1::n_sqz] 174 | else: 175 | x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) 176 | return x_sqz * x_mask, x_mask 177 | 178 | 179 | def unsqueeze(x, x_mask=None, n_sqz=2): 180 | b, c, t = x.size() 181 | 182 | x_unsqz = x.view(b, n_sqz, c // n_sqz, t) 183 | x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) 184 | 185 | if x_mask is not None: 186 | x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) 187 | else: 188 | x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) 189 | return x_unsqz * x_mask, x_mask 190 | -------------------------------------------------------------------------------- /module/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1-dr)**2) 26 | g_loss = torch.mean(dg**2) 27 | loss += (r_loss + g_loss) 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1-dg)**2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | 63 | def mle_loss(z, m, logs, logdet, mask): 64 | l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term 65 | l = l - torch.sum(logdet) # log jacobian determinant 66 | l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes 67 | l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term 68 | return l -------------------------------------------------------------------------------- /module/mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + '_' + str(y.device) 59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 66 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 67 | 68 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 69 | return spec 70 | 71 | 72 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 73 | global mel_basis 74 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 75 | fmax_dtype_device = str(fmax) + '_' + dtype_device 76 | if fmax_dtype_device not in mel_basis: 77 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 78 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 79 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 80 | spec = spectral_normalize_torch(spec) 81 | return spec 82 | 83 | 84 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 85 | if torch.min(y) < -1.: 86 | print('min value is ', torch.min(y)) 87 | if torch.max(y) > 1.: 88 | print('max value is ', torch.max(y)) 89 | 90 | global mel_basis, hann_window 91 | dtype_device = str(y.dtype) + '_' + str(y.device) 92 | fmax_dtype_device = str(fmax) + '_' + dtype_device 93 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 94 | if fmax_dtype_device not in mel_basis: 95 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 96 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 97 | if wnsize_dtype_device not in hann_window: 98 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 99 | 100 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 101 | y = y.squeeze(1) 102 | 103 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 104 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 105 | 106 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 107 | 108 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 109 | spec = spectral_normalize_torch(spec) 110 | 111 | return spec 112 | -------------------------------------------------------------------------------- /module/mrte_model.py: -------------------------------------------------------------------------------- 1 | # This is Multi-reference timbre encoder 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils import remove_weight_norm, weight_norm 6 | from module.attentions import MultiHeadAttention 7 | 8 | class MRTE(nn.Module): 9 | def __init__(self, 10 | content_enc_channels=192, 11 | hidden_size=512, 12 | out_channels=192, 13 | kernel_size=5, 14 | n_heads=4, 15 | ge_layer = 2 16 | ): 17 | super(MRTE, self).__init__() 18 | self.cross_attention = MultiHeadAttention(hidden_size,hidden_size,n_heads) 19 | self.c_pre = nn.Conv1d(content_enc_channels,hidden_size, 1) 20 | self.text_pre = nn.Conv1d(content_enc_channels,hidden_size, 1) 21 | self.c_post = nn.Conv1d(hidden_size,out_channels, 1) 22 | 23 | def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None): 24 | attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1) 25 | 26 | ssl_enc = self.c_pre(ssl_enc * ssl_mask) 27 | text_enc = self.text_pre(text * text_mask) 28 | if test != None: 29 | if test == 0: 30 | x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge 31 | elif test == 1: 32 | x = ssl_enc + ge 33 | elif test ==2: 34 | x = self.cross_attention(ssl_enc*0 * ssl_mask, text_enc * text_mask, attn_mask) + ge 35 | else: 36 | raise ValueError("test should be 0,1,2") 37 | else: 38 | x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge 39 | x = self.c_post(x * ssl_mask) 40 | return x 41 | 42 | 43 | class SpeakerEncoder(torch.nn.Module): 44 | def __init__(self, mel_n_channels=80, model_num_layers=2, model_hidden_size=256, model_embedding_size=256): 45 | super(SpeakerEncoder, self).__init__() 46 | self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) 47 | self.linear = nn.Linear(model_hidden_size, model_embedding_size) 48 | self.relu = nn.ReLU() 49 | 50 | def forward(self, mels): 51 | self.lstm.flatten_parameters() 52 | _, (hidden, _) = self.lstm(mels.transpose(-1, -2)) 53 | embeds_raw = self.relu(self.linear(hidden[-1])) 54 | return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 55 | 56 | 57 | class MELEncoder(nn.Module): 58 | def __init__(self, 59 | in_channels, 60 | out_channels, 61 | hidden_channels, 62 | kernel_size, 63 | dilation_rate, 64 | n_layers): 65 | super().__init__() 66 | self.in_channels = in_channels 67 | self.out_channels = out_channels 68 | self.hidden_channels = hidden_channels 69 | self.kernel_size = kernel_size 70 | self.dilation_rate = dilation_rate 71 | self.n_layers = n_layers 72 | 73 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 74 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers) 75 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 76 | 77 | def forward(self, x): 78 | # print(x.shape,x_lengths.shape) 79 | x = self.pre(x) 80 | x = self.enc(x) 81 | x = self.proj(x) 82 | return x 83 | 84 | 85 | class WN(torch.nn.Module): 86 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers): 87 | super(WN, self).__init__() 88 | assert(kernel_size % 2 == 1) 89 | self.hidden_channels =hidden_channels 90 | self.kernel_size = kernel_size 91 | self.dilation_rate = dilation_rate 92 | self.n_layers = n_layers 93 | 94 | self.in_layers = torch.nn.ModuleList() 95 | self.res_skip_layers = torch.nn.ModuleList() 96 | 97 | for i in range(n_layers): 98 | dilation = dilation_rate ** i 99 | padding = int((kernel_size * dilation - dilation) / 2) 100 | in_layer = nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 101 | dilation=dilation, padding=padding) 102 | in_layer = weight_norm(in_layer) 103 | self.in_layers.append(in_layer) 104 | 105 | # last one is not necessary 106 | if i < n_layers - 1: 107 | res_skip_channels = 2 * hidden_channels 108 | else: 109 | res_skip_channels = hidden_channels 110 | 111 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 112 | res_skip_layer = weight_norm(res_skip_layer, name='weight') 113 | self.res_skip_layers.append(res_skip_layer) 114 | 115 | def forward(self, x): 116 | output = torch.zeros_like(x) 117 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 118 | 119 | for i in range(self.n_layers): 120 | x_in = self.in_layers[i](x) 121 | 122 | acts = fused_add_tanh_sigmoid_multiply( 123 | x_in, 124 | n_channels_tensor) 125 | 126 | res_skip_acts = self.res_skip_layers[i](acts) 127 | if i < self.n_layers - 1: 128 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 129 | x = (x + res_acts) 130 | output = output + res_skip_acts[:,self.hidden_channels:,:] 131 | else: 132 | output = output + res_skip_acts 133 | return output 134 | 135 | def remove_weight_norm(self): 136 | for l in self.in_layers: 137 | remove_weight_norm(l) 138 | for l in self.res_skip_layers: 139 | remove_weight_norm(l) 140 | 141 | 142 | @torch.jit.script 143 | def fused_add_tanh_sigmoid_multiply(input, n_channels): 144 | n_channels_int = n_channels[0] 145 | t_act = torch.tanh(input[:, :n_channels_int, :]) 146 | s_act = torch.sigmoid(input[:, n_channels_int:, :]) 147 | acts = t_act * s_act 148 | return acts 149 | 150 | 151 | 152 | if __name__ == '__main__': 153 | content_enc = torch.randn(3,192,100) 154 | content_mask = torch.ones(3,1,100) 155 | ref_mel = torch.randn(3,128,30) 156 | ref_mask = torch.ones(3,1,30) 157 | model = MRTE() 158 | out = model(content_enc,content_mask,ref_mel,ref_mask) 159 | print(out.shape) -------------------------------------------------------------------------------- /module/quantize.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Residual vector quantizer implementation.""" 8 | 9 | from dataclasses import dataclass, field 10 | import math 11 | import typing as tp 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from module.core_vq import ResidualVectorQuantization 17 | 18 | 19 | @dataclass 20 | class QuantizedResult: 21 | quantized: torch.Tensor 22 | codes: torch.Tensor 23 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 24 | penalty: tp.Optional[torch.Tensor] = None 25 | metrics: dict = field(default_factory=dict) 26 | 27 | 28 | class ResidualVectorQuantizer(nn.Module): 29 | """Residual Vector Quantizer. 30 | Args: 31 | dimension (int): Dimension of the codebooks. 32 | n_q (int): Number of residual vector quantizers used. 33 | bins (int): Codebook size. 34 | decay (float): Decay for exponential moving average over the codebooks. 35 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 36 | kmeans_iters (int): Number of iterations used for kmeans initialization. 37 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 38 | that have an exponential moving average cluster size less than the specified threshold with 39 | randomly selected vector from the current batch. 40 | """ 41 | def __init__( 42 | self, 43 | dimension: int = 256, 44 | n_q: int = 8, 45 | bins: int = 1024, 46 | decay: float = 0.99, 47 | kmeans_init: bool = True, 48 | kmeans_iters: int = 50, 49 | threshold_ema_dead_code: int = 2, 50 | ): 51 | super().__init__() 52 | self.n_q = n_q 53 | self.dimension = dimension 54 | self.bins = bins 55 | self.decay = decay 56 | self.kmeans_init = kmeans_init 57 | self.kmeans_iters = kmeans_iters 58 | self.threshold_ema_dead_code = threshold_ema_dead_code 59 | self.vq = ResidualVectorQuantization( 60 | dim=self.dimension, 61 | codebook_size=self.bins, 62 | num_quantizers=self.n_q, 63 | decay=self.decay, 64 | kmeans_init=self.kmeans_init, 65 | kmeans_iters=self.kmeans_iters, 66 | threshold_ema_dead_code=self.threshold_ema_dead_code, 67 | ) 68 | 69 | def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult: 70 | """Residual vector quantization on the given input tensor. 71 | Args: 72 | x (torch.Tensor): Input tensor. 73 | n_q (int): Number of quantizer used to quantize. Default: All quantizers. 74 | layers (list): Layer that need to return quantized. Defalt: None. 75 | Returns: 76 | QuantizedResult: 77 | The quantized (or approximately quantized) representation with 78 | the associated numbert quantizers and layer quantized required to return. 79 | """ 80 | n_q = n_q if n_q else self.n_q 81 | if layers and max(layers) >= n_q: 82 | raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.') 83 | quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers) 84 | return quantized, codes, torch.mean(commit_loss), quantized_list 85 | 86 | 87 | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor: 88 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 89 | The RVQ encode method sets the appropriate number of quantizer to use 90 | and returns indices for each quantizer. 91 | Args: 92 | x (torch.Tensor): Input tensor. 93 | n_q (int): Number of quantizer used to quantize. Default: All quantizers. 94 | st (int): Start to encode input from which layers. Default: 0. 95 | """ 96 | n_q = n_q if n_q else self.n_q 97 | st = st or 0 98 | codes = self.vq.encode(x, n_q=n_q, st=st) 99 | return codes 100 | 101 | def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: 102 | """Decode the given codes to the quantized representation. 103 | Args: 104 | codes (torch.Tensor): Input indices for each quantizer. 105 | st (int): Start to decode input codes from which layers. Default: 0. 106 | """ 107 | quantized = self.vq.decode(codes, st=st) 108 | return quantized -------------------------------------------------------------------------------- /module/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.9.1 2 | matplotlib 3 | numpy 4 | scipy 5 | tensorboard 6 | torch 7 | amfm_decompy 8 | speechtokenizer 9 | tqdm 10 | -------------------------------------------------------------------------------- /resample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from glob import glob 4 | 5 | import librosa 6 | import soundfile 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | 10 | from data_conf import data_root 11 | 12 | def process_wav(wavpath): 13 | wav, _ = librosa.load(wavpath, sr=tgt_sr) 14 | soundfile.write(wavpath, wav, tgt_sr) 15 | 16 | def get_wav_files(path): 17 | wav_files = [] 18 | for root, dirs, files in os.walk(path): 19 | for file in files: 20 | if file.endswith(".wav"): 21 | wav_files.append(os.path.join(root, file)) 22 | return wav_files 23 | 24 | tgt_path = data_root 25 | 26 | num_processes = 10 # You can adjust the number of processes as needed 27 | tgt_sr = 32000 28 | 29 | print("Note: this script will overwrite the original files!") 30 | print("all the wav files under {} will be resampled to {}Hz".format(tgt_path, tgt_sr)) 31 | input("press enter to continue... or press Ctrl+C to cancel") 32 | 33 | if __name__ == "__main__": 34 | with Pool(num_processes) as pool: 35 | file_list = get_wav_files(tgt_path) 36 | list(tqdm(pool.imap(process_wav, file_list), total=len(file_list))) 37 | 38 | -------------------------------------------------------------------------------- /resources/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/resources/structure.png -------------------------------------------------------------------------------- /s1_infer.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import time 4 | import torch 5 | from AR.models.t2s_lightning_module import Text2SemanticLightningModule 6 | from AR.utils.io import load_yaml_config 7 | from text import cleaned_text_to_sequence 8 | from text.cleaner import text_to_sequence, clean_text 9 | 10 | 11 | text = "当然,不同问题之间错综复杂,对应的结论也有冲突.所以我想要的是'平衡',也就是在所有问题中找到一个'最优解'." 12 | text = "当然,不同问题之间错综复杂,对应的结论也有冲突.所以我想要的是'平衡'。" 13 | # text = "幸运的是,此次事故并未造成人员伤亡,但两辆车均受到了不同程度的损伤。事故发生后,许多网友对这名女子的驾驶行为表示了强烈的不解和担忧。同时,也有网友表示,这种行为不仅危害了自己和他人的生命安全,还可能对其他道路使用者造成恐慌和困扰。" 14 | # text = "皆さん、こんにちは、私は派蒙です。今日はみんなが見たいものをください。" 15 | prompt_text = "万一他很崇拜我们呢?嘿嘿," 16 | prompt_wav_path = "/home/fish/genshin_data/zh/派蒙/vo_DQAQ003_1_paimon_06.wav" 17 | 18 | def text2phoneid(text, lang='zh'): 19 | phones = clean_text(text, lang) 20 | print(phones) 21 | return cleaned_text_to_sequence(phones) 22 | 23 | 24 | semantic_data = pd.read_csv('dump/semantic.tsv', delimiter='\t') 25 | 26 | 27 | phones = text2phoneid(text) 28 | prompt_phones = text2phoneid(prompt_text) 29 | prompt_semantic = semantic_data[semantic_data['item_name'] == prompt_wav_path]['semantic_audio'].values[0] 30 | prompt_semantic = torch.LongTensor([int(idx) for idx in prompt_semantic.split(' ')]) 31 | 32 | print(prompt_semantic) 33 | n_semantic = 1024 34 | device = 'cpu' 35 | config = load_yaml_config("configs/s1.yaml") 36 | ckpt_path = 'logs/s1/ckpt/epoch=4-step=3945.ckpt' 37 | 38 | hz = 50 39 | max_sec = config['data']['max_sec'] 40 | 41 | # get models 42 | t2s_model = Text2SemanticLightningModule.load_from_checkpoint( 43 | checkpoint_path=ckpt_path, config=config, map_location=device) 44 | t2s_model.to(device) 45 | t2s_model.eval() 46 | 47 | total = sum([param.nelement() for param in t2s_model.parameters()]) 48 | 49 | print("Number of parameter: %.2fM" % (total / 1e6)) 50 | 51 | 52 | all_phoneme_ids = torch.LongTensor(prompt_phones+phones).to(device).unsqueeze(0) 53 | print(all_phoneme_ids.shape) 54 | all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) 55 | prompt = prompt_semantic.unsqueeze(0).to(device) 56 | st = time.time() 57 | with torch.no_grad(): 58 | pred_semantic = t2s_model.model.infer( 59 | all_phoneme_ids, 60 | all_phoneme_len, 61 | prompt, 62 | top_k=config['inference']['top_k'], 63 | early_stop_num=hz * max_sec) 64 | 65 | print(f'{time.time() - st} sec used in T2S') 66 | 67 | torch.save(pred_semantic.squeeze(0).squeeze(0), 'pred_semantic.pt') 68 | 69 | phones = " ".join([str(i) for i in prompt_phones+phones]) 70 | 71 | os.system(f"python s2_infer.py '{phones}'") -------------------------------------------------------------------------------- /s1_train.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py 2 | import argparse 3 | import logging 4 | import os 5 | from pathlib import Path 6 | 7 | import torch 8 | from pytorch_lightning import seed_everything 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from pytorch_lightning.loggers import WandbLogger 12 | from pytorch_lightning.strategies import DDPStrategy 13 | from AR.data.data_module import Text2SemanticDataModule 14 | from AR.models.t2s_lightning_module import Text2SemanticLightningModule 15 | from AR.utils.io import load_yaml_config 16 | logging.getLogger('numba').setLevel(logging.WARNING) 17 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 18 | torch.set_float32_matmul_precision('high') 19 | from AR.utils import get_newest_ckpt 20 | 21 | 22 | def main(args): 23 | output_dir = Path(args.output_dir) 24 | output_dir.mkdir(parents=True, exist_ok=True) 25 | 26 | ckpt_dir = output_dir / 'ckpt' 27 | ckpt_dir.mkdir(parents=True, exist_ok=True) 28 | 29 | config = load_yaml_config(args.config_file) 30 | 31 | seed_everything(config["train"]["seed"], workers=True) 32 | ckpt_callback: ModelCheckpoint = ModelCheckpoint( 33 | save_top_k=-1, 34 | save_on_train_epoch_end=False, 35 | every_n_epochs=config["train"]["save_every_n_epoch"], 36 | dirpath=ckpt_dir) 37 | logger = WandbLogger( 38 | project="ar_s1", 39 | name=output_dir.stem, 40 | save_dir=output_dir, 41 | # resume the loss curve 42 | resume=True, 43 | # id='k19kvsq8' 44 | ) 45 | trainer: Trainer = Trainer( 46 | max_epochs=config["train"]["epochs"], 47 | accelerator='gpu', 48 | devices=-1, 49 | benchmark=False, 50 | fast_dev_run=False, 51 | strategy=DDPStrategy(), 52 | precision=config["train"]["precision"], 53 | logger=logger, 54 | callbacks=[ckpt_callback]) 55 | 56 | model: Text2SemanticLightningModule = Text2SemanticLightningModule( 57 | config, output_dir) 58 | 59 | data_module: Text2SemanticDataModule = Text2SemanticDataModule( 60 | config, 61 | train_semantic_path=args.train_semantic_path, 62 | train_phoneme_path=args.train_phoneme_path, 63 | dev_semantic_path=args.dev_semantic_path, 64 | dev_phoneme_path=args.dev_phoneme_path) 65 | 66 | try: 67 | # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序 68 | newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir)) 69 | ckpt_path = ckpt_dir / newest_ckpt_name 70 | except Exception: 71 | ckpt_path = None 72 | print("ckpt_path:", ckpt_path) 73 | trainer.fit(model, data_module, ckpt_path=ckpt_path) 74 | 75 | 76 | # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument( 80 | '-c', 81 | '--config_file', 82 | type=str, 83 | default='configs/s1.yaml', 84 | help='path of config file') 85 | # args for dataset 86 | parser.add_argument( 87 | '--train_semantic_path', 88 | type=str, 89 | default='dump/semantic_train.tsv') 90 | parser.add_argument( 91 | '--train_phoneme_path', type=str, default='dump/phoneme_train.npy') 92 | parser.add_argument( 93 | '--dev_semantic_path', type=str, default='dump/semantic_dev.tsv') 94 | parser.add_argument( 95 | '--dev_phoneme_path', type=str, default='dump/phoneme_dev.npy') 96 | parser.add_argument( 97 | '--output_dir', 98 | type=str, 99 | default='logs/s1', 100 | help='directory to save the results') 101 | 102 | args = parser.parse_args() 103 | logging.info(str(args)) 104 | main(args) 105 | -------------------------------------------------------------------------------- /s2_infer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import librosa 4 | import soundfile 5 | import torch 6 | 7 | import utils 8 | from module.models import SynthesizerTrn 9 | from module.mel_processing import spectrogram_torch 10 | # from feature_extractor import cnhubert as content_module 11 | 12 | vits_model_cache = None 13 | 14 | 15 | def _load_model(device="cuda"): 16 | global vits_model_cache 17 | if vits_model_cache is not None: 18 | return vits_model_cache 19 | hps = utils.get_hparams_from_file("configs/s2.json") 20 | model_dir = hps.s2_ckpt_dir 21 | net_g = SynthesizerTrn( 22 | hps.data.filter_length // 2 + 1, 23 | hps.train.segment_size // hps.data.hop_length, 24 | n_speakers=hps.data.n_speakers, 25 | **hps.model).to(device) 26 | 27 | utils.load_checkpoint(utils.latest_checkpoint_path(model_dir, "G_*.pth"), net_g, 28 | None, True) 29 | net_g.eval() 30 | vits_model_cache = (hps, net_g) 31 | return hps, net_g 32 | 33 | 34 | def get_spepc(hps, filename): 35 | audio, sampling_rate = utils.load_wav_to_torch(filename) 36 | if sampling_rate != hps.data.sampling_rate: 37 | raise ValueError("{} SR doesn't match target {} SR".format( 38 | sampling_rate, hps.data.sampling_rate)) 39 | audio_norm = audio 40 | audio_norm = audio_norm.unsqueeze(0) 41 | spec = spectrogram_torch(audio_norm, hps.data.filter_length, 42 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 43 | center=False) 44 | return spec 45 | 46 | 47 | def decode_to_file(codes,phonemes, save_path, refer_path, transform='valle'): 48 | device = codes.device 49 | hps, net_g = _load_model(device=device) 50 | if transform=='valle': 51 | codes = codes.transpose(0, 1).unsqueeze(1) 52 | else: 53 | codes = codes.transpose(0, 1) 54 | refer = get_spepc(hps, refer_path).to(device) 55 | audio = net_g.decode(codes,phonemes, refer).detach().cpu().numpy()[0, 0] 56 | soundfile.write(save_path, audio, hps.data.sampling_rate) 57 | 58 | 59 | def encode_from_file(path, device='cpu'): 60 | hps, net_g = _load_model(device=device) 61 | content_model = content_module.get_model().to(device) 62 | wav16k, sr = librosa.load(path, sr=16000) 63 | with torch.no_grad(): 64 | wav16k = torch.from_numpy(wav16k).to(device) 65 | ssl_content = content_module.get_content(content_model, wav_16k_tensor=wav16k) 66 | codes = net_g.extract_latent(ssl_content) 67 | return codes.cpu() 68 | 69 | def encode_semantic_from_wav16k_numpy(wav16k, device='cpu'): 70 | hps, net_g = _load_model(device=device) 71 | content_model = content_module.get_model().to(device) 72 | with torch.no_grad(): 73 | wav16k = torch.from_numpy(wav16k).to(device) 74 | ssl_content = content_module.get_content(content_model, wav_16k_tensor=wav16k) 75 | codes = net_g.extract_latent(ssl_content) 76 | return codes[0, :1, :] 77 | 78 | if __name__ == '__main__': 79 | codes_path = "pred_semantic.pt" 80 | refer_path = "/home/fish/genshin_data/zh/派蒙/vo_DQAQ003_1_paimon_06.wav" 81 | # src_path = "dataset/PaiMeng/vo_DQAQ003_1_paimon_06.wav" 82 | device = 'cpu' 83 | # codes = encode_from_file(src_path, device=device) 84 | codes = torch.load(codes_path).unsqueeze(0).unsqueeze(0) 85 | print('argv', sys.argv[1]) 86 | phonemes = torch.LongTensor([int(i) for i in sys.argv[1].split(" ")]).unsqueeze(0) 87 | print(codes.shape) 88 | print("phonemes", phonemes) 89 | 90 | decode_to_file(codes, phonemes,"tmp.wav", refer_path, transform="raw") -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | from text.symbols import * 2 | 3 | 4 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 5 | 6 | def cleaned_text_to_sequence(cleaned_text): 7 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 8 | Args: 9 | text: string to convert to a sequence 10 | Returns: 11 | List of integers corresponding to the symbols in the text 12 | ''' 13 | phones = [_symbol_to_id[symbol] for symbol in cleaned_text] 14 | return phones 15 | 16 | -------------------------------------------------------------------------------- /text/chinese.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import cn2an 5 | from pypinyin import lazy_pinyin, Style 6 | 7 | from text.symbols import punctuation 8 | from text.tone_sandhi import ToneSandhi 9 | 10 | current_file_path = os.path.dirname(__file__) 11 | pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in 12 | open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()} 13 | 14 | import jieba.posseg as psg 15 | 16 | 17 | rep_map = { 18 | ':': ',', 19 | ';': ',', 20 | ',': ',', 21 | '。': '.', 22 | '!': '!', 23 | '?': '?', 24 | '\n': '.', 25 | "·": ",", 26 | '、': ",", 27 | '...': '…', 28 | '$': '.', 29 | '—': "-" 30 | } 31 | 32 | tone_modifier = ToneSandhi() 33 | 34 | def replace_punctuation(text): 35 | text = text.replace("嗯", "恩").replace("呣","母") 36 | pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys())) 37 | 38 | replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) 39 | 40 | replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text) 41 | 42 | return replaced_text 43 | 44 | def g2p(text): 45 | text = replace_punctuation(text) 46 | pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation)) 47 | sentences = [i for i in re.split(pattern, text) if i.strip()!=''] 48 | phones = _g2p(sentences) 49 | return phones 50 | 51 | 52 | def _get_initials_finals(word): 53 | initials = [] 54 | finals = [] 55 | orig_initials = lazy_pinyin( 56 | word, neutral_tone_with_five=True, style=Style.INITIALS) 57 | orig_finals = lazy_pinyin( 58 | word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) 59 | for c, v in zip(orig_initials, orig_finals): 60 | initials.append(c) 61 | finals.append(v) 62 | return initials, finals 63 | 64 | 65 | def _g2p(segments): 66 | phones_list = [] 67 | for seg in segments: 68 | pinyins = [] 69 | # Replace all English words in the sentence 70 | seg = re.sub('[a-zA-Z]+', '', seg) 71 | seg_cut = psg.lcut(seg) 72 | initials = [] 73 | finals = [] 74 | seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) 75 | for word, pos in seg_cut: 76 | if pos == 'eng': 77 | continue 78 | sub_initials, sub_finals = _get_initials_finals(word) 79 | sub_finals = tone_modifier.modified_tone(word, pos, 80 | sub_finals) 81 | initials.append(sub_initials) 82 | finals.append(sub_finals) 83 | 84 | # assert len(sub_initials) == len(sub_finals) == len(word) 85 | initials = sum(initials, []) 86 | finals = sum(finals, []) 87 | # 88 | for c, v in zip(initials, finals): 89 | raw_pinyin = c+v 90 | # NOTE: post process for pypinyin outputs 91 | # we discriminate i, ii and iii 92 | if c == v: 93 | assert c in punctuation 94 | phone = [c] 95 | else: 96 | v_without_tone = v[:-1] 97 | tone = v[-1] 98 | 99 | pinyin = c+v_without_tone 100 | assert tone in '12345' 101 | 102 | if c: 103 | # 多音节 104 | v_rep_map = { 105 | "uei": 'ui', 106 | 'iou': 'iu', 107 | 'uen': 'un', 108 | } 109 | if v_without_tone in v_rep_map.keys(): 110 | pinyin = c+v_rep_map[v_without_tone] 111 | else: 112 | # 单音节 113 | pinyin_rep_map = { 114 | 'ing': 'ying', 115 | 'i': 'yi', 116 | 'in': 'yin', 117 | 'u': 'wu', 118 | } 119 | if pinyin in pinyin_rep_map.keys(): 120 | pinyin = pinyin_rep_map[pinyin] 121 | else: 122 | single_rep_map = { 123 | 'v': 'yu', 124 | 'e': 'e', 125 | 'i': 'y', 126 | 'u': 'w', 127 | } 128 | if pinyin[0] in single_rep_map.keys(): 129 | pinyin = single_rep_map[pinyin[0]]+pinyin[1:] 130 | 131 | assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) 132 | new_c, new_v = pinyin_to_symbol_map[pinyin].split(' ') 133 | new_v = new_v + tone 134 | phone = [new_c, new_v] 135 | phones_list += phone 136 | return phones_list 137 | 138 | 139 | 140 | def text_normalize(text): 141 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 142 | for number in numbers: 143 | text = text.replace(number, cn2an.an2cn(number), 1) 144 | return text 145 | 146 | 147 | if __name__ == '__main__': 148 | text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" 149 | text = "呣呣呣~就是…大人的鼹鼠党吧?" 150 | text = "你好" 151 | text = text_normalize(text) 152 | print(g2p(text)) 153 | 154 | 155 | # # 示例用法 156 | # text = "这是一个示例文本:,你好!这是一个测试..." 157 | # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 158 | -------------------------------------------------------------------------------- /text/cleaner.py: -------------------------------------------------------------------------------- 1 | from text import chinese, japanese, cleaned_text_to_sequence, symbols, english 2 | 3 | language_module_map = { 4 | 'zh': chinese, 5 | "ja": japanese, 6 | 'en': english 7 | } 8 | 9 | def clean_text(text, language): 10 | language_module = language_module_map[language] 11 | norm_text = language_module.text_normalize(text) 12 | phones = language_module.g2p(norm_text) 13 | 14 | for ph in phones: 15 | assert ph in symbols 16 | return phones 17 | 18 | def text_to_sequence(text, language): 19 | phones = clean_text(text) 20 | return cleaned_text_to_sequence(phones) 21 | 22 | if __name__ == '__main__': 23 | print(clean_text("你好,啊啊啊额、还是到付红四方。")) 24 | 25 | 26 | -------------------------------------------------------------------------------- /text/cmudict_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/ar-vits/aeca3848c6d34946b71348c9f59e7e008cd838ce/text/cmudict_cache.pickle -------------------------------------------------------------------------------- /text/english.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import re 4 | from g2p_en import G2p 5 | 6 | from string import punctuation 7 | 8 | from text import symbols 9 | 10 | current_file_path = os.path.dirname(__file__) 11 | CMU_DICT_PATH = os.path.join(current_file_path, 'cmudict.rep') 12 | CACHE_PATH = os.path.join(current_file_path, 'cmudict_cache.pickle') 13 | _g2p = G2p() 14 | 15 | arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} 16 | 17 | 18 | def replace_phs(phs): 19 | rep_map = { 20 | ';': ',', 21 | ':': ',', 22 | '\'': '-', 23 | '"': '-' 24 | } 25 | phs_new = [] 26 | for ph in phs: 27 | if ph in symbols: 28 | phs_new.append(ph) 29 | elif ph in rep_map.keys(): 30 | phs_new.append(rep_map[ph]) 31 | else: 32 | print('ph not in symbols: ', ph) 33 | return phs_new 34 | 35 | def read_dict(): 36 | g2p_dict = {} 37 | start_line = 49 38 | with open(CMU_DICT_PATH) as f: 39 | line = f.readline() 40 | line_index = 1 41 | while line: 42 | if line_index >= start_line: 43 | line = line.strip() 44 | word_split = line.split(' ') 45 | word = word_split[0] 46 | 47 | syllable_split = word_split[1].split(' - ') 48 | g2p_dict[word] = [] 49 | for syllable in syllable_split: 50 | phone_split = syllable.split(' ') 51 | g2p_dict[word].append(phone_split) 52 | 53 | line_index = line_index + 1 54 | line = f.readline() 55 | 56 | return g2p_dict 57 | 58 | 59 | def cache_dict(g2p_dict, file_path): 60 | with open(file_path, 'wb') as pickle_file: 61 | pickle.dump(g2p_dict, pickle_file) 62 | 63 | 64 | def get_dict(): 65 | if os.path.exists(CACHE_PATH): 66 | with open(CACHE_PATH, 'rb') as pickle_file: 67 | g2p_dict = pickle.load(pickle_file) 68 | else: 69 | g2p_dict = read_dict() 70 | cache_dict(g2p_dict, CACHE_PATH) 71 | 72 | return g2p_dict 73 | 74 | eng_dict = get_dict() 75 | 76 | 77 | def text_normalize(text): 78 | # todo: eng text normalize 79 | return text.replace(";", ",") 80 | 81 | def g2p(text): 82 | 83 | phones = [] 84 | words = re.split(r"([,;.\-\?\!\s+])", text) 85 | for w in words: 86 | if w.upper() in eng_dict: 87 | phns = eng_dict[w.upper()] 88 | for ph in phns: 89 | phones += ph 90 | else: 91 | phone_list = list(filter(lambda p: p != " ", _g2p(w))) 92 | for ph in phone_list: 93 | if ph in arpa: 94 | phones.append(ph) 95 | else: 96 | phones.append(ph) 97 | 98 | return replace_phs(phones) 99 | 100 | if __name__ == "__main__": 101 | # print(get_dict()) 102 | print(g2p("hello")) 103 | print(g2p("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")) 104 | # all_phones = set() 105 | # for k, syllables in eng_dict.items(): 106 | # for group in syllables: 107 | # for ph in group: 108 | # all_phones.add(ph) 109 | # print(all_phones) -------------------------------------------------------------------------------- /text/japanese.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py 2 | import re 3 | import sys 4 | 5 | import pyopenjtalk 6 | 7 | from text import symbols 8 | 9 | # Regular expression matching Japanese without punctuation marks: 10 | _japanese_characters = re.compile( 11 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 12 | 13 | # Regular expression matching non-Japanese characters or punctuation marks: 14 | _japanese_marks = re.compile( 15 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 16 | 17 | # List of (symbol, Japanese) pairs for marks: 18 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 19 | ('%', 'パーセント') 20 | ]] 21 | 22 | 23 | # List of (consonant, sokuon) pairs: 24 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 25 | (r'Q([↑↓]*[kg])', r'k#\1'), 26 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 27 | (r'Q([↑↓]*[sʃ])', r's\1'), 28 | (r'Q([↑↓]*[pb])', r'p#\1') 29 | ]] 30 | 31 | # List of (consonant, hatsuon) pairs: 32 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 33 | (r'N([↑↓]*[pbm])', r'm\1'), 34 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 35 | (r'N([↑↓]*[tdn])', r'n\1'), 36 | (r'N([↑↓]*[kg])', r'ŋ\1') 37 | ]] 38 | 39 | 40 | 41 | def post_replace_ph(ph): 42 | rep_map = { 43 | ':': ',', 44 | ';': ',', 45 | ',': ',', 46 | '。': '.', 47 | '!': '!', 48 | '?': '?', 49 | '\n': '.', 50 | "·": ",", 51 | '、': ",", 52 | '...': '…' 53 | } 54 | if ph in rep_map.keys(): 55 | ph = rep_map[ph] 56 | if ph in symbols: 57 | return ph 58 | if ph not in symbols: 59 | ph = 'UNK' 60 | return ph 61 | 62 | def symbols_to_japanese(text): 63 | for regex, replacement in _symbols_to_japanese: 64 | text = re.sub(regex, replacement, text) 65 | return text 66 | 67 | 68 | def preprocess_jap(text): 69 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 70 | text = symbols_to_japanese(text) 71 | sentences = re.split(_japanese_marks, text) 72 | marks = re.findall(_japanese_marks, text) 73 | text = [] 74 | for i, sentence in enumerate(sentences): 75 | if re.match(_japanese_characters, sentence): 76 | p = pyopenjtalk.g2p(sentence) 77 | text += p.split(" ") 78 | 79 | if i < len(marks): 80 | text += [marks[i].replace(' ', '')] 81 | return text 82 | 83 | def text_normalize(text): 84 | # todo: jap text normalize 85 | return text 86 | 87 | def g2p(norm_text): 88 | phones = preprocess_jap(norm_text) 89 | phones = [post_replace_ph(i) for i in phones] 90 | # todo: implement tones and word2ph 91 | return phones 92 | 93 | 94 | if __name__ == '__main__': 95 | for line in open("../../../Downloads/transcript_utf8.txt").readlines(): 96 | text = line.split(":")[1] 97 | phones = g2p(text) 98 | print(phones) 99 | -------------------------------------------------------------------------------- /text/opencpop-strict.txt: -------------------------------------------------------------------------------- 1 | a AA a 2 | ai AA ai 3 | an AA an 4 | ang AA ang 5 | ao AA ao 6 | ba b a 7 | bai b ai 8 | ban b an 9 | bang b ang 10 | bao b ao 11 | bei b ei 12 | ben b en 13 | beng b eng 14 | bi b i 15 | bian b ian 16 | biao b iao 17 | bie b ie 18 | bin b in 19 | bing b ing 20 | bo b o 21 | bu b u 22 | ca c a 23 | cai c ai 24 | can c an 25 | cang c ang 26 | cao c ao 27 | ce c e 28 | cei c ei 29 | cen c en 30 | ceng c eng 31 | cha ch a 32 | chai ch ai 33 | chan ch an 34 | chang ch ang 35 | chao ch ao 36 | che ch e 37 | chen ch en 38 | cheng ch eng 39 | chi ch ir 40 | chong ch ong 41 | chou ch ou 42 | chu ch u 43 | chua ch ua 44 | chuai ch uai 45 | chuan ch uan 46 | chuang ch uang 47 | chui ch ui 48 | chun ch un 49 | chuo ch uo 50 | ci c i0 51 | cong c ong 52 | cou c ou 53 | cu c u 54 | cuan c uan 55 | cui c ui 56 | cun c un 57 | cuo c uo 58 | da d a 59 | dai d ai 60 | dan d an 61 | dang d ang 62 | dao d ao 63 | de d e 64 | dei d ei 65 | den d en 66 | deng d eng 67 | di d i 68 | dia d ia 69 | dian d ian 70 | diao d iao 71 | die d ie 72 | ding d ing 73 | diu d iu 74 | dong d ong 75 | dou d ou 76 | du d u 77 | duan d uan 78 | dui d ui 79 | dun d un 80 | duo d uo 81 | e EE e 82 | ei EE ei 83 | en EE en 84 | eng EE eng 85 | er EE er 86 | fa f a 87 | fan f an 88 | fang f ang 89 | fei f ei 90 | fen f en 91 | feng f eng 92 | fo f o 93 | fou f ou 94 | fu f u 95 | ga g a 96 | gai g ai 97 | gan g an 98 | gang g ang 99 | gao g ao 100 | ge g e 101 | gei g ei 102 | gen g en 103 | geng g eng 104 | gong g ong 105 | gou g ou 106 | gu g u 107 | gua g ua 108 | guai g uai 109 | guan g uan 110 | guang g uang 111 | gui g ui 112 | gun g un 113 | guo g uo 114 | ha h a 115 | hai h ai 116 | han h an 117 | hang h ang 118 | hao h ao 119 | he h e 120 | hei h ei 121 | hen h en 122 | heng h eng 123 | hong h ong 124 | hou h ou 125 | hu h u 126 | hua h ua 127 | huai h uai 128 | huan h uan 129 | huang h uang 130 | hui h ui 131 | hun h un 132 | huo h uo 133 | ji j i 134 | jia j ia 135 | jian j ian 136 | jiang j iang 137 | jiao j iao 138 | jie j ie 139 | jin j in 140 | jing j ing 141 | jiong j iong 142 | jiu j iu 143 | ju j v 144 | jv j v 145 | juan j van 146 | jvan j van 147 | jue j ve 148 | jve j ve 149 | jun j vn 150 | jvn j vn 151 | ka k a 152 | kai k ai 153 | kan k an 154 | kang k ang 155 | kao k ao 156 | ke k e 157 | kei k ei 158 | ken k en 159 | keng k eng 160 | kong k ong 161 | kou k ou 162 | ku k u 163 | kua k ua 164 | kuai k uai 165 | kuan k uan 166 | kuang k uang 167 | kui k ui 168 | kun k un 169 | kuo k uo 170 | la l a 171 | lai l ai 172 | lan l an 173 | lang l ang 174 | lao l ao 175 | le l e 176 | lei l ei 177 | leng l eng 178 | li l i 179 | lia l ia 180 | lian l ian 181 | liang l iang 182 | liao l iao 183 | lie l ie 184 | lin l in 185 | ling l ing 186 | liu l iu 187 | lo l o 188 | long l ong 189 | lou l ou 190 | lu l u 191 | luan l uan 192 | lun l un 193 | luo l uo 194 | lv l v 195 | lve l ve 196 | ma m a 197 | mai m ai 198 | man m an 199 | mang m ang 200 | mao m ao 201 | me m e 202 | mei m ei 203 | men m en 204 | meng m eng 205 | mi m i 206 | mian m ian 207 | miao m iao 208 | mie m ie 209 | min m in 210 | ming m ing 211 | miu m iu 212 | mo m o 213 | mou m ou 214 | mu m u 215 | na n a 216 | nai n ai 217 | nan n an 218 | nang n ang 219 | nao n ao 220 | ne n e 221 | nei n ei 222 | nen n en 223 | neng n eng 224 | ni n i 225 | nian n ian 226 | niang n iang 227 | niao n iao 228 | nie n ie 229 | nin n in 230 | ning n ing 231 | niu n iu 232 | nong n ong 233 | nou n ou 234 | nu n u 235 | nuan n uan 236 | nun n un 237 | nuo n uo 238 | nv n v 239 | nve n ve 240 | o OO o 241 | ou OO ou 242 | pa p a 243 | pai p ai 244 | pan p an 245 | pang p ang 246 | pao p ao 247 | pei p ei 248 | pen p en 249 | peng p eng 250 | pi p i 251 | pian p ian 252 | piao p iao 253 | pie p ie 254 | pin p in 255 | ping p ing 256 | po p o 257 | pou p ou 258 | pu p u 259 | qi q i 260 | qia q ia 261 | qian q ian 262 | qiang q iang 263 | qiao q iao 264 | qie q ie 265 | qin q in 266 | qing q ing 267 | qiong q iong 268 | qiu q iu 269 | qu q v 270 | qv q v 271 | quan q van 272 | qvan q van 273 | que q ve 274 | qve q ve 275 | qun q vn 276 | qvn q vn 277 | ran r an 278 | rang r ang 279 | rao r ao 280 | re r e 281 | ren r en 282 | reng r eng 283 | ri r ir 284 | rong r ong 285 | rou r ou 286 | ru r u 287 | rua r ua 288 | ruan r uan 289 | rui r ui 290 | run r un 291 | ruo r uo 292 | sa s a 293 | sai s ai 294 | san s an 295 | sang s ang 296 | sao s ao 297 | se s e 298 | sen s en 299 | seng s eng 300 | sha sh a 301 | shai sh ai 302 | shan sh an 303 | shang sh ang 304 | shao sh ao 305 | she sh e 306 | shei sh ei 307 | shen sh en 308 | sheng sh eng 309 | shi sh ir 310 | shou sh ou 311 | shu sh u 312 | shua sh ua 313 | shuai sh uai 314 | shuan sh uan 315 | shuang sh uang 316 | shui sh ui 317 | shun sh un 318 | shuo sh uo 319 | si s i0 320 | song s ong 321 | sou s ou 322 | su s u 323 | suan s uan 324 | sui s ui 325 | sun s un 326 | suo s uo 327 | ta t a 328 | tai t ai 329 | tan t an 330 | tang t ang 331 | tao t ao 332 | te t e 333 | tei t ei 334 | teng t eng 335 | ti t i 336 | tian t ian 337 | tiao t iao 338 | tie t ie 339 | ting t ing 340 | tong t ong 341 | tou t ou 342 | tu t u 343 | tuan t uan 344 | tui t ui 345 | tun t un 346 | tuo t uo 347 | wa w a 348 | wai w ai 349 | wan w an 350 | wang w ang 351 | wei w ei 352 | wen w en 353 | weng w eng 354 | wo w o 355 | wu w u 356 | xi x i 357 | xia x ia 358 | xian x ian 359 | xiang x iang 360 | xiao x iao 361 | xie x ie 362 | xin x in 363 | xing x ing 364 | xiong x iong 365 | xiu x iu 366 | xu x v 367 | xv x v 368 | xuan x van 369 | xvan x van 370 | xue x ve 371 | xve x ve 372 | xun x vn 373 | xvn x vn 374 | ya y a 375 | yan y En 376 | yang y ang 377 | yao y ao 378 | ye y E 379 | yi y i 380 | yin y in 381 | ying y ing 382 | yo y o 383 | yong y ong 384 | you y ou 385 | yu y v 386 | yv y v 387 | yuan y van 388 | yvan y van 389 | yue y ve 390 | yve y ve 391 | yun y vn 392 | yvn y vn 393 | za z a 394 | zai z ai 395 | zan z an 396 | zang z ang 397 | zao z ao 398 | ze z e 399 | zei z ei 400 | zen z en 401 | zeng z eng 402 | zha zh a 403 | zhai zh ai 404 | zhan zh an 405 | zhang zh ang 406 | zhao zh ao 407 | zhe zh e 408 | zhei zh ei 409 | zhen zh en 410 | zheng zh eng 411 | zhi zh ir 412 | zhong zh ong 413 | zhou zh ou 414 | zhu zh u 415 | zhua zh ua 416 | zhuai zh uai 417 | zhuan zh uan 418 | zhuang zh uang 419 | zhui zh ui 420 | zhun zh un 421 | zhuo zh uo 422 | zi z i0 423 | zong z ong 424 | zou z ou 425 | zu z u 426 | zuan z uan 427 | zui z ui 428 | zun z un 429 | zuo z uo 430 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | punctuation = ['!', '?', '…', ",", ".", '-'] 4 | pu_symbols = punctuation + ["SP", "UNK"] 5 | pad = '_' 6 | 7 | c = ['AA', 'EE', 'OO', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'w', 'x', 'y', 'z', 'zh'] 8 | v = ['E1', 'En1', 'a1', 'ai1', 'an1', 'ang1', 'ao1', 'e1', 'ei1', 'en1', 'eng1', 'er1', 'i1', 'i01', 'ia1', 'ian1', 'iang1', 'iao1', 'ie1', 'in1', 'ing1', 'iong1', 'ir1', 'iu1', 'o1', 'ong1', 'ou1', 'u1', 'ua1', 'uai1', 'uan1', 'uang1', 'ui1', 'un1', 'uo1', 'v1', 'van1', 've1', 'vn1', 'E2', 'En2', 'a2', 'ai2', 'an2', 'ang2', 'ao2', 'e2', 'ei2', 'en2', 'eng2', 'er2', 'i2', 'i02', 'ia2', 'ian2', 'iang2', 'iao2', 'ie2', 'in2', 'ing2', 'iong2', 'ir2', 'iu2', 'o2', 'ong2', 'ou2', 'u2', 'ua2', 'uai2', 'uan2', 'uang2', 'ui2', 'un2', 'uo2', 'v2', 'van2', 've2', 'vn2', 'E3', 'En3', 'a3', 'ai3', 'an3', 'ang3', 'ao3', 'e3', 'ei3', 'en3', 'eng3', 'er3', 'i3', 'i03', 'ia3', 'ian3', 'iang3', 'iao3', 'ie3', 'in3', 'ing3', 'iong3', 'ir3', 'iu3', 'o3', 'ong3', 'ou3', 'u3', 'ua3', 'uai3', 'uan3', 'uang3', 'ui3', 'un3', 'uo3', 'v3', 'van3', 've3', 'vn3', 'E4', 'En4', 'a4', 'ai4', 'an4', 'ang4', 'ao4', 'e4', 'ei4', 'en4', 'eng4', 'er4', 'i4', 'i04', 'ia4', 'ian4', 'iang4', 'iao4', 'ie4', 'in4', 'ing4', 'iong4', 'ir4', 'iu4', 'o4', 'ong4', 'ou4', 'u4', 'ua4', 'uai4', 'uan4', 'uang4', 'ui4', 'un4', 'uo4', 'v4', 'van4', 've4', 'vn4', 'E5', 'En5', 'a5', 'ai5', 'an5', 'ang5', 'ao5', 'e5', 'ei5', 'en5', 'eng5', 'er5', 'i5', 'i05', 'ia5', 'ian5', 'iang5', 'iao5', 'ie5', 'in5', 'ing5', 'iong5', 'ir5', 'iu5', 'o5', 'ong5', 'ou5', 'u5', 'ua5', 'uai5', 'uan5', 'uang5', 'ui5', 'un5', 'uo5', 'v5', 'van5', 've5', 'vn5'] 9 | 10 | v_without_tone = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn'] 11 | 12 | # japanese 13 | ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky', 14 | 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'v', 'w', 'y', 'z'] 15 | 16 | arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} 17 | 18 | symbols = [pad] + c + v + ja_symbols + pu_symbols + list(arpa) 19 | symbols = sorted(set(symbols)) 20 | if __name__ == '__main__': 21 | print(len(symbols)) --------------------------------------------------------------------------------