├── .gitignore ├── LICENSE ├── README.md ├── cli.py ├── config └── config_gan.yaml ├── configs ├── config_adm.yaml ├── config_gan.yaml └── config_plm.yaml ├── examples └── mel_step_400k_re_loss_0.4771.png ├── infer.py ├── models ├── megatts2.py └── trainer.py ├── modules ├── __init__.py ├── convnet.py ├── datamodule.py ├── dscrm.py ├── embedding.py ├── mrte.py ├── quantization │ ├── __init__.py │ ├── ac.py │ ├── core_vq.py │ └── vq.py ├── tokenizer.py ├── transformer.py └── vqpe.py ├── prepare_ds.py ├── requirements.txt └── utils ├── distrib.py ├── mandarin_pinyin_to_mfa_lty.dict ├── symbol_table.py ├── textgrid.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | data/ 163 | logs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 LSimon95 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # megatts2 2 | Unofficial implementation of Megatts2 3 | 4 | ## TODO 5 | ### Base test 6 | - [x] Prepare dataset 7 | - [x] VQ-GAN 8 | - [x] ADM 9 | - [x] PLM 10 | ### Better version 11 | - [ ] Replace Hifigan with Bigvgan 12 | - [ ] Mix training Chinese and English 13 | - [ ] Train on about 1k hours of speech 14 | - [ ] Webui 15 | 16 | ## Install mfa 17 | 1. conda create -n aligner && conda activate aligner 18 | 2. conda install -c conda-forge montreal-forced-aligner=2.2.17 19 | 20 | ## Prepare dataset 21 | 1. Prepare wav and txt files to ./data/wav 22 | 2. Run `python3 prepare_ds.py --stage 0 --num_workers 4 --wavtxt_path data/wavs --text_grid_path data/textgrids --ds_path data/ds` 23 | 3. mfa model download acoustic mandarin_mfa 24 | 4. mfa align data/wavs utils/mandarin_pinyin_to_mfa_lty.dict mandarin_mfa data/textgrids --clean -j 12 -t /workspace/tmp 25 | 5. Run `python3 prepare_ds.py --stage 1 --num_workers 4 --wavtxt_path data/wavs --text_grid_path data/textgrids --ds_path data/ds` 26 | 6. Run `python3 prepare_ds.py --stage 2 --generator_config configs/config_gan.yaml --generator_ckpt generator.ckpt` after training generator. 27 | 28 | ## Train 29 | Training procedure refers to Pytorch-lightning 30 | 31 | ## Infer test 32 | `python infer.py` 33 | 34 | ## Citing 35 | ```bibtex 36 | @misc{2307.07218, 37 | Author = {Ziyue Jiang and Jinglin Liu and Yi Ren and Jinzheng He and Chen Zhang and Zhenhui Ye and Pengfei Wei and Chunfeng Wang and Xiang Yin and Zejun Ma and Zhou Zhao}, 38 | Title = {Mega-TTS 2: Zero-Shot Text-to-Speech with Arbitrary Length Speech Prompts}, 39 | Year = {2023}, 40 | Eprint = {arXiv:2307.07218}, 41 | } 42 | ``` 43 | 44 | ## License 45 | - MIT 46 | - Support by Simon of [ZideAI](https://zideai.com/) -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | # main.py 2 | from lightning.pytorch.cli import LightningCLI 3 | 4 | from models.trainer import MegaGANTrainer, MegaPLMTrainer, MegaADMTrainer 5 | from modules.datamodule import TTSDataModule, test 6 | 7 | 8 | 9 | def cli_main(): 10 | cli = LightningCLI(MegaADMTrainer, TTSDataModule) 11 | 12 | if __name__ == "__main__": 13 | cli_main() 14 | # note: it is good practice to implement the CLI in a function and call it in the main if block -------------------------------------------------------------------------------- /config/config_gan.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.0 2 | seed_everything: true 3 | trainer: 4 | logger: 5 | class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 6 | init_args: 7 | save_dir: logs/ 8 | callbacks: 9 | - class_path: lightning.pytorch.callbacks.ModelSummary 10 | init_args: 11 | max_depth: 3 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | monitor: val/loss_re 15 | filename: nikatts_ar_checkpoint_{epoch}_{step}_{val/loss_re:.4f} 16 | save_top_k: 5 17 | save_last: true 18 | every_n_epochs: 1 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | # ~ 3 epochs 23 | max_steps: 600000 24 | # # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 25 | # limit_val_batches: 100 26 | accelerator: gpu 27 | log_every_n_steps: 100 28 | val_check_interval: 1500 29 | check_val_every_n_epoch: 1 30 | 31 | # strategy: ddp 32 | # devices: [0, 1] 33 | # use_distributed_sampler: false 34 | 35 | devices: [0] 36 | model: 37 | G: 38 | class_path: models.megatts2.MegaVQ 39 | init_args: 40 | mrte: 41 | class_path: modules.mrte.MRTE 42 | init_args: 43 | mel_bins: 80 44 | mel_frames: 256 45 | mel_activation: ReLU 46 | mel_kernel_size: 3 47 | mel_stride: 16 48 | mel_n_stack: 5 49 | mel_n_block: 2 50 | content_ff_dim: 1024 51 | content_n_heads: 2 52 | content_n_layers: 8 53 | hidden_size: 512 54 | duration_token_ms: 16.0 55 | phone_vocab_size: 320 56 | dropout: 0.1 57 | sample_rate: 16000 58 | vqpe: 59 | class_path: modules.vqpe.VQProsodyEncoder 60 | init_args: 61 | mel_bins: 80 62 | stride: 8 63 | hidden_size: 384 64 | kernel_size: 5 65 | n_stack: 3 66 | n_block: 2 67 | vq_bins: 1024 68 | vq_dim: 256 69 | activation: ReLU 70 | kernel_size: 5 71 | activation: ReLU 72 | hidden_size: 512 73 | decoder_n_stack: 4 74 | decoder_n_block: 2 75 | D: 76 | class_path: modules.dscrm.Discriminator 77 | init_args: 78 | time_lengths: 79 | - 32 80 | - 64 81 | - 128 82 | freq_length: 80 83 | kernel: 84 | - 3 85 | - 3 86 | c_in: 1 87 | hidden_size: 192 88 | initial_learning_rate: 3e-5 89 | warmup_steps: 200.0 90 | 91 | G_commit_loss_coeff: 0.15 92 | G_vq_loss_coeff: 0.05 93 | G_adv_loss_coeff: 1.0 94 | train_dtype: bfloat16 95 | class_path: models.trainer.MegaGANTrainer 96 | data: 97 | ds_path: /root/autodl-tmp/megatts2/data/ds/ 98 | max_duration_batch: 120 99 | min_duration: 2.1 100 | max_duration: 20 101 | num_buckets: 10 102 | num_workers: 4 103 | class_path: modules.datamodule.TTSDataModule 104 | ckpt_path: last.ckpt -------------------------------------------------------------------------------- /configs/config_adm.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.0 2 | seed_everything: true 3 | trainer: 4 | logger: 5 | class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 6 | init_args: 7 | save_dir: logs/ 8 | callbacks: 9 | - class_path: lightning.pytorch.callbacks.ModelSummary 10 | init_args: 11 | max_depth: 3 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | monitor: val/loss 15 | filename: nikatts_ar_checkpoint_{epoch}_{step}_{val/loss:.4f} 16 | save_top_k: 5 17 | save_last: true 18 | every_n_epochs: 1 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | max_steps: 50000 23 | # # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 24 | # limit_val_batches: 100 25 | accelerator: gpu 26 | log_every_n_steps: 100 27 | val_check_interval: 1000 28 | check_val_every_n_epoch: 1 29 | 30 | # strategy: ddp 31 | # devices: [0, 1] 32 | # use_distributed_sampler: false 33 | 34 | devices: [0] 35 | model: 36 | adm: 37 | class_path: models.megatts2.MegaADM 38 | init_args: 39 | n_layers: 8 40 | n_heads: 8 41 | emb_dim: 256 42 | tc_latent_dim: 512 43 | tc_emb_dim: 512 44 | dropout: 0.1 45 | max_duration_token: 256 46 | initial_learning_rate: 2e-5 47 | warmup_steps: 200.0 48 | train_dtype: float32 49 | class_path: models.trainer.MegaPLMTrainer 50 | data: 51 | ds_path: /root/autodl-tmp/megatts2/data/ds/ 52 | dataset: MegaADMDataset 53 | max_duration_batch: 400 54 | min_duration: 2.1 55 | max_duration: 20 56 | num_buckets: 10 57 | num_workers: 4 58 | class_path: modules.datamodule.TTSDataModule 59 | ckpt_path: null -------------------------------------------------------------------------------- /configs/config_gan.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.0 2 | seed_everything: true 3 | trainer: 4 | logger: 5 | class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 6 | init_args: 7 | save_dir: logs/ 8 | callbacks: 9 | - class_path: lightning.pytorch.callbacks.ModelSummary 10 | init_args: 11 | max_depth: 3 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | monitor: val/loss_re 15 | filename: nikatts_ar_checkpoint_{epoch}_{step}_{val/loss_re:.4f} 16 | save_top_k: 5 17 | save_last: true 18 | every_n_epochs: 1 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | # ~ 3 epochs 23 | max_steps: 600000 24 | # # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 25 | # limit_val_batches: 100 26 | accelerator: gpu 27 | log_every_n_steps: 100 28 | val_check_interval: 1500 29 | check_val_every_n_epoch: 1 30 | 31 | # strategy: ddp 32 | # devices: [0, 1] 33 | # use_distributed_sampler: false 34 | 35 | devices: [0] 36 | model: 37 | G: 38 | class_path: models.megatts2.MegaG 39 | init_args: 40 | mrte: 41 | class_path: modules.mrte.MRTE 42 | init_args: 43 | mel_bins: 80 44 | mel_frames: 256 45 | mel_activation: ReLU 46 | mel_kernel_size: 3 47 | mel_stride: 16 48 | mel_n_layer: 5 49 | mel_n_stack: 5 50 | mel_n_block: 2 51 | content_ff_dim: 1024 52 | content_n_heads: 2 53 | content_n_layers: 8 54 | hidden_size: 512 55 | duration_token_ms: 16.0 56 | phone_vocab_size: 320 57 | dropout: 0.1 58 | sample_rate: 16000 59 | vqpe: 60 | class_path: modules.vqpe.VQProsodyEncoder 61 | init_args: 62 | mel_bins: 20 63 | stride: 8 64 | hidden_size: 384 65 | kernel_size: 5 66 | n_layers: 3 67 | n_stacks: 5 68 | n_blocks: 2 69 | vq_bins: 1024 70 | vq_dim: 256 71 | activation: ReLU 72 | kernel_size: 5 73 | activation: ReLU 74 | hidden_size: 512 75 | decoder_n_stack: 4 76 | decoder_n_block: 2 77 | D: 78 | class_path: modules.dscrm.Discriminator 79 | init_args: 80 | time_lengths: 81 | - 32 82 | - 64 83 | - 128 84 | freq_length: 80 85 | kernel: 86 | - 3 87 | - 3 88 | c_in: 1 89 | hidden_size: 192 90 | initial_learning_rate: 3e-5 91 | warmup_steps: 200.0 92 | 93 | G_commit_loss_coeff: 0.15 94 | G_vq_loss_coeff: 0.05 95 | G_adv_loss_coeff: 1.0 96 | train_dtype: bfloat16 97 | class_path: models.trainer.MegaGANTrainer 98 | data: 99 | ds_path: /root/autodl-tmp/megatts2/data/ds/ 100 | max_duration_batch: 60 101 | min_duration: 2.1 102 | max_duration: 20 103 | num_buckets: 10 104 | num_workers: 4 105 | class_path: modules.datamodule.TTSDataModule 106 | ckpt_path: null -------------------------------------------------------------------------------- /configs/config_plm.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.0 2 | seed_everything: true 3 | trainer: 4 | logger: 5 | class_path: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 6 | init_args: 7 | save_dir: logs/ 8 | callbacks: 9 | - class_path: lightning.pytorch.callbacks.ModelSummary 10 | init_args: 11 | max_depth: 3 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | monitor: val/loss 15 | filename: nikatts_ar_checkpoint_{epoch}_{step}_{val/loss:.4f} 16 | save_top_k: 5 17 | save_last: true 18 | every_n_epochs: 1 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | max_steps: 100000 23 | # # You might want to limit val batches when evaluating all the metrics, as they are time-consuming 24 | # limit_val_batches: 100 25 | accelerator: gpu 26 | log_every_n_steps: 100 27 | val_check_interval: 5000 28 | check_val_every_n_epoch: 1 29 | 30 | # strategy: ddp 31 | # devices: [0, 1] 32 | # use_distributed_sampler: false 33 | 34 | devices: [0] 35 | model: 36 | plm: 37 | class_path: models.megatts2.MegaPLM 38 | init_args: 39 | n_layers: 12 40 | n_heads: 16 41 | vq_dim: 512 42 | tc_latent_dim: 512 43 | vq_bins: 1024 44 | dropout: 0.1 45 | initial_learning_rate: 1e-4 46 | warmup_steps: 200.0 47 | train_dtype: bfloat16 48 | class_path: models.trainer.MegaPLMTrainer 49 | data: 50 | ds_path: /root/autodl-tmp/megatts2/data/ds/ 51 | dataset: MegaPLMDataset 52 | min_duration: 2.1 53 | max_duration: 20 54 | num_workers: 4 55 | max_n_cuts: 15 56 | class_path: modules.datamodule.TTSDataModule 57 | ckpt_path: null -------------------------------------------------------------------------------- /examples/mel_step_400k_re_loss_0.4771.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LSimon95/megatts2/2ab81a1234791e809cdcfca12e3c7b3d0cc2a9f3/examples/mel_step_400k_re_loss_0.4771.png -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from models.megatts2 import Megatts 2 | 3 | if __name__ == '__main__': 4 | megatts = Megatts( 5 | g_ckpt='generator.ckpt', 6 | g_config='configs/config_gan.yaml', 7 | plm_ckpt='plm.ckpt', 8 | plm_config='configs/config_plm.yaml', 9 | adm_ckpt='adm.ckpt', 10 | adm_config='configs/config_adm.yaml', 11 | symbol_table='/root/autodl-tmp/megatts2/data/ds/unique_text_tokens.k2symbols' 12 | ) 13 | 14 | megatts.eval() 15 | 16 | megatts( 17 | '/root/autodl-tmp/megatts2/data/test', 18 | '八百标兵奔北坡北坡炮兵并排跑炮兵怕把标兵碰标兵怕碰炮兵炮黑化黑灰化肥灰会挥发发灰黑讳为黑灰花会回飞', 19 | ) -------------------------------------------------------------------------------- /models/megatts2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules.mrte import MRTE 6 | from modules.vqpe import VQProsodyEncoder 7 | from modules.convnet import ConvNet 8 | from modules.embedding import SinePositionalEmbedding, TokenEmbedding 9 | 10 | from modules.transformer import TransformerEncoder, TransformerEncoderLayer 11 | 12 | from einops import rearrange 13 | 14 | import yaml 15 | 16 | from utils.utils import instantiate_class 17 | import glob 18 | import librosa 19 | 20 | from modules.tokenizer import extract_mel_spec, TextTokenizer, HIFIGAN_SR, HIFIGAN_HOP_LENGTH 21 | from modules.datamodule import TokensCollector 22 | 23 | import numpy as np 24 | from modules.mrte import LengthRegulator 25 | from speechbrain.pretrained import HIFIGAN 26 | 27 | import torchaudio 28 | 29 | 30 | class MegaG(nn.Module): 31 | def __init__( 32 | self, 33 | mrte: MRTE, 34 | vqpe: VQProsodyEncoder, 35 | kernel_size: int = 5, 36 | activation: str = 'ReLU', 37 | hidden_size: int = 512, 38 | decoder_n_stack: int = 4, 39 | decoder_n_block: int = 2 40 | 41 | ): 42 | super(MegaG, self).__init__() 43 | 44 | self.mrte = mrte 45 | self.vqpe = vqpe 46 | self.decoder = ConvNet( 47 | in_channels=mrte.hidden_size + vqpe.vq.dimension, 48 | out_channels=mrte.mel_bins, 49 | hidden_size=hidden_size, 50 | n_stacks=decoder_n_stack, 51 | n_blocks=decoder_n_block, 52 | kernel_size=kernel_size, 53 | activation=activation, 54 | ) 55 | 56 | def forward( 57 | self, 58 | duration_tokens: torch.Tensor, # (B, T) 59 | phone: torch.Tensor, # (B, T) 60 | phone_lens: torch.Tensor, # (B,) 61 | mel_mrte: torch.Tensor, # (B, T, mel_bins) 62 | mel_vqpe: torch.Tensor, # (B, T, mel_bins) 63 | ): 64 | zq, commit_loss, vq_loss, _ = self.vqpe(mel_vqpe) 65 | x = self.mrte(duration_tokens, phone, phone_lens, mel_mrte) 66 | 67 | x = torch.cat([x, zq], dim=-1) 68 | 69 | x = rearrange(x, 'B T D -> B D T') 70 | x = self.decoder(x) 71 | x = rearrange(x, 'B D T -> B T D') 72 | 73 | return x, commit_loss, vq_loss 74 | 75 | def s2_latent( 76 | self, 77 | phone: torch.Tensor, # (B, T) 78 | phone_lens: torch.Tensor, # (B,) 79 | mel_mrte: torch.Tensor, # (B, T, mel_bins) 80 | mel_vqpe: torch.Tensor, # (B, T, mel_bins) 81 | ): 82 | _, _, _, codes = self.vqpe(mel_vqpe) 83 | x = self.mrte.tc_latent(phone, phone_lens, mel_mrte) 84 | return x, codes 85 | 86 | @classmethod 87 | def from_hparams(self, config_path: str) -> "MegaG": 88 | 89 | with open(config_path, "r") as f: 90 | config = yaml.safe_load(f) 91 | 92 | G_config = init = config['model']['G'] 93 | 94 | mrte = instantiate_class( 95 | args=(), init=G_config['init_args']['mrte']) 96 | vqpe = instantiate_class( 97 | args=(), init=G_config['init_args']['vqpe']) 98 | 99 | G_config['init_args']['mrte'] = mrte 100 | G_config['init_args']['vqpe'] = vqpe 101 | 102 | G = instantiate_class(args=(), init=G_config) 103 | 104 | return G 105 | 106 | @classmethod 107 | def from_pretrained(self, ckpt: str, config: str) -> "MegaG": 108 | 109 | G = MegaG.from_hparams(config) 110 | 111 | state_dict = {} 112 | for k, v in torch.load(ckpt)['state_dict'].items(): 113 | if k.startswith('G.'): 114 | state_dict[k[2:]] = v 115 | 116 | G.load_state_dict(state_dict, strict=True) 117 | return G 118 | 119 | 120 | class MegaPLM(nn.Module): 121 | def __init__( 122 | self, 123 | n_layers: int = 12, 124 | n_heads: int = 16, 125 | vq_dim: int = 512, 126 | tc_latent_dim: int = 512, 127 | vq_bins: int = 1024, 128 | dropout: float = 0.1, 129 | ): 130 | super(MegaPLM, self).__init__() 131 | d_model = vq_dim + tc_latent_dim 132 | self.plm = TransformerEncoder( 133 | TransformerEncoderLayer( 134 | dim=d_model, 135 | ff_dim=d_model * 4, 136 | n_heads=n_heads, 137 | dropout=dropout, 138 | conv_ff=False, 139 | ), 140 | num_layers=n_layers, 141 | ) 142 | 143 | self.predict_layer = nn.Linear(d_model, vq_bins, bias=False) 144 | 145 | self.pos = SinePositionalEmbedding(d_model) 146 | self.pc_embedding = nn.Embedding(vq_bins + 2, vq_dim) 147 | 148 | def forward( 149 | self, 150 | tc_latent: torch.Tensor, # (B, T, D) 151 | p_codes: torch.Tensor, # (B, T) 152 | lens: torch.Tensor, # (B,) 153 | ): 154 | pc_emb = self.pc_embedding(p_codes[:, :-1]) 155 | x_emb = torch.cat([tc_latent, pc_emb], dim=-1) 156 | x_pos = self.pos(x_emb) 157 | 158 | x = self.plm(x_pos, lens, causal=True) 159 | logits = self.predict_layer(x) 160 | 161 | target = p_codes[:, 1:] 162 | 163 | return logits, target 164 | 165 | def infer( 166 | self, 167 | tc_latent: torch.Tensor, # (B, T, D) 168 | ): 169 | T = tc_latent.shape[1] 170 | p_code = torch.Tensor([1024]).to( 171 | tc_latent.device).type(torch.int64).unsqueeze(0) 172 | for t in range(T): 173 | pc_emb = self.pc_embedding(p_code) 174 | x_emb = torch.cat([tc_latent[:, 0:t+1, :], pc_emb], dim=-1) 175 | x_pos = self.pos(x_emb) 176 | 177 | x = self.plm(x_pos) 178 | logits = self.predict_layer(x)[:, -1:, :] 179 | p_code = torch.cat([p_code, logits.argmax(dim=-1)], dim=1) 180 | 181 | return p_code[:, 1:] 182 | 183 | @classmethod 184 | def from_pretrained(cls, ckpt: str, config: str) -> "MegaPLM": 185 | 186 | with open(config, "r") as f: 187 | config = yaml.safe_load(f) 188 | 189 | plm_config = config['model']['plm'] 190 | plm = instantiate_class(args=(), init=plm_config) 191 | 192 | state_dict = {} 193 | for k, v in torch.load(ckpt)['state_dict'].items(): 194 | if k.startswith('plm.'): 195 | state_dict[k[4:]] = v 196 | 197 | plm.load_state_dict(state_dict, strict=True) 198 | return plm 199 | 200 | 201 | class MegaADM(nn.Module): 202 | def __init__( 203 | self, 204 | n_layers: int = 8, 205 | n_heads: int = 8, 206 | emb_dim: int = 256, 207 | tc_latent_dim: int = 512, 208 | tc_emb_dim: int = 256, 209 | dropout: float = 0.1, 210 | max_duration_token: int = 256, 211 | ): 212 | super(MegaADM, self).__init__() 213 | 214 | d_model = emb_dim + tc_emb_dim 215 | self.adm = TransformerEncoder( 216 | TransformerEncoderLayer( 217 | dim=d_model, 218 | ff_dim=emb_dim * 4, 219 | n_heads=n_heads, 220 | dropout=dropout, 221 | conv_ff=False, 222 | ), 223 | num_layers=n_layers, 224 | ) 225 | 226 | self.dt_linear_emb = nn.Linear(1, emb_dim, bias=False) 227 | self.tc_linear_emb = nn.Linear(tc_latent_dim, tc_emb_dim, bias=False) 228 | self.pos_emb = SinePositionalEmbedding(d_model) 229 | self.predict_layer = nn.Linear(d_model, 1, bias=False) 230 | 231 | self.max_duration_token = max_duration_token 232 | 233 | def forward( 234 | self, 235 | tc_latents: torch.Tensor, # (B, T, D) 236 | duration_tokens: torch.Tensor, # (B, T) 237 | lens: torch.Tensor, # (B,) 238 | ): 239 | dt_emb = self.dt_linear_emb(duration_tokens[:, :-1]) 240 | tc_emb = self.tc_linear_emb(tc_latents) 241 | x_emb = torch.cat([tc_emb, dt_emb], dim=-1) 242 | x_pos = self.pos_emb(x_emb) 243 | 244 | x = self.adm(x_pos, lens, causal=True) 245 | duration_tokens_predict = self.predict_layer(x)[..., 0] 246 | 247 | target = duration_tokens[:, 1:, 0] 248 | 249 | # fill padding with 0 250 | # max_len = duration_tokens.size(1) - 1 251 | # seq_range = torch.arange(0, max_len, device=duration_tokens_predict.device) 252 | # expaned_lengths = seq_range.unsqueeze(0).expand(lens.size(0), max_len) 253 | # mask = expaned_lengths >= lens.unsqueeze(-1) 254 | # duration_tokens_predict = duration_tokens_predict.masked_fill(mask, 0) 255 | return duration_tokens_predict, target 256 | 257 | def infer( 258 | self, 259 | tc_latents: torch.Tensor, # (B, T, D) 260 | ): 261 | T = tc_latents.shape[1] 262 | p_code = torch.Tensor([0]).to( 263 | tc_latents.device).unsqueeze(0).unsqueeze(1) 264 | for t in range(T): 265 | dt_emb = self.dt_linear_emb(p_code) 266 | tc_emb = self.tc_linear_emb(tc_latents[:, 0:t+1, :]) 267 | 268 | x_emb = torch.cat([tc_emb, dt_emb], dim=-1) 269 | x_pos = self.pos_emb(x_emb) 270 | 271 | x = self.adm(x_pos) 272 | dt_predict = self.predict_layer(x)[:, -1:, :] 273 | p_code = torch.cat([p_code, dt_predict], dim=1) 274 | 275 | return (p_code[:, 1:, :] + 0.5).to(torch.int32).clamp(1, 128) 276 | 277 | @classmethod 278 | def from_pretrained(self, ckpt: str, config: str) -> "MegaADM": 279 | 280 | with open(config, "r") as f: 281 | config = yaml.safe_load(f) 282 | 283 | adm_config = config['model']['adm'] 284 | adm = instantiate_class(args=(), init=adm_config) 285 | 286 | state_dict = {} 287 | for k, v in torch.load(ckpt)['state_dict'].items(): 288 | if k.startswith('adm.'): 289 | state_dict[k[4:]] = v 290 | 291 | adm.load_state_dict(state_dict, strict=True) 292 | return adm 293 | 294 | 295 | class Megatts(nn.Module): 296 | def __init__( 297 | self, 298 | g_ckpt: str, 299 | g_config: str, 300 | plm_ckpt: str, 301 | plm_config: str, 302 | adm_ckpt: str, 303 | adm_config: str, 304 | symbol_table: str 305 | ): 306 | super(Megatts, self).__init__() 307 | 308 | self.generator = MegaG.from_pretrained(g_ckpt, g_config) 309 | self.generator.eval() 310 | self.plm = MegaPLM.from_pretrained(plm_ckpt, plm_config) 311 | self.plm.eval() 312 | self.adm = MegaADM.from_pretrained(adm_ckpt, adm_config) 313 | self.adm.eval() 314 | 315 | self.tt = TextTokenizer() 316 | self.ttc = TokensCollector(symbol_table) 317 | 318 | self.lr = LengthRegulator( 319 | HIFIGAN_HOP_LENGTH, 16000, (HIFIGAN_HOP_LENGTH / HIFIGAN_SR * 1000)) 320 | 321 | self.hifi_gan = HIFIGAN.from_hparams( 322 | source="speechbrain/tts-hifigan-libritts-16kHz") 323 | self.hifi_gan.eval() 324 | 325 | def forward( 326 | self, 327 | wavs_dir: str, 328 | text: str, 329 | ): 330 | mels_prompt = None 331 | # Make mrte mels 332 | wavs = glob.glob(f'{wavs_dir}/*.wav') 333 | mels = torch.empty(0) 334 | for wav in wavs: 335 | y = librosa.load(wav, sr=HIFIGAN_SR)[0] 336 | y = librosa.util.normalize(y) 337 | # y = librosa.effects.trim(y, top_db=20)[0] 338 | y = torch.from_numpy(y) 339 | 340 | mel_spec = extract_mel_spec(y).transpose(0, 1) 341 | mels = torch.cat([mels, mel_spec], dim=0) 342 | 343 | if mels_prompt is None: 344 | mels_prompt = mel_spec 345 | 346 | mels = mels.unsqueeze(0) 347 | 348 | # G2P 349 | phone_tokens = self.ttc.phone2token( 350 | self.tt.tokenize_lty(self.tt.tokenize(text))) 351 | phone_tokens = phone_tokens.unsqueeze(0) 352 | 353 | with torch.no_grad(): 354 | tc_latent = self.generator.mrte.tc_latent(phone_tokens, mels) 355 | dt = self.adm.infer(tc_latent)[..., 0] 356 | tc_latent_expand = self.lr(tc_latent, dt) 357 | tc_latent = F.max_pool1d(tc_latent_expand.transpose( 358 | 1, 2), 8, ceil_mode=True).transpose(1, 2) 359 | p_codes = self.plm.infer(tc_latent) 360 | 361 | zq = self.generator.vqpe.vq.decode(p_codes.unsqueeze(0)) 362 | zq = rearrange( 363 | zq, "B D T -> B T D").unsqueeze(2).contiguous().expand(-1, -1, 8, -1) 364 | zq = rearrange(zq, "B T S D -> B (T S) D") 365 | x = torch.cat( 366 | [tc_latent_expand, zq[:, :tc_latent_expand.shape[1], :]], dim=-1) 367 | x = rearrange(x, 'B T D -> B D T') 368 | x = self.generator.decoder(x) 369 | 370 | audio = self.hifi_gan.decode_batch(x.cpu()) 371 | audio_prompt = self.hifi_gan.decode_batch( 372 | mels_prompt.unsqueeze(0).transpose(1, 2).cpu()) 373 | audio = torch.cat([audio_prompt, audio], dim=-1) 374 | 375 | torchaudio.save('test.wav', audio[0], HIFIGAN_SR) 376 | -------------------------------------------------------------------------------- /models/trainer.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | 3 | import torch 4 | import torchaudio 5 | import torch.nn.functional as F 6 | 7 | import transformers 8 | 9 | import numpy as np 10 | import math 11 | 12 | from .megatts2 import MegaG, MegaPLM, MegaADM 13 | from modules.dscrm import Discriminator 14 | from modules.tokenizer import HIFIGAN_SR 15 | 16 | from utils.utils import plot_spectrogram_to_numpy 17 | 18 | from speechbrain.pretrained import HIFIGAN 19 | 20 | from torchmetrics.classification import MulticlassAccuracy 21 | 22 | class MegaGANTrainer(pl.LightningModule): 23 | def __init__( 24 | self, 25 | G: MegaG, 26 | D: Discriminator, 27 | initial_learning_rate: float, 28 | warmup_steps: float = 200, 29 | G_commit_loss_coeff: float = 10, 30 | G_vq_loss_coeff: float = 10, 31 | G_adv_loss_coeff: float = 1.0, 32 | 33 | train_dtype: str = "float32", 34 | **kwargs 35 | ): 36 | 37 | super().__init__() 38 | self.automatic_optimization = False 39 | self.save_hyperparameters(ignore=['G', 'D']) 40 | self.G = G 41 | self.D = D 42 | self.validation_step_outputs = [] 43 | 44 | if self.hparams.train_dtype == "float32": 45 | self.train_dtype = torch.float32 46 | elif self.hparams.train_dtype == "bfloat16": 47 | self.train_dtype = torch.bfloat16 48 | print("Using bfloat16") 49 | 50 | def configure_optimizers(self): 51 | D_params = [ 52 | {"params": self.D.parameters()} 53 | ] 54 | G_params = [ 55 | {"params": self.G.parameters()} 56 | ] 57 | 58 | D_opt = torch.optim.AdamW( 59 | D_params, lr=self.hparams.initial_learning_rate) 60 | G_opt = torch.optim.AdamW( 61 | G_params, lr=self.hparams.initial_learning_rate) 62 | 63 | D_sch = transformers.get_cosine_schedule_with_warmup( 64 | D_opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.trainer.max_steps // 2 65 | ) 66 | G_sch = transformers.get_cosine_schedule_with_warmup( 67 | G_opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.trainer.max_steps // 2 68 | ) 69 | 70 | return ( 71 | [D_opt, G_opt], 72 | [{"scheduler": D_sch, "interval": "step"}, { 73 | "scheduler": G_sch, "interval": "step"}], 74 | ) 75 | 76 | def forward(self, batch: dict): 77 | y_hat, commit_loss, vq_loss = self.G( 78 | duration_tokens=batch["duration_tokens"], 79 | phone=batch["phone_tokens"], 80 | phone_lens=batch["tokens_lens"], 81 | mel_mrte=batch["mel_timbres"], 82 | mel_vqpe=batch["mel_targets"] 83 | ) 84 | 85 | return y_hat, commit_loss, vq_loss 86 | 87 | def training_step(self, batch: dict, batch_idx, **kwargs): 88 | opt1, opt2 = self.optimizers() 89 | sch1, sch2 = self.lr_schedulers() 90 | 91 | with torch.cuda.amp.autocast(dtype=self.train_dtype): 92 | self.G.train() 93 | y_hat, G_loss_commit, G_loss_vq = self(batch) 94 | 95 | # Train discriminator 96 | y = batch["mel_targets"] 97 | D_outputs = self.D(y) 98 | D_loss_real = 0.5 * torch.mean((D_outputs["y"] - 1) ** 2) 99 | 100 | D_outputs = self.D(y_hat.detach()) 101 | D_loss_fake = 0.5 * torch.mean(D_outputs["y"] ** 2) 102 | 103 | D_loss_total = D_loss_real + D_loss_fake 104 | 105 | opt1.zero_grad() 106 | self.manual_backward(D_loss_total) 107 | opt1.step() 108 | sch1.step() 109 | 110 | # Train generator 111 | G_loss_re = F.l1_loss(y, y_hat) 112 | 113 | G_loss = G_loss_re + G_loss_commit * self.hparams.G_commit_loss_coeff + \ 114 | G_loss_vq * self.hparams.G_vq_loss_coeff 115 | 116 | G_loss_adv = 0.5 * torch.mean((self.D(y_hat)["y"] - 1) ** 2) 117 | G_loss_total = G_loss_adv * self.hparams.G_adv_loss_coeff + G_loss 118 | 119 | opt2.zero_grad() 120 | self.manual_backward(G_loss_total) 121 | opt2.step() 122 | sch2.step() 123 | 124 | if batch_idx % 5 == 0: 125 | self.log("train/D_loss_total", D_loss_total, prog_bar=True) 126 | self.log("train/D_loss_real", D_loss_real) 127 | self.log("train/D_loss_fake", D_loss_fake) 128 | 129 | self.log("train/G_loss_total", G_loss_total, prog_bar=True) 130 | self.log("train/G_loss_adv", G_loss_adv) 131 | self.log("train/G_loss", G_loss) 132 | self.log("train/G_loss_commit", G_loss_commit) 133 | self.log("train/G_loss_vq", G_loss_vq) 134 | self.log("train/G_loss_re", G_loss_re) 135 | 136 | def on_validation_epoch_start(self): 137 | pass 138 | 139 | def validation_step(self, batch: torch.Tensor, **kwargs): 140 | 141 | y = batch["mel_targets"] 142 | with torch.no_grad(): 143 | self.G.eval() 144 | y_hat, _, _ = self(batch) 145 | 146 | loss_re = F.l1_loss(y, y_hat) 147 | 148 | self.validation_step_outputs.append({ 149 | "y": y[0], 150 | "y_hat": y_hat[0], 151 | "loss_re": loss_re, 152 | }) 153 | 154 | def on_validation_epoch_end(self): 155 | outputs = self.validation_step_outputs 156 | if self.global_rank == 0: 157 | 158 | mel = outputs[0]["y"].transpose(0, 1) 159 | mel_hat = outputs[0]["y_hat"].transpose(0, 1) 160 | 161 | self.logger.experiment.add_image( 162 | "val/mel_analyse", 163 | plot_spectrogram_to_numpy( 164 | mel.data.cpu().numpy(), mel_hat.data.cpu().numpy()), 165 | self.global_step, 166 | dataformats="HWC", 167 | ) 168 | 169 | with torch.no_grad(): 170 | hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-libritts-16kHz") 171 | hifi_gan.eval() 172 | 173 | audio_target = hifi_gan.decode_batch(mel.unsqueeze(0).cpu()) 174 | audio_hat = hifi_gan.decode_batch(mel_hat.unsqueeze(0).cpu()) 175 | 176 | self.logger.experiment.add_audio( 177 | "val/audio_target", 178 | audio_target[0], 179 | self.global_step, 180 | sample_rate=HIFIGAN_SR, 181 | ) 182 | 183 | self.logger.experiment.add_audio( 184 | "val/audio_hat", 185 | audio_hat[0], 186 | self.global_step, 187 | sample_rate=HIFIGAN_SR, 188 | ) 189 | 190 | loss_re = torch.mean(torch.stack( 191 | [x["loss_re"] for x in outputs])) 192 | 193 | self.log("val/loss_re", loss_re, sync_dist=True) 194 | 195 | self.validation_step_outputs = [] 196 | 197 | class MegaPLMTrainer(pl.LightningModule): 198 | def __init__( 199 | self, 200 | plm: MegaPLM, 201 | initial_learning_rate: float, 202 | warmup_steps: float = 200, 203 | train_dtype: str = "float32", 204 | **kwargs 205 | ): 206 | super().__init__() 207 | self.save_hyperparameters(ignore=['plm']) 208 | self.validation_step_outputs = [] 209 | 210 | if self.hparams.train_dtype == "float32": 211 | self.train_dtype = torch.float32 212 | elif self.hparams.train_dtype == "bfloat16": 213 | self.train_dtype = torch.bfloat16 214 | print("Using bfloat16") 215 | 216 | self.plm = plm 217 | 218 | self.accuracy_metric = MulticlassAccuracy( 219 | 1024, 220 | top_k=10, 221 | average="micro", 222 | multidim_average="global", 223 | ignore_index=1024 + 1 224 | ) 225 | 226 | def configure_optimizers(self): 227 | plm_params = [ 228 | {"params": self.plm.parameters()} 229 | ] 230 | 231 | plm_opt = torch.optim.AdamW( 232 | plm_params, lr=self.hparams.initial_learning_rate) 233 | 234 | plm_sch = transformers.get_cosine_schedule_with_warmup( 235 | plm_opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.trainer.max_steps 236 | ) 237 | 238 | return ( 239 | [plm_opt], 240 | [{"scheduler": plm_sch, "interval": "step"}], 241 | ) 242 | 243 | def forward(self, batch: dict): 244 | logits, y = self.plm( 245 | tc_latent=batch["tc_latents"], 246 | p_codes=batch["p_codes"], 247 | lens=batch["lens"] 248 | ) 249 | 250 | logits = logits.transpose(1, 2) 251 | 252 | # ignore padding 253 | loss = F.cross_entropy(logits, y, reduction="sum", ignore_index=1024 + 1) 254 | loss_log = loss / y.shape[0] / y.shape[1] 255 | ac10 = self.accuracy_metric(logits.detach(), y) 256 | 257 | return loss, loss_log, ac10 258 | 259 | def training_step(self, batch: dict, batch_idx, **kwargs): 260 | with torch.cuda.amp.autocast(dtype=self.train_dtype): 261 | self.plm.train() 262 | loss, loss_log, ac10 = self(batch) 263 | 264 | if batch_idx % 5 == 0: 265 | self.log("train/ac10", ac10, prog_bar=True) 266 | self.log("train/loss", loss_log, prog_bar=True) 267 | 268 | return loss 269 | 270 | def on_validation_epoch_start(self): 271 | pass 272 | 273 | def validation_step(self, batch: torch.Tensor, **kwargs): 274 | with torch.no_grad(): 275 | self.plm.eval() 276 | _, loss_log, ac10 = self(batch) 277 | 278 | self.validation_step_outputs.append({ 279 | "loss_log": loss_log, 280 | "ac10": ac10 281 | }) 282 | 283 | def on_validation_epoch_end(self): 284 | outputs = self.validation_step_outputs 285 | if self.global_rank == 0: 286 | loss_log = torch.mean(torch.stack( 287 | [x["loss_log"] for x in outputs])) 288 | ac10 = torch.mean(torch.stack( 289 | [x["ac10"] for x in outputs])) 290 | 291 | self.log("val/loss", loss_log, sync_dist=True) 292 | self.log("val/ac10", ac10, sync_dist=True) 293 | 294 | self.validation_step_outputs = [] 295 | 296 | class MegaADMTrainer(pl.LightningModule): 297 | def __init__( 298 | self, 299 | adm: MegaADM, 300 | initial_learning_rate: float, 301 | warmup_steps: float = 200, 302 | train_dtype: str = "float32", 303 | **kwargs 304 | ): 305 | super().__init__() 306 | self.save_hyperparameters(ignore=['adm']) 307 | self.validation_step_outputs = [] 308 | 309 | if self.hparams.train_dtype == "float32": 310 | self.train_dtype = torch.float32 311 | elif self.hparams.train_dtype == "bfloat16": 312 | self.train_dtype = torch.bfloat16 313 | print("Using bfloat16") 314 | 315 | self.adm = adm 316 | 317 | def configure_optimizers(self): 318 | adm_params = [ 319 | {"params": self.adm.parameters()} 320 | ] 321 | 322 | adm_opt = torch.optim.AdamW( 323 | adm_params, lr=self.hparams.initial_learning_rate) 324 | 325 | adm_sch = transformers.get_cosine_schedule_with_warmup( 326 | adm_opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.trainer.max_steps 327 | ) 328 | 329 | return ( 330 | [adm_opt], 331 | [{"scheduler": adm_sch, "interval": "step"}], 332 | ) 333 | 334 | def forward(self, batch: dict): 335 | duration_tokens_predict, target = self.adm( 336 | tc_latents=batch["tc_latents"], 337 | duration_tokens=batch["duration_tokens"], 338 | lens=batch["lens"] 339 | ) 340 | 341 | # ignore padding 342 | loss = F.mse_loss(duration_tokens_predict, target, reduction="sum") 343 | loss_log = loss / target.shape[0] / target.shape[1] 344 | 345 | return loss, loss_log 346 | 347 | def training_step(self, batch: dict, batch_idx, **kwargs): 348 | with torch.cuda.amp.autocast(dtype=self.train_dtype): 349 | self.adm.train() 350 | loss, loss_log = self(batch) 351 | 352 | if batch_idx % 5 == 0: 353 | self.log("train/loss", loss_log, prog_bar=True) 354 | 355 | return loss 356 | 357 | def on_validation_epoch_start(self): 358 | pass 359 | 360 | def validation_step(self, batch: torch.Tensor, **kwargs): 361 | with torch.no_grad(): 362 | self.adm.eval() 363 | _, loss_log = self(batch) 364 | 365 | self.validation_step_outputs.append({ 366 | "loss_log": loss_log 367 | }) 368 | 369 | def on_validation_epoch_end(self): 370 | outputs = self.validation_step_outputs 371 | if self.global_rank == 0: 372 | loss_log = torch.mean(torch.stack( 373 | [x["loss_log"] for x in outputs])) 374 | 375 | self.log("val/loss", loss_log, sync_dist=True) 376 | 377 | self.validation_step_outputs = [] -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LSimon95/megatts2/2ab81a1234791e809cdcfca12e3c7b3d0cc2a9f3/modules/__init__.py -------------------------------------------------------------------------------- /modules/convnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from einops import rearrange 5 | 6 | from typing import List 7 | 8 | 9 | class ConvBlock(nn.Module): 10 | def __init__(self, hidden_size, kernel_size, activation): 11 | super(ConvBlock, self).__init__() 12 | 13 | self.conv = nn.Conv1d( 14 | in_channels=hidden_size, 15 | out_channels=hidden_size, 16 | kernel_size=kernel_size, 17 | padding=(kernel_size - 1) // 2, 18 | ) 19 | self.norm = nn.LayerNorm(hidden_size) 20 | self.activation = getattr(nn, activation)() 21 | self.dropout = nn.Dropout(0.1) 22 | 23 | def forward(self, x): 24 | x = self.activation(x) 25 | x = self.dropout(x) 26 | x = self.conv(x) 27 | 28 | x = rearrange(x, "B D T -> B T D") 29 | x = self.norm(x) 30 | x = rearrange(x, "B T D -> B D T") 31 | return x 32 | 33 | 34 | class ConvStack(nn.Module): 35 | def __init__(self, hidden_size, n_blocks, kernel_size, activation): 36 | super(ConvStack, self).__init__() 37 | 38 | blocks = [] 39 | for i in range(n_blocks): 40 | blocks += [ 41 | ConvBlock( 42 | hidden_size=hidden_size, 43 | kernel_size=kernel_size, 44 | activation=activation, 45 | ) 46 | ] 47 | self.blocks = nn.Sequential(*blocks) 48 | 49 | def forward(self, x): 50 | return self.blocks(x) 51 | 52 | class ResidualBlockStack(nn.Module): 53 | def __init__(self, hidden_size, n_stacks, n_blocks, kernel_size, activation): 54 | super(ResidualBlockStack, self).__init__() 55 | 56 | self.conv_stacks = [] 57 | 58 | for i in range(n_stacks): 59 | self.conv_stacks += [ 60 | ConvStack( 61 | hidden_size=hidden_size, 62 | n_blocks=n_blocks, 63 | kernel_size=kernel_size, 64 | activation=activation, 65 | ) 66 | ] 67 | self.conv_stacks = nn.Sequential(*self.conv_stacks) 68 | 69 | def forward(self, x): 70 | for conv_stack in self.conv_stacks: 71 | x = x + conv_stack(x) 72 | return x 73 | 74 | class ConvNet(nn.Module): 75 | def __init__( 76 | self, 77 | in_channels: int, 78 | out_channels: int, 79 | hidden_size: int, 80 | n_stacks: int, 81 | n_blocks: int, 82 | kernel_size: int, 83 | activation: str, 84 | last_layer_avg_pooling: bool = False, 85 | ): 86 | super(ConvNet, self).__init__() 87 | 88 | self.first_layer = nn.Conv1d( 89 | in_channels=in_channels, 90 | out_channels=hidden_size, 91 | kernel_size=kernel_size, 92 | stride=1, 93 | padding=(kernel_size - 1) // 2, 94 | ) 95 | 96 | self.conv_stack = ResidualBlockStack( 97 | hidden_size=hidden_size, 98 | n_stacks=n_stacks, 99 | n_blocks=n_blocks, 100 | kernel_size=kernel_size, 101 | activation=activation, 102 | ) 103 | 104 | if last_layer_avg_pooling: 105 | self.last_layer = nn.AdaptiveAvgPool1d(1) 106 | else: 107 | self.last_layer = nn.Conv1d( 108 | in_channels=hidden_size, 109 | out_channels=out_channels, 110 | kernel_size=kernel_size, 111 | stride=1, 112 | padding=(kernel_size - 1) // 2, 113 | ) 114 | 115 | def forward(self, x): 116 | x = self.first_layer(x) 117 | x = self.conv_stack(x) 118 | x = self.last_layer(x) 119 | return x 120 | 121 | class ConvNetDoubleLayer(nn.Module): 122 | def __init__( 123 | self, 124 | hidden_size: int, 125 | n_stacks: int, 126 | n_blocks: int, 127 | middle_layer: nn.Module, 128 | kernel_size: int, 129 | activation: str, 130 | ): 131 | super(ConvNetDoubleLayer, self).__init__() 132 | self.conv_stack1 = ResidualBlockStack( 133 | hidden_size=hidden_size, 134 | n_stacks=n_stacks, 135 | n_blocks=n_blocks, 136 | kernel_size=kernel_size, 137 | activation=activation, 138 | ) 139 | 140 | self.middle_layer = middle_layer 141 | 142 | self.conv_stack2 = ResidualBlockStack( 143 | hidden_size=hidden_size, 144 | n_stacks=n_stacks, 145 | n_blocks=n_blocks, 146 | kernel_size=kernel_size, 147 | activation=activation, 148 | ) 149 | 150 | def forward(self, x): 151 | x = self.conv_stack1(x) 152 | x = self.middle_layer(x) 153 | x = self.conv_stack2(x) 154 | return x 155 | 156 | class ConvNetDouble(nn.Module): 157 | def __init__( 158 | self, 159 | in_channels: int, 160 | out_channels: int, 161 | hidden_size: int, 162 | n_layers: int, 163 | n_stacks: int, 164 | n_blocks: int, 165 | middle_layer: nn.Module, 166 | kernel_size: int, 167 | activation: str, 168 | ): 169 | super(ConvNetDouble, self).__init__() 170 | 171 | self.first_layer = first_conv = nn.Conv1d( 172 | in_channels=in_channels, 173 | out_channels=hidden_size, 174 | kernel_size=kernel_size, 175 | stride=1, 176 | padding=(kernel_size - 1) // 2, 177 | ) 178 | 179 | self.layers = [] 180 | for i in range(n_layers): 181 | self.layers += [ 182 | ConvNetDoubleLayer( 183 | hidden_size=hidden_size, 184 | n_stacks=n_stacks, 185 | n_blocks=n_blocks, 186 | middle_layer=middle_layer, 187 | kernel_size=kernel_size, 188 | activation=activation, 189 | ) 190 | ] 191 | 192 | self.layers = nn.Sequential(*self.layers) 193 | 194 | self.last_layer = nn.Conv1d( 195 | in_channels=hidden_size, 196 | out_channels=out_channels, 197 | kernel_size=kernel_size, 198 | stride=1, 199 | padding=(kernel_size - 1) // 2, 200 | ) 201 | 202 | def forward(self, x): 203 | x = self.first_layer(x) 204 | 205 | x_out = self.layers[0](x) 206 | for layer in self.layers[1:]: 207 | x_out = x_out + layer(x) 208 | 209 | x = self.last_layer(x_out) 210 | return x 211 | 212 | def test(): 213 | x = torch.rand(2, 128, 240) 214 | convnet = ConvNet( 215 | in_channels=128, 216 | out_channels=128, 217 | hidden_size=128, 218 | n_stacks=2, 219 | n_blocks=2, 220 | kernel_size=3, 221 | activation="ReLU", 222 | ) 223 | y = convnet(x) 224 | print(y.shape) 225 | 226 | convnet = ConvNetDouble( 227 | in_channels=128, 228 | out_channels=128, 229 | hidden_size=128, 230 | n_layers=2, 231 | n_stacks=2, 232 | n_blocks=2, 233 | middle_layer=nn.MaxPool1d( 234 | kernel_size=8, 235 | stride=8, 236 | ), 237 | kernel_size=3, 238 | activation="ReLU", 239 | ) 240 | y = convnet(x) 241 | print(y.shape) 242 | 243 | -------------------------------------------------------------------------------- /modules/datamodule.py: -------------------------------------------------------------------------------- 1 | import lightning.pytorch as pl 2 | 3 | import random 4 | from concurrent.futures import ThreadPoolExecutor 5 | 6 | from lhotse import CutSet, load_manifest 7 | from lhotse.dataset.collation import collate_features 8 | from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler 9 | from lhotse.dataset.input_strategies import ( 10 | _get_executor 11 | ) 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | 17 | from typing import Dict, Tuple, List, Type 18 | 19 | from utils.symbol_table import SymbolTable 20 | 21 | from tqdm.auto import tqdm 22 | 23 | import numpy as np 24 | 25 | from modules.mrte import LengthRegulator 26 | 27 | from .tokenizer import HIFIGAN_SR, HIFIGAN_HOP_LENGTH 28 | 29 | 30 | class TokensCollector(): 31 | def __init__(self, symbols_table: str) -> None: 32 | unique_tokens = SymbolTable.from_file(symbols_table).symbols 33 | self.token2idx = {token: idx for idx, 34 | token in enumerate(unique_tokens)} 35 | 36 | def __call__(self, cuts: CutSet) -> (List, torch.Tensor): 37 | 38 | phone_tokens_list = [] 39 | duration_tokens_list = [] 40 | lens = [] 41 | for cut in cuts: 42 | phone_tokens = cut.supervisions[0].custom['phone_tokens'] 43 | duration_tokens = cut.supervisions[0].custom['duration_tokens'] 44 | 45 | phone_tokens_list.append(self.phone2token(phone_tokens)) 46 | duration_tokens_list.append(torch.Tensor(duration_tokens)) 47 | 48 | lens.append(len(phone_tokens)) 49 | 50 | max_len = max(lens) 51 | phone_tokens_list_padded = [] 52 | duration_tokens_list_padded = [] 53 | for i in range(len(phone_tokens_list)): 54 | phone_tokens_list_padded.append(F.pad( 55 | phone_tokens_list[i], (0, max_len - lens[i]), mode='constant', value=0)) 56 | duration_tokens_list_padded.append(F.pad( 57 | duration_tokens_list[i], (0, max_len - lens[i]), mode='constant', value=0)) 58 | 59 | phone_tokens = torch.stack(phone_tokens_list_padded) 60 | duration_tokens = torch.stack( 61 | duration_tokens_list_padded).type(torch.int64) 62 | lens = torch.Tensor(lens).to(dtype=torch.int32) 63 | 64 | return phone_tokens, duration_tokens, lens 65 | 66 | def phone2token(self, phone: List) -> int: 67 | return torch.Tensor( 68 | [self.token2idx[token] for token in phone] 69 | ).type(torch.int64) 70 | 71 | 72 | class TTSDataset(torch.utils.data.Dataset): 73 | def __init__( 74 | self, 75 | spk2cuts: Dict, 76 | ds_path: str, 77 | n_same_spk_samples: int = 10 78 | ): 79 | super().__init__() 80 | self.tokens_collector = TokensCollector( 81 | f'{ds_path}/unique_text_tokens.k2symbols') 82 | self.spk2cuts = spk2cuts 83 | self.n_same_spk_samples = n_same_spk_samples 84 | 85 | def __getitem__(self, cuts: CutSet) -> Dict: 86 | phone_tokens, duration_tokens, tokens_lens = self.tokens_collector( 87 | cuts) 88 | mel_targets, mel_target_lens = collate_features( 89 | cuts, 90 | executor=_get_executor(8, executor_type=ThreadPoolExecutor),) 91 | 92 | # align duration token and mel_target_lens 93 | for i in range(mel_target_lens.shape[0]): 94 | sum_duration = torch.sum(duration_tokens[i]) 95 | assert sum_duration <= mel_target_lens[i] 96 | if sum_duration < mel_target_lens[i]: 97 | mel_target_lens[i] = sum_duration 98 | 99 | max_len = max(mel_target_lens) 100 | mel_targets = mel_targets[:, :max_len, :] 101 | 102 | mel_timbres_list = [] 103 | mel_timbre_lens_list = [] 104 | n_sample = random.randint(2, self.n_same_spk_samples) 105 | for cut in cuts: 106 | same_spk_cuts = self.spk2cuts[cut.supervisions[0].speaker] 107 | same_spk_cuts = same_spk_cuts.sample( 108 | n_cuts=min(n_sample, len(same_spk_cuts))) 109 | 110 | mel_timbres_same_spk, mel_timbre_lens_same_spk = collate_features( 111 | same_spk_cuts, 112 | executor=_get_executor(8, executor_type=ThreadPoolExecutor),) 113 | 114 | mel_timbre = mel_timbres_same_spk[0, :mel_timbre_lens_same_spk[0]] 115 | for i in range(1, mel_timbres_same_spk.shape[0]): 116 | mel_timbre = torch.cat( 117 | [mel_timbre, mel_timbres_same_spk[i, :mel_timbre_lens_same_spk[i]]], dim=0) 118 | mel_timbres_list.append(mel_timbre) 119 | mel_timbre_lens_list.append(mel_timbre.shape[0]) 120 | 121 | mel_timbres_list_cutted = [] 122 | min_mel_timbres_len = min(mel_timbre_lens_list) 123 | for mel_timbre in mel_timbres_list: 124 | mel_timbres_list_cutted.append(mel_timbre[:min_mel_timbres_len, :]) 125 | 126 | mel_timbres = torch.stack(mel_timbres_list_cutted).type(torch.float32) 127 | 128 | batch = { 129 | "phone_tokens": phone_tokens, 130 | "duration_tokens": duration_tokens, 131 | "tokens_lens": tokens_lens, 132 | "mel_targets": mel_targets, 133 | "mel_target_lens": mel_target_lens, 134 | "mel_timbres": mel_timbres, 135 | } 136 | 137 | return batch 138 | 139 | 140 | class MegaPLMDataset(torch.utils.data.Dataset): 141 | def __init__( 142 | self, 143 | spk2cuts: Dict, 144 | ds_path: str, 145 | lr: LengthRegulator, 146 | n_same_spk_samples: int = 10, 147 | vq_bins: int = 1024, 148 | 149 | ): 150 | super().__init__() 151 | self.spk2cuts = spk2cuts 152 | 153 | self.bos = torch.Tensor([vq_bins]) 154 | self.eos = torch.Tensor([vq_bins + 1]) 155 | 156 | self.n_same_spk_samples = n_same_spk_samples 157 | self.lr = lr 158 | 159 | self.tokens_collector = TokensCollector( 160 | f'{ds_path}/unique_text_tokens.k2symbols') 161 | self.ds_path = ds_path 162 | 163 | def read_latent(self, cut) -> Dict: 164 | 165 | id = cut.recording_id 166 | spk = cut.supervisions[0].speaker 167 | 168 | latents = np.load(f'{self.ds_path}/latents/{spk}/{id}.npy', 169 | allow_pickle=True).item() 170 | tc_latent = torch.from_numpy(latents['tc_latent']) 171 | duration_tokens = torch.Tensor( 172 | cut.supervisions[0].custom['duration_tokens']).unsqueeze(0).to(dtype=torch.int32) 173 | tc_latent = self.lr(tc_latent, duration_tokens) 174 | p_code = torch.from_numpy(latents['p_code']) 175 | tc_latent = F.max_pool1d(tc_latent.transpose( 176 | 1, 2), 8, ceil_mode=True).transpose(1, 2) 177 | 178 | return tc_latent, p_code 179 | 180 | def __getitem__(self, cuts_sample: CutSet) -> Dict: 181 | 182 | p_code_spks = [] 183 | tc_latent_spks = [] 184 | lens = [] 185 | 186 | for cut in cuts_sample: 187 | 188 | spk = cut.supervisions[0].speaker 189 | 190 | same_spk_cuts = self.spk2cuts[spk] 191 | same_spk_cuts = same_spk_cuts.sample( 192 | n_cuts=self.n_same_spk_samples) 193 | 194 | tc_latent, p_code = self.read_latent(cut) 195 | 196 | tc_latent_spk = tc_latent[0, ...] 197 | p_code_spk = torch.cat([p_code[0, 0, :]]) 198 | 199 | assert tc_latent_spk.shape[0] == p_code_spk.shape[0] 200 | 201 | for cut_spk in same_spk_cuts: 202 | tc_latent_spk_cat, p_code_spk_cat = self.read_latent(cut_spk) 203 | 204 | tc_latent_spk = torch.cat( 205 | [tc_latent_spk_cat[0, ...], tc_latent_spk], dim=0) 206 | p_code_spk = torch.cat( 207 | [p_code_spk_cat[0, 0, :], p_code_spk], dim=0) 208 | 209 | assert tc_latent_spk.shape[0] == p_code_spk.shape[0] 210 | assert torch.max(p_code_spk) < 1024 211 | 212 | p_code_spk = torch.cat([self.bos, p_code_spk], dim=0) 213 | lens.append(p_code_spk.shape[0] - 1) 214 | 215 | p_code_spks.append(p_code_spk) 216 | tc_latent_spks.append(tc_latent_spk) 217 | 218 | max_len = max(lens) 219 | 220 | # pad 221 | p_code_spks_padded = [] 222 | tc_latent_spks_padded = [] 223 | 224 | for i in range(len(p_code_spks)): 225 | p_code_spks_padded.append(F.pad( 226 | p_code_spks[i], (0, max_len - lens[i]), mode='constant', value=self.eos.item())) 227 | tc_latent_spks_padded.append(F.pad( 228 | tc_latent_spks[i], (0, 0, 0, max_len - lens[i]), mode='constant', value=0)) 229 | 230 | p_code_spks = torch.stack(p_code_spks_padded).type(torch.int64) 231 | tc_latent_spks = torch.stack(tc_latent_spks_padded).type(torch.float32) 232 | lens = torch.Tensor(lens).to(dtype=torch.int32) 233 | 234 | batch = { 235 | "p_codes": p_code_spks, 236 | "tc_latents": tc_latent_spks, 237 | "lens": lens, 238 | } 239 | 240 | return batch 241 | 242 | 243 | class MegaADMDataset(torch.utils.data.Dataset): 244 | def __init__(self, ds_path: str): 245 | self.tokens_collector = TokensCollector( 246 | f'{ds_path}/unique_text_tokens.k2symbols') 247 | self.ds_path = ds_path 248 | self.max_duration_token = 128 249 | 250 | def __getitem__(self, cuts_sample: CutSet) -> Dict: 251 | duration_token_list = [] 252 | tc_latent_list = [] 253 | lens = [] 254 | for cut in cuts_sample: 255 | spk = cut.supervisions[0].speaker 256 | id = cut.recording_id 257 | 258 | duration_tokens = cut.supervisions[0].custom['duration_tokens'] 259 | if np.max(duration_tokens) >= self.max_duration_token: 260 | continue 261 | 262 | duration_tokens = torch.Tensor( 263 | duration_tokens).to(dtype=torch.int32) 264 | 265 | latents = np.load(f'{self.ds_path}/latents/{spk}/{id}.npy', 266 | allow_pickle=True).item() 267 | tc_latent = torch.from_numpy(latents['tc_latent'])[0] 268 | assert tc_latent.shape[0] == duration_tokens.shape[0] 269 | 270 | duration_token_list.append(duration_tokens) 271 | tc_latent_list.append(tc_latent) 272 | lens.append(duration_tokens.shape[0]) 273 | 274 | max_len = max(lens) 275 | 276 | # pad 277 | duration_token_list_padded = [] 278 | tc_latent_list_padded = [] 279 | for i in range(len(duration_token_list)): 280 | duration_token_list_padded.append(F.pad( 281 | duration_token_list[i], (1, max_len - lens[i]), mode='constant', value=0)) 282 | tc_latent_list_padded.append(F.pad( 283 | tc_latent_list[i], (0, 0, 0, max_len - lens[i]), mode='constant', value=0)) 284 | 285 | duration_tokens = torch.stack(duration_token_list_padded).type( 286 | torch.float32).unsqueeze(-1) 287 | tc_latents = torch.stack(tc_latent_list_padded).type(torch.float32) 288 | lens = torch.Tensor(lens).to(dtype=torch.int32) 289 | 290 | batch = { 291 | "duration_tokens": duration_tokens, 292 | "tc_latents": tc_latents, 293 | "lens": lens, 294 | } 295 | 296 | return batch 297 | 298 | 299 | def make_spk_cutset(cuts: CutSet) -> Dict[str, CutSet]: 300 | spk2cuts = {} 301 | for cut in tqdm(cuts, desc="Making spk2cuts"): 302 | spk = cut.supervisions[0].speaker 303 | if spk not in spk2cuts: 304 | spk2cuts[spk] = cuts.filter( 305 | lambda c: c.supervisions[0].speaker == spk).to_eager() 306 | 307 | return spk2cuts 308 | 309 | 310 | class TTSDataModule(pl.LightningDataModule): 311 | def __init__( 312 | self, 313 | ds_path: str = 'data', 314 | max_duration_batch: float = 80, 315 | min_duration: float = 1.5, 316 | max_duration: float = 20, 317 | max_n_cuts: int = 3, 318 | num_buckets: int = 2, 319 | num_workers: int = 4, 320 | dataset: str = 'TTSDataset', 321 | **kwargs 322 | ) -> None: 323 | super().__init__() 324 | self.save_hyperparameters(ignore=['class_path']) 325 | 326 | def setup(self, stage: str = None) -> None: 327 | 328 | def filter_duration(c): 329 | if c.duration < self.hparams.min_duration or c.duration > self.hparams.max_duration: 330 | return False 331 | return True 332 | 333 | seed = random.randint(0, 100000) 334 | cs_train = load_manifest(f'{self.hparams.ds_path}/cuts_train.jsonl.gz') 335 | cs_train = cs_train.filter(filter_duration) 336 | 337 | if not self.hparams.dataset == 'MegaADMDataset': 338 | spk2cuts = make_spk_cutset(cs_train) 339 | 340 | if self.hparams.dataset == 'TTSDataset' or self.hparams.dataset == 'MegaADMDataset': 341 | if self.hparams.dataset == 'TTSDataset': 342 | dataset = TTSDataset(spk2cuts, self.hparams.ds_path, 10) 343 | else: 344 | dataset = MegaADMDataset(self.hparams.ds_path) 345 | 346 | sampler = DynamicBucketingSampler( 347 | cs_train, 348 | max_duration=self.hparams.max_duration_batch, 349 | shuffle=True, 350 | num_buckets=self.hparams.num_buckets, 351 | drop_last=False, 352 | seed=seed, 353 | ) 354 | elif self.hparams.dataset == 'MegaPLMDataset': 355 | lr = LengthRegulator( 356 | HIFIGAN_HOP_LENGTH, 16000, (HIFIGAN_HOP_LENGTH / HIFIGAN_SR * 1000)) 357 | dataset = MegaPLMDataset( 358 | spk2cuts, self.hparams.ds_path, lr, 10, 1024) 359 | 360 | sampler = SimpleCutSampler( 361 | cs_train, 362 | max_cuts=self.hparams.max_n_cuts, 363 | shuffle=True, 364 | drop_last=False, 365 | seed=seed, 366 | ) 367 | else: 368 | raise ValueError(f'Unsupported dataset: {self.hparams.dataset}') 369 | 370 | self.train_dl = DataLoader( 371 | dataset, 372 | batch_size=None, 373 | num_workers=self.hparams.num_workers, 374 | pin_memory=True, 375 | sampler=sampler, 376 | ) 377 | 378 | cs_valid = load_manifest(f'{self.hparams.ds_path}/cuts_valid.jsonl.gz') 379 | cs_valid = cs_valid.filter(filter_duration) 380 | 381 | if not self.hparams.dataset == 'MegaADMDataset': 382 | spk2cuts = make_spk_cutset(cs_valid) 383 | 384 | if self.hparams.dataset == 'TTSDataset' or self.hparams.dataset == 'MegaADMDataset': 385 | sampler = DynamicBucketingSampler( 386 | cs_valid, 387 | max_duration=self.hparams.max_duration_batch, 388 | shuffle=True, 389 | num_buckets=self.hparams.num_buckets, 390 | drop_last=False, 391 | seed=seed, 392 | ) 393 | elif self.hparams.dataset == 'MegaPLMDataset': 394 | sampler = SimpleCutSampler( 395 | cs_valid, 396 | max_cuts=self.hparams.max_n_cuts, 397 | shuffle=True, 398 | drop_last=False, 399 | seed=seed, 400 | ) 401 | else: 402 | raise ValueError(f'Unsupported dataset: {self.hparams.dataset}') 403 | 404 | self.valid_dl = DataLoader( 405 | dataset, 406 | batch_size=None, 407 | num_workers=self.hparams.num_workers, 408 | sampler=sampler, 409 | ) 410 | 411 | def train_dataloader(self) -> DataLoader: 412 | return self.train_dl 413 | 414 | def val_dataloader(self) -> DataLoader: 415 | return self.valid_dl 416 | 417 | def test_dataloader(self) -> DataLoader: 418 | return None 419 | 420 | 421 | def test(): 422 | 423 | cs_valid = load_manifest("data/ds/cuts_valid.jsonl.gz") 424 | spk2cuts = make_spk_cutset(cs_valid) 425 | 426 | valid_dl = DataLoader( 427 | TTSDataset(spk2cuts, 'data/ds', 10), 428 | batch_size=None, 429 | num_workers=0, 430 | sampler=DynamicBucketingSampler( 431 | cs_valid, 432 | max_duration=10, 433 | shuffle=True, 434 | num_buckets=5, 435 | drop_last=False, 436 | seed=20000 437 | ), 438 | ) 439 | 440 | # for batch in valid_dl: 441 | # print(batch['phone_tokens'].shape) 442 | # print(batch['duration_tokens'].shape) 443 | # print(batch['tokens_lens'].shape) 444 | # print(batch['mel_targets'].shape) 445 | # print(batch['mel_target_lens'].shape) 446 | # print(batch['mel_timbres'].shape) 447 | # break 448 | 449 | lr = LengthRegulator(240, 16000, 15) 450 | 451 | valid_dl = DataLoader( 452 | MegaPLMDataset(spk2cuts, 'data/ds', lr, 10, 453 | (HIFIGAN_HOP_LENGTH / HIFIGAN_SR * 1000)), 454 | batch_size=None, 455 | num_workers=0, 456 | sampler=SimpleCutSampler( 457 | cs_valid, 458 | max_cuts=3, 459 | shuffle=True, 460 | drop_last=True, 461 | seed=20000, 462 | ), 463 | ) 464 | 465 | # for batch in valid_dl: 466 | # print(batch['p_codes'].shape) 467 | # print(batch['tc_latents'].shape) 468 | # print(batch['lens'].shape) 469 | 470 | cs = cs_valid + load_manifest("data/ds/cuts_train.jsonl.gz") 471 | 472 | valid_dl = DataLoader( 473 | MegaADMDataset('data/ds'), 474 | batch_size=None, 475 | num_workers=0, 476 | sampler=DynamicBucketingSampler( 477 | cs, 478 | max_duration=20, 479 | shuffle=True, 480 | num_buckets=5, 481 | drop_last=False, 482 | seed=20000 483 | ), 484 | ) 485 | 486 | for batch in valid_dl: 487 | pass 488 | # print(batch['duration_tokens'].shape) 489 | # print(batch['tc_latents'].shape) 490 | # print(batch['lens'].shape) 491 | -------------------------------------------------------------------------------- /modules/dscrm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy of SyntaSpeech 3 | https://github.com/yerfor/SyntaSpeech/blob/main/modules/tts/syntaspeech/multi_window_disc.py 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SingleWindowDisc(nn.Module): 12 | def __init__(self, time_length, freq_length=80, kernel=(3, 3), c_in=1, hidden_size=128): 13 | super().__init__() 14 | padding = (kernel[0] // 2, kernel[1] // 2) 15 | self.model = nn.ModuleList([ 16 | nn.Sequential(*[ 17 | nn.Conv2d(c_in, hidden_size, kernel, (2, 2), padding), 18 | nn.LeakyReLU(0.2, inplace=True), 19 | nn.Dropout2d(0.25), 20 | nn.BatchNorm2d(hidden_size, 0.8) 21 | ]), 22 | nn.Sequential(*[ 23 | nn.Conv2d(hidden_size, hidden_size, kernel, (2, 2), padding), 24 | nn.LeakyReLU(0.2, inplace=True), 25 | nn.Dropout2d(0.25), 26 | nn.BatchNorm2d(hidden_size, 0.8) 27 | ]), 28 | nn.Sequential(*[ 29 | nn.Conv2d(hidden_size, hidden_size, kernel, (2, 2), padding), 30 | nn.LeakyReLU(0.2, inplace=True), 31 | nn.Dropout2d(0.25), 32 | ]), 33 | ]) 34 | ds_size = (time_length // 2 ** 3, (freq_length + 7) // 2 ** 3) 35 | self.adv_layer = nn.Linear(hidden_size * ds_size[0] * ds_size[1], 1) 36 | 37 | def forward(self, x): 38 | """ 39 | :param x: [B, C, T, n_bins] 40 | :return: validity: [B, 1], h: List of hiddens 41 | """ 42 | h = [] 43 | for l in self.model: 44 | x = l(x) 45 | h.append(x) 46 | x = x.view(x.shape[0], -1) 47 | validity = self.adv_layer(x) # [B, 1] 48 | return validity, h 49 | 50 | 51 | class MultiWindowDiscriminator(nn.Module): 52 | def __init__(self, time_lengths, freq_length=80, kernel=(3, 3), c_in=1, hidden_size=128): 53 | super(MultiWindowDiscriminator, self).__init__() 54 | self.win_lengths = time_lengths 55 | self.discriminators = nn.ModuleList() 56 | 57 | for time_length in time_lengths: 58 | self.discriminators += [SingleWindowDisc( 59 | time_length, freq_length, kernel, c_in=c_in, hidden_size=hidden_size)] 60 | 61 | def forward(self, x, x_len, start_frames_wins=None): 62 | ''' 63 | Args: 64 | x (tensor): input mel, (B, c_in, T, n_bins). 65 | x_length (tensor): len of per mel. (B,). 66 | 67 | Returns: 68 | tensor : (B). 69 | ''' 70 | validity = [] 71 | if start_frames_wins is None: 72 | start_frames_wins = [None] * len(self.discriminators) 73 | h = [] 74 | for i, start_frames in zip(range(len(self.discriminators)), start_frames_wins): 75 | x_clip, start_frames = self.clip( 76 | x, x_len, self.win_lengths[i], start_frames) # (B, win_length, C) 77 | start_frames_wins[i] = start_frames 78 | if x_clip is None: 79 | continue 80 | x_clip, h_ = self.discriminators[i](x_clip) 81 | h += h_ 82 | validity.append(x_clip) 83 | if len(validity) != len(self.discriminators): 84 | return None, start_frames_wins, h 85 | validity = sum(validity) # [B] 86 | return validity, start_frames_wins, h 87 | 88 | def clip(self, x, x_len, win_length, start_frames=None): 89 | '''Ramdom clip x to win_length. 90 | Args: 91 | x (tensor) : (B, c_in, T, n_bins). 92 | cond (tensor) : (B, T, H). 93 | x_len (tensor) : (B,). 94 | win_length (int): target clip length 95 | 96 | Returns: 97 | (tensor) : (B, c_in, win_length, n_bins). 98 | 99 | ''' 100 | T_start = 0 101 | T_end = x_len.max() - win_length 102 | assert T_end >= 0 103 | T_end = T_end.item() 104 | if start_frames is None: 105 | start_frame = np.random.randint(low=T_start, high=T_end + 1) 106 | start_frames = [start_frame] * x.size(0) 107 | else: 108 | start_frame = start_frames[0] 109 | x_batch = x[:, :, start_frame: start_frame + win_length] 110 | return x_batch, start_frames 111 | 112 | 113 | class Discriminator(nn.Module): 114 | def __init__(self, time_lengths=[32, 64, 128], freq_length=80, kernel=(3, 3), c_in=1, 115 | hidden_size=128): 116 | super(Discriminator, self).__init__() 117 | self.time_lengths = time_lengths 118 | self.discriminator = MultiWindowDiscriminator( 119 | freq_length=freq_length, 120 | time_lengths=time_lengths, 121 | kernel=kernel, 122 | c_in=c_in, hidden_size=hidden_size 123 | ) 124 | 125 | def forward(self, x, start_frames_wins=None): 126 | """ 127 | 128 | :param x: [B, T, 80] 129 | :param return_y_only: 130 | :return: 131 | """ 132 | if len(x.shape) == 3: 133 | x = x[:, None, :, :] # [B,1,T,80] 134 | x_len = x.sum([1, -1]).ne(0).int().sum([-1]) 135 | ret = {'y_c': None, 'y': None} 136 | ret['y'], start_frames_wins, ret['h'] = self.discriminator( 137 | x, x_len, start_frames_wins=start_frames_wins) 138 | 139 | ret['start_frames_wins'] = start_frames_wins 140 | return ret 141 | 142 | 143 | def test(): 144 | x = torch.randn(2, 1, 1030, 80) 145 | model = Discriminator() 146 | ret = model(x) 147 | print(ret['y']) 148 | print(ret['start_frames_wins']) 149 | print(ret['h'][0].shape) 150 | -------------------------------------------------------------------------------- /modules/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class TokenEmbedding(nn.Module): 22 | def __init__( 23 | self, 24 | dim_model: int, 25 | vocab_size: int, 26 | dropout: float = 0.0, 27 | ): 28 | super().__init__() 29 | 30 | self.vocab_size = vocab_size 31 | self.dim_model = dim_model 32 | 33 | self.dropout = torch.nn.Dropout(p=dropout) 34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) 35 | 36 | @property 37 | def weight(self) -> torch.Tensor: 38 | return self.word_embeddings.weight 39 | 40 | def embedding(self, index: int) -> torch.Tensor: 41 | return self.word_embeddings.weight[index : index + 1] 42 | 43 | def forward(self, x: torch.Tensor): 44 | X = self.word_embeddings(x) 45 | X = self.dropout(X) 46 | 47 | return X 48 | 49 | 50 | class SinePositionalEmbedding(nn.Module): 51 | def __init__( 52 | self, 53 | dim_model: int, 54 | dropout: float = 0.0, 55 | scale: bool = False, 56 | alpha: bool = False, 57 | ): 58 | super().__init__() 59 | self.dim_model = dim_model 60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0 61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 62 | self.dropout = torch.nn.Dropout(p=dropout) 63 | 64 | self.reverse = False 65 | self.pe = None 66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 67 | 68 | def extend_pe(self, x, offset = 0): 69 | """Reset the positional encodings.""" 70 | x_size = x.size(1) + offset 71 | if self.pe is not None: 72 | if self.pe.size(1) >= x_size: 73 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 74 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 75 | return 76 | pe = torch.zeros(x_size, self.dim_model) 77 | if self.reverse: 78 | position = torch.arange( 79 | x_size - 1, -1, -1.0, dtype=torch.float32 80 | ).unsqueeze(1) 81 | else: 82 | position = torch.arange( 83 | 0, x_size, dtype=torch.float32 84 | ).unsqueeze(1) 85 | div_term = torch.exp( 86 | torch.arange(0, self.dim_model, 2, dtype=torch.float32) 87 | * -(math.log(10000.0) / self.dim_model) 88 | ) 89 | pe[:, 0::2] = torch.sin(position * div_term) 90 | pe[:, 1::2] = torch.cos(position * div_term) 91 | pe = pe.unsqueeze(0) 92 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 93 | 94 | def forward(self, x: torch.Tensor, offset : int = 0) -> torch.Tensor: 95 | self.extend_pe(x, offset) 96 | output = x.unsqueeze(-1) if x.ndim == 2 else x 97 | output = output * self.x_scale + self.alpha * self.pe[:, offset : x.size(1) + offset] 98 | return self.dropout(output) 99 | -------------------------------------------------------------------------------- /modules/mrte.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange 6 | 7 | from typing import List 8 | 9 | from modules.embedding import TokenEmbedding, SinePositionalEmbedding 10 | from modules.convnet import ConvNetDouble, ConvNet 11 | from modules.transformer import (TransformerEncoder, 12 | TransformerEncoderLayer, 13 | MultiHeadAttention) 14 | from utils.utils import make_attn_mask 15 | 16 | from modules.tokenizer import ( 17 | HIFIGAN_SR, 18 | HIFIGAN_MEL_CHANNELS, 19 | HIFIGAN_HOP_LENGTH 20 | ) 21 | 22 | 23 | def create_alignment(base_mat, duration_tokens): 24 | N, L = duration_tokens.shape 25 | for i in range(N): 26 | count = 0 27 | for j in range(L): 28 | for k in range(duration_tokens[i][j]): 29 | base_mat[i][count+k][j] = 1 30 | count = count + duration_tokens[i][j] 31 | return base_mat 32 | 33 | 34 | class LengthRegulator(nn.Module): 35 | """ Length Regulator from FastSpeech """ 36 | 37 | def __init__(self, mel_frames, sample_rate, duration_token_ms): 38 | super(LengthRegulator, self).__init__() 39 | 40 | assert (mel_frames / sample_rate * 1000 / duration_token_ms) == 1 41 | 42 | def forward( 43 | self, 44 | x: torch.Tensor, # (B, T, D) 45 | duration_tokens: torch.Tensor, # (B, T) int for duration 46 | mel_max_length=None 47 | ): 48 | 49 | bsz, input_len, _ = x.size() 50 | 51 | expand_max_len = torch.max(torch.sum(duration_tokens, -1), -1)[0].int() 52 | 53 | alignment = torch.zeros(bsz, expand_max_len, input_len).numpy() 54 | alignment = create_alignment(alignment, duration_tokens.cpu().numpy()) 55 | alignment = torch.from_numpy(alignment).to(x.device) 56 | output = alignment @ x 57 | if mel_max_length: 58 | output = F.pad( 59 | output, (0, 0, 0, mel_max_length-output.size(1), 0, 0)) 60 | return output 61 | 62 | 63 | class MRTE(nn.Module): 64 | def __init__( 65 | self, 66 | mel_bins: int = HIFIGAN_MEL_CHANNELS, 67 | mel_frames: int = HIFIGAN_HOP_LENGTH, 68 | mel_activation: str = 'ReLU', 69 | mel_kernel_size: int = 3, 70 | mel_stride: int = 16, 71 | mel_n_layer: int = 5, 72 | mel_n_stack: int = 5, 73 | mel_n_block: int = 2, 74 | content_ff_dim: int = 1024, 75 | content_n_heads: int = 2, 76 | content_n_layers: int = 8, 77 | hidden_size: int = 512, 78 | duration_token_ms: float = ( 79 | HIFIGAN_HOP_LENGTH / HIFIGAN_SR * 1000), 80 | phone_vocab_size: int = 320, 81 | dropout: float = 0.1, 82 | sample_rate: int = HIFIGAN_SR, 83 | ): 84 | super(MRTE, self).__init__() 85 | 86 | self.n_heads = content_n_heads 87 | self.mel_bins = mel_bins 88 | self.hidden_size = hidden_size 89 | 90 | self.phone_embedding = TokenEmbedding( 91 | dim_model=hidden_size, 92 | vocab_size=phone_vocab_size, 93 | dropout=dropout, 94 | ) 95 | 96 | self.phone_pos_embedding = SinePositionalEmbedding( 97 | dim_model=hidden_size, 98 | dropout=dropout, 99 | ) 100 | 101 | self.mel_encoder_middle_layer = nn.Conv1d( 102 | in_channels=hidden_size, 103 | out_channels=hidden_size, 104 | kernel_size=mel_stride + 1, 105 | stride=mel_stride, 106 | padding=(mel_stride) // 2, 107 | ) 108 | self.mel_encoder = ConvNetDouble( 109 | in_channels=mel_bins, 110 | out_channels=hidden_size, 111 | hidden_size=hidden_size, 112 | n_layers=mel_n_layer, 113 | n_stacks=mel_n_stack, 114 | n_blocks=mel_n_block, 115 | middle_layer=self.mel_encoder_middle_layer, 116 | kernel_size=mel_kernel_size, 117 | activation=mel_activation, 118 | ) 119 | 120 | self.phone_encoder = TransformerEncoder( 121 | TransformerEncoderLayer( 122 | dim=hidden_size, 123 | ff_dim=content_ff_dim, 124 | conv_ff=True, 125 | n_heads=content_n_heads, 126 | dropout=dropout, 127 | ), 128 | num_layers=content_n_layers, 129 | ) 130 | 131 | self.mha = MultiHeadAttention( 132 | qkv_dim=hidden_size, 133 | n_heads=1, 134 | dropout=dropout, 135 | ) 136 | self.norm = nn.LayerNorm(hidden_size) 137 | self.activation = nn.ReLU() 138 | 139 | self.length_regulator = LengthRegulator( 140 | mel_frames, sample_rate, duration_token_ms) 141 | 142 | 143 | # self.test_pllm = TransformerEncoder( 144 | # TransformerEncoderLayer( 145 | # dim=1024, 146 | # ff_dim=1024, 147 | # conv_ff=True, 148 | # n_heads=16, 149 | # dropout=dropout, 150 | # ), 151 | # num_layers=12, 152 | # ) 153 | 154 | def tc_latent( 155 | self, 156 | phone: torch.Tensor, # (B, T) 157 | mel: torch.Tensor, # (B, T, mel_bins) 158 | ): 159 | phone_emb = self.phone_embedding(phone) 160 | phone_pos = self.phone_pos_embedding(phone_emb) 161 | 162 | mel = rearrange(mel, 'B T D -> B D T') 163 | mel_context = self.mel_encoder(mel) 164 | mel_context = rearrange(mel_context, 'B D T -> B T D') 165 | phone_x = self.phone_encoder(phone_pos) 166 | 167 | tc_latent = self.mha(phone_x, kv=mel_context) 168 | tc_latent = self.norm(tc_latent) 169 | tc_latent = self.activation(tc_latent) 170 | 171 | return tc_latent 172 | 173 | def forward( 174 | self, 175 | duration_tokens: torch.Tensor, # (B, T) 176 | phone: torch.Tensor, # (B, T) 177 | phone_lens: torch.Tensor, # (B,) 178 | mel: torch.Tensor, # (B, T, mel_bins) 179 | ): 180 | tc_latent = self.tc_latent(phone, phone_lens, mel) 181 | 182 | out = self.length_regulator(tc_latent, duration_tokens) 183 | return out 184 | 185 | 186 | def test(): 187 | lr_in = torch.randn(2, 10, 128) 188 | lr = LengthRegulator(240, 16000, 15) 189 | 190 | duration_tokens = torch.tensor( 191 | [[1, 2, 3, 4], [1, 2, 3, 5]]).to(dtype=torch.int32) 192 | 193 | out = lr(lr_in, duration_tokens) 194 | assert out.shape == (2, 11, 128) 195 | 196 | mrte = MRTE( 197 | mel_bins = HIFIGAN_MEL_CHANNELS, 198 | mel_frames = HIFIGAN_HOP_LENGTH, 199 | ff_dim = 1024, 200 | n_heads = 2, 201 | n_layers = 8, 202 | hidden_size = 512, 203 | activation = 'ReLU', 204 | kernel_size = 3, 205 | stride = 16, 206 | n_stacks = 5, 207 | n_blocks = 2, 208 | duration_token_ms = ( 209 | HIFIGAN_HOP_LENGTH / HIFIGAN_SR * 1000), 210 | phone_vocab_size = 320, 211 | dropout = 0.1, 212 | sample_rate = HIFIGAN_SR, 213 | ) 214 | mrte = mrte.to('cuda') 215 | 216 | duration_tokens = torch.tensor([[1, 2, 3, 4], [1, 1, 1, 2]]).to( 217 | dtype=torch.int32).to('cuda') 218 | 219 | t = torch.randint(0, 320, (2, 10)).to(dtype=torch.int64).to('cuda') 220 | tl = torch.tensor([6, 10]).to(dtype=torch.int64).to('cuda') 221 | m = torch.randn(2, 347, HIFIGAN_MEL_CHANNELS).to('cuda') 222 | 223 | out = mrte(duration_tokens, t, tl, m) 224 | -------------------------------------------------------------------------------- /modules/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .vq import QuantizedResult, ResidualVectorQuantizer 9 | -------------------------------------------------------------------------------- /modules/quantization/ac.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Arithmetic coder.""" 8 | 9 | import io 10 | import math 11 | import random 12 | import typing as tp 13 | import torch 14 | 15 | from ..binary import BitPacker, BitUnpacker 16 | 17 | 18 | def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int, 19 | roundoff: float = 1e-8, min_range: int = 2, 20 | check: bool = True) -> torch.Tensor: 21 | """Turn the given PDF into a quantized CDF that splits 22 | [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional 23 | to the PDF. 24 | 25 | Args: 26 | pdf (torch.Tensor): probability distribution, shape should be `[N]`. 27 | total_range_bits (int): see `ArithmeticCoder`, the typical range we expect 28 | during the coding process is `[0, 2 ** total_range_bits - 1]`. 29 | roundoff (float): will round the pdf up to that level to remove difference coming 30 | from e.g. evaluating the Language Model on different architectures. 31 | min_range (int): minimum range width. Should always be at least 2 for numerical 32 | stability. Use this to avoid pathological behavior is a value 33 | that is expected to be rare actually happens in real life. 34 | check (bool): if True, checks that nothing bad happened, can be deactivated for speed. 35 | """ 36 | pdf = pdf.detach() 37 | if roundoff: 38 | pdf = (pdf / roundoff).floor() * roundoff 39 | # interpolate with uniform distribution to achieve desired minimum probability. 40 | total_range = 2 ** total_range_bits 41 | cardinality = len(pdf) 42 | alpha = min_range * cardinality / total_range 43 | assert alpha <= 1, "you must reduce min_range" 44 | ranges = (((1 - alpha) * total_range) * pdf).floor().long() 45 | ranges += min_range 46 | quantized_cdf = torch.cumsum(ranges, dim=-1) 47 | if min_range < 2: 48 | raise ValueError("min_range must be at least 2.") 49 | if check: 50 | assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] 51 | if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: 52 | raise ValueError("You must increase your total_range_bits.") 53 | return quantized_cdf 54 | 55 | 56 | class ArithmeticCoder: 57 | """ArithmeticCoder, 58 | Let us take a distribution `p` over `N` symbols, and assume we have a stream 59 | of random variables `s_t` sampled from `p`. Let us assume that we have a budget 60 | of `B` bits that we can afford to write on device. There are `2**B` possible numbers, 61 | corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single 62 | sequence `(s_t)` by doing the following: 63 | 64 | 1) Initialize the current range to` [0 ** 2 B - 1]`. 65 | 2) For each time step t, split the current range into contiguous chunks, 66 | one for each possible outcome, with size roughly proportional to `p`. 67 | For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks 68 | would be `{[0, 2], [3, 3]}`. 69 | 3) Select the chunk corresponding to `s_t`, and replace the current range with this. 70 | 4) When done encoding all the values, just select any value remaining in the range. 71 | 72 | You will notice that this procedure can fail: for instance if at any point in time 73 | the range is smaller than `N`, then we can no longer assign a non-empty chunk to each 74 | possible outcome. Intuitively, the more likely a value is, the less the range width 75 | will reduce, and the longer we can go on encoding values. This makes sense: for any efficient 76 | coding scheme, likely outcomes would take less bits, and more of them can be coded 77 | with a fixed budget. 78 | 79 | In practice, we do not know `B` ahead of time, but we have a way to inject new bits 80 | when the current range decreases below a given limit (given by `total_range_bits`), without 81 | having to redo all the computations. If we encode mostly likely values, we will seldom 82 | need to inject new bits, but a single rare value can deplete our stock of entropy! 83 | 84 | In this explanation, we assumed that the distribution `p` was constant. In fact, the present 85 | code works for any sequence `(p_t)` possibly different for each timestep. 86 | We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller 87 | the KL between the true distribution and `p_t`, the most efficient the coding will be. 88 | 89 | Args: 90 | fo (IO[bytes]): file-like object to which the bytes will be written to. 91 | total_range_bits (int): the range `M` described above is `2 ** total_range_bits. 92 | Any time the current range width fall under this limit, new bits will 93 | be injected to rescale the initial range. 94 | """ 95 | 96 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): 97 | assert total_range_bits <= 30 98 | self.total_range_bits = total_range_bits 99 | self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. 100 | self.low: int = 0 101 | self.high: int = 0 102 | self.max_bit: int = -1 103 | self._dbg: tp.List[tp.Any] = [] 104 | self._dbg2: tp.List[tp.Any] = [] 105 | 106 | @property 107 | def delta(self) -> int: 108 | """Return the current range width.""" 109 | return self.high - self.low + 1 110 | 111 | def _flush_common_prefix(self): 112 | # If self.low and self.high start with the sames bits, 113 | # those won't change anymore as we always just increase the range 114 | # by powers of 2, and we can flush them out to the bit stream. 115 | assert self.high >= self.low, (self.low, self.high) 116 | assert self.high < 2 ** (self.max_bit + 1) 117 | while self.max_bit >= 0: 118 | b1 = self.low >> self.max_bit 119 | b2 = self.high >> self.max_bit 120 | if b1 == b2: 121 | self.low -= (b1 << self.max_bit) 122 | self.high -= (b1 << self.max_bit) 123 | assert self.high >= self.low, (self.high, self.low, self.max_bit) 124 | assert self.low >= 0 125 | self.max_bit -= 1 126 | self.packer.push(b1) 127 | else: 128 | break 129 | 130 | def push(self, symbol: int, quantized_cdf: torch.Tensor): 131 | """Push the given symbol on the stream, flushing out bits 132 | if possible. 133 | 134 | Args: 135 | symbol (int): symbol to encode with the AC. 136 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` 137 | to build this from your pdf estimate. 138 | """ 139 | while self.delta < 2 ** self.total_range_bits: 140 | self.low *= 2 141 | self.high = self.high * 2 + 1 142 | self.max_bit += 1 143 | 144 | range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() 145 | range_high = quantized_cdf[symbol].item() - 1 146 | effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) 147 | effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) 148 | assert self.low <= self.high 149 | self.high = self.low + effective_high 150 | self.low = self.low + effective_low 151 | assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) 152 | self._dbg.append((self.low, self.high)) 153 | self._dbg2.append((self.low, self.high)) 154 | outs = self._flush_common_prefix() 155 | assert self.low <= self.high 156 | assert self.max_bit >= -1 157 | assert self.max_bit <= 61, self.max_bit 158 | return outs 159 | 160 | def flush(self): 161 | """Flush the remaining information to the stream. 162 | """ 163 | while self.max_bit >= 0: 164 | b1 = (self.low >> self.max_bit) & 1 165 | self.packer.push(b1) 166 | self.max_bit -= 1 167 | self.packer.flush() 168 | 169 | 170 | class ArithmeticDecoder: 171 | """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. 172 | 173 | Note that this must be called with **exactly** the same parameters and sequence 174 | of quantized cdf as the arithmetic encoder or the wrong values will be decoded. 175 | 176 | If the AC encoder current range is [L, H], with `L` and `H` having the some common 177 | prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. 178 | For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside 179 | `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained 180 | for a specific sequence of symbols and a binary-search allows us to decode those symbols. 181 | At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, 182 | and we will need to read new bits from the stream and repeat the process. 183 | 184 | """ 185 | def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): 186 | self.total_range_bits = total_range_bits 187 | self.low: int = 0 188 | self.high: int = 0 189 | self.current: int = 0 190 | self.max_bit: int = -1 191 | self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. 192 | # Following is for debugging 193 | self._dbg: tp.List[tp.Any] = [] 194 | self._dbg2: tp.List[tp.Any] = [] 195 | self._last: tp.Any = None 196 | 197 | @property 198 | def delta(self) -> int: 199 | return self.high - self.low + 1 200 | 201 | def _flush_common_prefix(self): 202 | # Given the current range [L, H], if both have a common prefix, 203 | # we know we can remove it from our representation to avoid handling large numbers. 204 | while self.max_bit >= 0: 205 | b1 = self.low >> self.max_bit 206 | b2 = self.high >> self.max_bit 207 | if b1 == b2: 208 | self.low -= (b1 << self.max_bit) 209 | self.high -= (b1 << self.max_bit) 210 | self.current -= (b1 << self.max_bit) 211 | assert self.high >= self.low 212 | assert self.low >= 0 213 | self.max_bit -= 1 214 | else: 215 | break 216 | 217 | def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: 218 | """Pull a symbol, reading as many bits from the stream as required. 219 | This returns `None` when the stream has been exhausted. 220 | 221 | Args: 222 | quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` 223 | to build this from your pdf estimate. This must be **exatly** 224 | the same cdf as the one used at encoding time. 225 | """ 226 | while self.delta < 2 ** self.total_range_bits: 227 | bit = self.unpacker.pull() 228 | if bit is None: 229 | return None 230 | self.low *= 2 231 | self.high = self.high * 2 + 1 232 | self.current = self.current * 2 + bit 233 | self.max_bit += 1 234 | 235 | def bin_search(low_idx: int, high_idx: int): 236 | # Binary search is not just for coding interviews :) 237 | if high_idx < low_idx: 238 | raise RuntimeError("Binary search failed") 239 | mid = (low_idx + high_idx) // 2 240 | range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 241 | range_high = quantized_cdf[mid].item() - 1 242 | effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits)))) 243 | effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits)))) 244 | low = effective_low + self.low 245 | high = effective_high + self.low 246 | if self.current >= low: 247 | if self.current <= high: 248 | return (mid, low, high, self.current) 249 | else: 250 | return bin_search(mid + 1, high_idx) 251 | else: 252 | return bin_search(low_idx, mid - 1) 253 | 254 | self._last = (self.low, self.high, self.current, self.max_bit) 255 | sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) 256 | self._dbg.append((self.low, self.high, self.current)) 257 | self._flush_common_prefix() 258 | self._dbg2.append((self.low, self.high, self.current)) 259 | 260 | return sym 261 | 262 | 263 | def test(): 264 | torch.manual_seed(1234) 265 | random.seed(1234) 266 | for _ in range(4): 267 | pdfs = [] 268 | cardinality = random.randrange(4000) 269 | steps = random.randrange(100, 500) 270 | fo = io.BytesIO() 271 | encoder = ArithmeticCoder(fo) 272 | symbols = [] 273 | for step in range(steps): 274 | pdf = torch.softmax(torch.randn(cardinality), dim=0) 275 | pdfs.append(pdf) 276 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) 277 | symbol = torch.multinomial(pdf, 1).item() 278 | symbols.append(symbol) 279 | encoder.push(symbol, q_cdf) 280 | encoder.flush() 281 | 282 | fo.seek(0) 283 | decoder = ArithmeticDecoder(fo) 284 | for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): 285 | q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) 286 | decoded_symbol = decoder.pull(q_cdf) 287 | assert decoded_symbol == symbol, idx 288 | assert decoder.pull(torch.zeros(1)) is None 289 | 290 | 291 | if __name__ == "__main__": 292 | test() 293 | -------------------------------------------------------------------------------- /modules/quantization/core_vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # This implementation is inspired from 8 | # https://github.com/lucidrains/vector-quantize-pytorch 9 | # which is released under MIT License. Hereafter, the original license: 10 | # MIT License 11 | # 12 | # Copyright (c) 2020 Phil Wang 13 | # 14 | # Permission is hereby granted, free of charge, to any person obtaining a copy 15 | # of this software and associated documentation files (the "Software"), to deal 16 | # in the Software without restriction, including without limitation the rights 17 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | # copies of the Software, and to permit persons to whom the Software is 19 | # furnished to do so, subject to the following conditions: 20 | # 21 | # The above copyright notice and this permission notice shall be included in all 22 | # copies or substantial portions of the Software. 23 | # 24 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | # SOFTWARE. 31 | 32 | """Core vector quantization implementation.""" 33 | 34 | import typing as tp 35 | import warnings 36 | 37 | from einops import rearrange, repeat 38 | import torch 39 | from torch import nn 40 | import torch.nn.functional as F 41 | 42 | from utils import distrib 43 | 44 | 45 | def default(val: tp.Any, d: tp.Any) -> tp.Any: 46 | return val if val is not None else d 47 | 48 | 49 | def ema_inplace(moving_avg, new, decay: float): 50 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 51 | 52 | 53 | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): 54 | return (x + epsilon) / (x.sum() + n_categories * epsilon) 55 | 56 | 57 | def uniform_init(*shape: int): 58 | t = torch.empty(shape) 59 | nn.init.kaiming_uniform_(t) 60 | return t 61 | 62 | 63 | def sample_vectors(samples, num: int): 64 | num_samples, device = samples.shape[0], samples.device 65 | 66 | if num_samples >= num: 67 | indices = torch.randperm(num_samples, device=device)[:num] 68 | else: 69 | indices = torch.randint(0, num_samples, (num,), device=device) 70 | 71 | return samples[indices] 72 | 73 | 74 | def kmeans(samples, num_clusters: int, num_iters: int = 10): 75 | dim, dtype = samples.shape[-1], samples.dtype 76 | 77 | means = sample_vectors(samples, num_clusters) 78 | 79 | for _ in range(num_iters): 80 | diffs = rearrange(samples, "n d -> n () d") - rearrange( 81 | means, "c d -> () c d" 82 | ) 83 | dists = -(diffs ** 2).sum(dim=-1) 84 | 85 | buckets = dists.max(dim=-1).indices 86 | bins = torch.bincount(buckets, minlength=num_clusters) 87 | zero_mask = bins == 0 88 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 89 | 90 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 91 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) 92 | new_means = new_means / bins_min_clamped[..., None] 93 | 94 | means = torch.where(zero_mask[..., None], means, new_means) 95 | 96 | return means, bins 97 | 98 | 99 | class EuclideanCodebook(nn.Module): 100 | """Codebook with Euclidean distance. 101 | Args: 102 | dim (int): Dimension. 103 | codebook_size (int): Codebook size. 104 | kmeans_init (bool): Whether to use k-means to initialize the codebooks. 105 | If set to true, run the k-means algorithm on the first training batch and use 106 | the learned centroids as initialization. 107 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. 108 | decay (float): Decay for exponential moving average over the codebooks. 109 | epsilon (float): Epsilon value for numerical stability. 110 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 111 | that have an exponential moving average cluster size less than the specified threshold with 112 | randomly selected vector from the current batch. 113 | """ 114 | def __init__( 115 | self, 116 | dim: int, 117 | codebook_size: int, 118 | kmeans_init: int = False, 119 | kmeans_iters: int = 10, 120 | decay: float = 0.99, 121 | epsilon: float = 1e-5, 122 | threshold_ema_dead_code: int = 2, 123 | ): 124 | super().__init__() 125 | self.decay = decay 126 | init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros 127 | embed = init_fn(codebook_size, dim) 128 | 129 | self.codebook_size = codebook_size 130 | 131 | self.kmeans_iters = kmeans_iters 132 | self.epsilon = epsilon 133 | self.threshold_ema_dead_code = threshold_ema_dead_code 134 | 135 | self.register_buffer("inited", torch.Tensor([not kmeans_init])) 136 | self.register_buffer("cluster_size", torch.zeros(codebook_size)) 137 | self.register_buffer("embed", embed) 138 | self.register_buffer("embed_avg", embed.clone()) 139 | 140 | @torch.jit.ignore 141 | def init_embed_(self, data): 142 | if self.inited: 143 | return 144 | 145 | embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) 146 | self.embed.data.copy_(embed) 147 | self.embed_avg.data.copy_(embed.clone()) 148 | self.cluster_size.data.copy_(cluster_size) 149 | self.inited.data.copy_(torch.Tensor([True])) 150 | # Make sure all buffers across workers are in sync after initialization 151 | distrib.broadcast_tensors(self.buffers()) 152 | 153 | def replace_(self, samples, mask): 154 | modified_codebook = torch.where( 155 | mask[..., None], sample_vectors(samples, self.codebook_size), self.embed 156 | ) 157 | self.embed.data.copy_(modified_codebook) 158 | 159 | def expire_codes_(self, batch_samples): 160 | if self.threshold_ema_dead_code == 0: 161 | return 162 | 163 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 164 | if not torch.any(expired_codes): 165 | return 166 | 167 | batch_samples = rearrange(batch_samples, "... d -> (...) d") 168 | self.replace_(batch_samples, mask=expired_codes) 169 | distrib.broadcast_tensors(self.buffers()) 170 | 171 | def preprocess(self, x): 172 | x = rearrange(x, "... d -> (...) d") 173 | return x 174 | 175 | def quantize(self, x): 176 | embed = self.embed.t() 177 | dist = -( 178 | x.pow(2).sum(1, keepdim=True) 179 | - 2 * x @ embed 180 | + embed.pow(2).sum(0, keepdim=True) 181 | ) 182 | embed_ind = dist.max(dim=-1).indices 183 | return embed_ind 184 | 185 | def postprocess_emb(self, embed_ind, shape): 186 | return embed_ind.view(*shape[:-1]) 187 | 188 | def dequantize(self, embed_ind): 189 | quantize = F.embedding(embed_ind, self.embed) 190 | return quantize 191 | 192 | def encode(self, x): 193 | shape = x.shape 194 | # pre-process 195 | x = self.preprocess(x) 196 | # quantize 197 | embed_ind = self.quantize(x) 198 | # post-process 199 | embed_ind = self.postprocess_emb(embed_ind, shape) 200 | return embed_ind 201 | 202 | def decode(self, embed_ind): 203 | quantize = self.dequantize(embed_ind) 204 | return quantize 205 | 206 | def forward(self, x): 207 | shape, dtype = x.shape, x.dtype 208 | x = self.preprocess(x) 209 | 210 | self.init_embed_(x) 211 | 212 | embed_ind = self.quantize(x) 213 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 214 | embed_ind = self.postprocess_emb(embed_ind, shape) 215 | quantize = self.dequantize(embed_ind) 216 | 217 | if self.training: 218 | # We do the expiry of code at that point as buffers are in sync 219 | # and all the workers will take the same decision. 220 | self.expire_codes_(x) 221 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) 222 | embed_sum = x.t() @ embed_onehot 223 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay) 224 | cluster_size = ( 225 | laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) 226 | * self.cluster_size.sum() 227 | ) 228 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) 229 | self.embed.data.copy_(embed_normalized) 230 | 231 | return quantize, embed_ind 232 | 233 | 234 | class VectorQuantization(nn.Module): 235 | """Vector quantization implementation. 236 | Currently supports only euclidean distance. 237 | Args: 238 | dim (int): Dimension 239 | codebook_size (int): Codebook size 240 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. 241 | decay (float): Decay for exponential moving average over the codebooks. 242 | epsilon (float): Epsilon value for numerical stability. 243 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 244 | kmeans_iters (int): Number of iterations used for kmeans initialization. 245 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 246 | that have an exponential moving average cluster size less than the specified threshold with 247 | randomly selected vector from the current batch. 248 | commitment_weight (float): Weight for commitment loss. 249 | """ 250 | def __init__( 251 | self, 252 | dim: int, 253 | codebook_size: int, 254 | codebook_dim: tp.Optional[int] = None, 255 | decay: float = 0.99, 256 | epsilon: float = 1e-5, 257 | kmeans_init: bool = True, 258 | kmeans_iters: int = 50, 259 | threshold_ema_dead_code: int = 2, 260 | commitment_weight: float = 1., 261 | ): 262 | super().__init__() 263 | _codebook_dim: int = default(codebook_dim, dim) 264 | 265 | requires_projection = _codebook_dim != dim 266 | self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) 267 | self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) 268 | 269 | self.epsilon = epsilon 270 | self.commitment_weight = commitment_weight 271 | 272 | self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, 273 | kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, 274 | decay=decay, epsilon=epsilon, 275 | threshold_ema_dead_code=threshold_ema_dead_code) 276 | self.codebook_size = codebook_size 277 | 278 | @property 279 | def codebook(self): 280 | return self._codebook.embed 281 | 282 | def encode(self, x): 283 | x = rearrange(x, "b d n -> b n d") 284 | x = self.project_in(x) 285 | embed_in = self._codebook.encode(x) 286 | return embed_in 287 | 288 | def decode(self, embed_ind): 289 | quantize = self._codebook.decode(embed_ind) 290 | quantize = self.project_out(quantize) 291 | quantize = rearrange(quantize, "b n d -> b d n") 292 | return quantize 293 | 294 | def forward(self, x): 295 | device = x.device 296 | x = rearrange(x, "b d n -> b n d") 297 | x = self.project_in(x) 298 | 299 | quantize, embed_ind = self._codebook(x) 300 | 301 | if self.training: 302 | quantize = x + (quantize - x).detach() 303 | 304 | loss = torch.tensor([0.0], device=device, requires_grad=self.training) 305 | 306 | if self.training: 307 | warnings.warn('When using RVQ in training model, first check ' 308 | 'https://github.com/facebookresearch/encodec/issues/25 . ' 309 | 'The bug wasn\'t fixed here for reproducibility.') 310 | if self.commitment_weight > 0: 311 | commit_loss = F.mse_loss(quantize.detach(), x) 312 | loss = loss + commit_loss * self.commitment_weight 313 | 314 | quantize = self.project_out(quantize) 315 | quantize = rearrange(quantize, "b n d -> b d n") 316 | return quantize, embed_ind, loss 317 | 318 | 319 | class ResidualVectorQuantization(nn.Module): 320 | """Residual vector quantization implementation. 321 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf 322 | """ 323 | def __init__(self, *, num_quantizers, **kwargs): 324 | super().__init__() 325 | self.layers = nn.ModuleList( 326 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)] 327 | ) 328 | 329 | def forward(self, x, n_q: tp.Optional[int] = None): 330 | quantized_out = 0.0 331 | residual = x 332 | 333 | all_losses = [] 334 | all_indices = [] 335 | 336 | n_q = n_q or len(self.layers) 337 | 338 | for layer in self.layers[:n_q]: 339 | quantized, indices, loss = layer(residual) 340 | residual = residual - quantized 341 | quantized_out = quantized_out + quantized 342 | 343 | all_indices.append(indices) 344 | all_losses.append(loss) 345 | 346 | out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) 347 | return quantized_out, out_indices, out_losses 348 | 349 | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: 350 | residual = x 351 | all_indices = [] 352 | n_q = n_q or len(self.layers) 353 | for layer in self.layers[:n_q]: 354 | indices = layer.encode(residual) 355 | quantized = layer.decode(indices) 356 | residual = residual - quantized 357 | all_indices.append(indices) 358 | out_indices = torch.stack(all_indices) 359 | return out_indices 360 | 361 | def decode(self, q_indices: torch.Tensor) -> torch.Tensor: 362 | quantized_out = torch.tensor(0.0, device=q_indices.device) 363 | for i, indices in enumerate(q_indices): 364 | layer = self.layers[i] 365 | quantized = layer.decode(indices) 366 | quantized_out = quantized_out + quantized 367 | return quantized_out 368 | -------------------------------------------------------------------------------- /modules/quantization/vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Residual vector quantizer implementation.""" 8 | 9 | from dataclasses import dataclass, field 10 | import math 11 | import typing as tp 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from .core_vq import ResidualVectorQuantization 17 | 18 | 19 | @dataclass 20 | class QuantizedResult: 21 | quantized: torch.Tensor 22 | codes: torch.Tensor 23 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 24 | penalty: tp.Optional[torch.Tensor] = None 25 | metrics: dict = field(default_factory=dict) 26 | 27 | 28 | class ResidualVectorQuantizer(nn.Module): 29 | """Residual Vector Quantizer. 30 | Args: 31 | dimension (int): Dimension of the codebooks. 32 | n_q (int): Number of residual vector quantizers used. 33 | bins (int): Codebook size. 34 | decay (float): Decay for exponential moving average over the codebooks. 35 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 36 | kmeans_iters (int): Number of iterations used for kmeans initialization. 37 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 38 | that have an exponential moving average cluster size less than the specified threshold with 39 | randomly selected vector from the current batch. 40 | """ 41 | def __init__( 42 | self, 43 | dimension: int = 256, 44 | n_q: int = 8, 45 | bins: int = 1024, 46 | decay: float = 0.99, 47 | kmeans_init: bool = True, 48 | kmeans_iters: int = 50, 49 | threshold_ema_dead_code: int = 2, 50 | ): 51 | super().__init__() 52 | self.n_q = n_q 53 | self.dimension = dimension 54 | self.bins = bins 55 | self.decay = decay 56 | self.kmeans_init = kmeans_init 57 | self.kmeans_iters = kmeans_iters 58 | self.threshold_ema_dead_code = threshold_ema_dead_code 59 | self.vq = ResidualVectorQuantization( 60 | dim=self.dimension, 61 | codebook_size=self.bins, 62 | num_quantizers=self.n_q, 63 | decay=self.decay, 64 | kmeans_init=self.kmeans_init, 65 | kmeans_iters=self.kmeans_iters, 66 | threshold_ema_dead_code=self.threshold_ema_dead_code, 67 | ) 68 | 69 | def forward(self, x: torch.Tensor) -> QuantizedResult: 70 | """Residual vector quantization on the given input tensor. 71 | Args: 72 | x (torch.Tensor): Input tensor. 73 | frame_rate (int): Sample rate of the input tensor. 74 | bandwidth (float): Target bandwidth. 75 | Returns: 76 | QuantizedResult: 77 | The quantized (or approximately quantized) representation with 78 | the associated bandwidth and any penalty term for the loss. 79 | """ 80 | quantized, codes, commit_loss = self.vq(x, n_q=self.n_q) 81 | return quantized, codes, commit_loss 82 | 83 | def get_num_quantizers_for_bandwidth(self, frame_rate: int, bandwidth: tp.Optional[float] = None) -> int: 84 | """Return n_q based on specified target bandwidth. 85 | """ 86 | bw_per_q = self.get_bandwidth_per_quantizer(frame_rate) 87 | n_q = self.n_q 88 | if bandwidth and bandwidth > 0.: 89 | # bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as 90 | # bandwidth == 6.0 91 | n_q = int(max(1, math.floor(bandwidth * 1000 / bw_per_q))) 92 | return n_q 93 | 94 | def get_bandwidth_per_quantizer(self, frame_rate: int): 95 | """Return bandwidth per quantizer for a given input frame rate. 96 | Each quantizer encodes a frame with lg(bins) bits. 97 | """ 98 | return math.log2(self.bins) * frame_rate 99 | 100 | def encode(self, x: torch.Tensor, frame_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: 101 | """Encode a given input tensor with the specified frame rate at the given bandwidth. 102 | The RVQ encode method sets the appropriate number of quantizers to use 103 | and returns indices for each quantizer. 104 | """ 105 | n_q = self.get_num_quantizers_for_bandwidth(frame_rate, bandwidth) 106 | codes = self.vq.encode(x, n_q=n_q) 107 | return codes 108 | 109 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 110 | """Decode the given codes to the quantized representation. 111 | """ 112 | quantized = self.vq.decode(codes) 113 | return quantized 114 | -------------------------------------------------------------------------------- /modules/tokenizer.py: -------------------------------------------------------------------------------- 1 | from pypinyin import pinyin, Style 2 | from phonemizer.separator import Separator 3 | 4 | import re 5 | from dataclasses import dataclass 6 | 7 | from lhotse.features import FeatureExtractor 8 | from lhotse.utils import Seconds, compute_num_frames 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from typing import Union 14 | 15 | import re 16 | 17 | from speechbrain.lobes.models.FastSpeech2 import mel_spectogram 18 | 19 | HIFIGAN_SR = 16000 20 | HIFIGAN_HOP_LENGTH = 256 21 | HIFIGAN_WIN_LENGTH = 1024 22 | HIFIGAN_MEL_CHANNELS = 80 23 | HIFIGAN_NFFT = 1024 24 | HIFIGAN_MAX_FREQ = 8000 25 | 26 | 27 | def get_pinyin2lty(): 28 | pinyin2lty = {} 29 | with open('utils/mandarin_pinyin_to_mfa_lty.dict', 'r') as f: 30 | lines = f.readlines() 31 | 32 | for line in lines: 33 | ele = re.split(r'\t', line) 34 | 35 | ity_phones = re.split(r'[ ]+', ele[-1].strip()) 36 | pinyin2lty[ele[0]] = ity_phones 37 | 38 | return pinyin2lty 39 | 40 | 41 | class TextTokenizer: 42 | def __init__(self) -> None: 43 | 44 | self.separator = Separator(word="_", syllable="-", phone="|") 45 | self.pinyin2lty = get_pinyin2lty() 46 | 47 | def phonemize(self, text: str) -> str: 48 | text = re.sub(r'[^\w\s]+', ' ', text) # remove punctuation 49 | text = re.sub(r'[ ]+', ' ', text) # remove extra spaces 50 | text = text.lower() 51 | 52 | phonemizeds = [] 53 | for text_eng_chn in re.split(r"[^\w\s']+", text): 54 | # split chinese and english 55 | for text in re.split(r"([a-z ]+)", text_eng_chn): 56 | text = text.strip() 57 | if text == '' or text == "'": 58 | continue 59 | d = [] 60 | if re.match(r"[a-z ']+", text): 61 | for word in re.split(r"[ ]+", text): 62 | phonemizeds.append(word) 63 | else: 64 | phones = [] 65 | for n, py in enumerate( 66 | pinyin( 67 | text, style=Style.TONE3, neutral_tone_with_five=True 68 | ) 69 | ): 70 | if not py[0][-1].isalnum(): 71 | raise ValueError 72 | phones.append(py[0]) 73 | phonemizeds.append(self.separator.phone.join(phones)) 74 | 75 | phonemizeds = f'{self.separator.word}'.join( 76 | [phones for phones in phonemizeds]) 77 | return phonemizeds 78 | 79 | def tokenize(self, text): 80 | phones = [] 81 | for word in re.split('([_-])', self.phonemize(text.strip())): 82 | if len(word): 83 | for phone in re.split('\|', word): 84 | if len(phone): 85 | phones.append(phone) 86 | 87 | return phones 88 | 89 | def tokenize_lty(self, pinyin_tokens): 90 | lty_tokens_list = [] 91 | 92 | for token in pinyin_tokens: 93 | if token in self.pinyin2lty.keys(): 94 | lty_tokens = self.pinyin2lty[token] 95 | lty_tokens_list += lty_tokens 96 | else: 97 | lty_tokens_list.append(token) 98 | return lty_tokens_list 99 | 100 | 101 | @dataclass 102 | class AudioFeatExtraConfig: 103 | frame_shift: Seconds = HIFIGAN_HOP_LENGTH / HIFIGAN_SR 104 | feature_dim: int = HIFIGAN_MEL_CHANNELS 105 | 106 | 107 | def extract_mel_spec(samples): 108 | mel_spec, _ = mel_spectogram( 109 | audio=samples, 110 | sample_rate=HIFIGAN_SR, 111 | hop_length=HIFIGAN_HOP_LENGTH, 112 | win_length=HIFIGAN_WIN_LENGTH, 113 | n_mels=HIFIGAN_MEL_CHANNELS, 114 | n_fft=HIFIGAN_NFFT, 115 | f_min=0, 116 | f_max=HIFIGAN_MAX_FREQ, 117 | power=1, 118 | normalized=False, 119 | min_max_energy_norm=True, 120 | norm="slaney", 121 | mel_scale="slaney", 122 | compression=True 123 | ) 124 | 125 | return mel_spec 126 | 127 | 128 | class MelSpecExtractor(FeatureExtractor): 129 | name = "mel_spec" 130 | config_type = AudioFeatExtraConfig 131 | 132 | @property 133 | def frame_shift(self) -> Seconds: 134 | return self.config.frame_shift 135 | 136 | def feature_dim(self, sampling_rate: int) -> int: 137 | return self.config.feature_dim 138 | 139 | def extract(self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int) -> np.ndarray: 140 | assert sampling_rate == HIFIGAN_SR 141 | if not isinstance(samples, torch.Tensor): 142 | samples = torch.from_numpy(samples) 143 | torch.set_num_threads(1) 144 | # Hifigan 145 | 146 | samples = samples.squeeze() 147 | mel_spec = extract_mel_spec(samples) 148 | 149 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 150 | num_frames = compute_num_frames( 151 | duration=duration, 152 | frame_shift=self.frame_shift, 153 | sampling_rate=sampling_rate, 154 | ) 155 | return mel_spec.squeeze(0).permute(1, 0)[:num_frames, :].numpy() 156 | 157 | 158 | if __name__ == '__main__': 159 | tt = TextTokenizer() 160 | 161 | txt = 'Hellow你好啊,我是Simon,你叫什么名字?What is your name?' 162 | phones = tt.phonemize(txt) 163 | print(phones) 164 | print(tt.tokenize(txt)) 165 | print(tt.tokenize_lty(tt.tokenize(txt))) 166 | # assert phones == 'hellow_ni3_hao3_wo3_shi4_simon_ni3_jiao4_shen2_me5_ming2_zi4_what_is_your_name' 167 | # print(tt.tokenize(txt)) 168 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from ctypes import Union 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from einops import rearrange 8 | 9 | from utils.utils import make_attn_mask 10 | 11 | 12 | def _get_clones(module, N): 13 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 14 | 15 | 16 | class MultiHeadAttention(nn.Module): 17 | def __init__(self, qkv_dim, n_heads=8, dropout=0.): 18 | super().__init__() 19 | 20 | assert qkv_dim % n_heads == 0 21 | self.n_heads = n_heads 22 | self.head_dim = qkv_dim // n_heads 23 | self.dropout = dropout 24 | self.qkv_dim = qkv_dim 25 | 26 | self.w_q = nn.Linear(qkv_dim, qkv_dim, bias=True) 27 | self.w_k = nn.Linear(qkv_dim, qkv_dim, bias=True) 28 | self.w_v = nn.Linear(qkv_dim, qkv_dim, bias=True) 29 | 30 | self.out_proj = nn.Sequential( 31 | nn.Linear(qkv_dim, qkv_dim), 32 | nn.Dropout(dropout), 33 | ) 34 | 35 | def forward(self, q, kv=None, mask=None): 36 | 37 | bsz, tgt_len, _ = q.size() 38 | src_len = kv.size(1) if kv is not None else tgt_len 39 | 40 | if kv is None: 41 | k = self.w_k(q) 42 | v = self.w_v(q) 43 | q = self.w_q(q) 44 | else: 45 | k = self.w_k(kv) 46 | v = self.w_v(kv) 47 | q = self.w_q(q) 48 | 49 | q = q.view(bsz, tgt_len, self.n_heads, self.head_dim).transpose(1, 2) 50 | k = k.view(bsz, src_len, self.n_heads, self.head_dim).transpose(1, 2) 51 | v = v.view(bsz, src_len, self.n_heads, self.head_dim).transpose(1, 2) 52 | att = F.scaled_dot_product_attention( 53 | q, k, v, mask, self.dropout if self.training else 0.0, False) 54 | 55 | att = att.transpose(1, 2).contiguous().view(bsz, tgt_len, self.qkv_dim) 56 | 57 | return self.out_proj(att) 58 | 59 | class TransformerEncoderLayer(nn.Module): 60 | def __init__(self, dim, ff_dim, conv_ff=False, n_heads=8, dropout=0.): 61 | super().__init__() 62 | 63 | self.dim = dim 64 | self.conv_ff = conv_ff 65 | self.n_heads = n_heads 66 | 67 | self.norm1 = nn.LayerNorm(dim) 68 | self.norm2 = nn.LayerNorm(dim) 69 | 70 | self.attn = MultiHeadAttention(dim, n_heads=n_heads, dropout=dropout) 71 | 72 | self.dropout = nn.Dropout(dropout) 73 | 74 | if conv_ff: 75 | self.ff = nn.Sequential( 76 | nn.Conv1d(dim, ff_dim, kernel_size=5, padding=2), 77 | nn.ReLU(), 78 | nn.Conv1d(ff_dim, dim, kernel_size=5, padding=2), 79 | ) 80 | else: 81 | self.ff = nn.Sequential( 82 | nn.Linear(dim, ff_dim), 83 | nn.ReLU(), 84 | self.dropout, 85 | nn.Linear(ff_dim, dim), 86 | ) 87 | 88 | def forward( 89 | self, 90 | x: torch.Tensor, 91 | mask: torch.Tensor = None 92 | ): 93 | 94 | x = x + self.attn(self.norm1(x), mask=mask) 95 | if self.conv_ff: 96 | x = self.norm2(x) 97 | x = rearrange(x, 'B T D -> B D T') 98 | x = x + self.ff(x) 99 | x = rearrange(x, 'B D T -> B T D') 100 | else: 101 | x = x + self.ff(self.norm2(x)) 102 | return x 103 | 104 | 105 | class TransformerEncoder(nn.Module): 106 | def __init__( 107 | self, 108 | encoder_layer: TransformerEncoderLayer, 109 | num_layers: int, 110 | norm=None 111 | ): 112 | super().__init__() 113 | 114 | self.layers = _get_clones(encoder_layer, num_layers) 115 | self.num_layers = num_layers 116 | self.norm = norm 117 | 118 | def forward( 119 | self, 120 | x: torch.Tensor, 121 | x_lens: torch.Tensor = None, 122 | causal: bool = False 123 | ) -> torch.Tensor: 124 | 125 | if x_lens is not None: 126 | mask = make_attn_mask(x_lens, self.layers[0].n_heads, causal=causal) 127 | else: 128 | mask = None 129 | for layer in self.layers: 130 | x = layer(x, mask=mask) 131 | if self.norm is not None: 132 | x = self.norm(x) 133 | return x 134 | 135 | 136 | def test(): 137 | 138 | x = torch.zeros([3, 7, 128]).to('cuda') 139 | context = torch.zeros([3, 20, 128]).to('cuda') 140 | 141 | x_lens = torch.Tensor([3, 4, 7]).to('cuda').to(torch.int32) 142 | context_lens = torch.Tensor([11, 12, 20]).to('cuda').to(torch.int32) 143 | 144 | encoder = TransformerEncoder( 145 | TransformerEncoderLayer( 146 | 128, 147 | 4 * 128, 148 | n_heads=8, 149 | dropout=0.1, 150 | conv_ff=True 151 | ), 152 | 12, 153 | nn.LayerNorm(128) 154 | ).to('cuda') 155 | 156 | 157 | context = encoder(context, x_lens=context_lens) 158 | print(context.shape) 159 | -------------------------------------------------------------------------------- /modules/vqpe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules.convnet import ConvNetDouble 6 | from modules.tokenizer import ( 7 | HIFIGAN_MEL_CHANNELS 8 | ) 9 | from modules.quantization import ResidualVectorQuantizer 10 | 11 | from einops import rearrange 12 | 13 | class VQProsodyEncoder(nn.Module): 14 | def __init__( 15 | self, 16 | mel_bins: int = HIFIGAN_MEL_CHANNELS, 17 | stride:int = 8, 18 | hidden_size: int = 384, 19 | kernel_size: int = 5, 20 | n_layers: int = 3, 21 | n_stacks: int = 5, 22 | n_blocks: int = 2, 23 | vq_bins: int = 1024, 24 | vq_dim: int = 256, 25 | activation: str = 'ReLU', 26 | ): 27 | super(VQProsodyEncoder, self).__init__() 28 | self.stride = stride 29 | self.mel_bins = mel_bins 30 | 31 | self.convnet = ConvNetDouble( 32 | in_channels=mel_bins, 33 | out_channels=vq_dim, 34 | hidden_size=hidden_size, 35 | n_layers=n_layers, 36 | n_stacks=n_stacks, 37 | n_blocks=n_blocks, 38 | middle_layer=nn.MaxPool1d(stride, ceil_mode=True), 39 | kernel_size=kernel_size, 40 | activation=activation, 41 | ) 42 | 43 | self.vq = ResidualVectorQuantizer( 44 | dimension=vq_dim, 45 | n_q=1, 46 | bins=vq_bins, 47 | decay=0.99 48 | ) 49 | 50 | def forward( 51 | self, 52 | mel: torch.Tensor, # (B, T, mel_bins) 53 | ): 54 | mel_len = mel.size(1) 55 | mel = mel[..., :self.mel_bins] 56 | mel = rearrange(mel, "B T D -> B D T") 57 | ze = self.convnet(mel) 58 | zq, codes, commit_loss = self.vq(ze) 59 | vq_loss = F.mse_loss(ze.detach(), zq) 60 | zq = rearrange(zq, "B D T -> B T D").unsqueeze(2).contiguous().expand(-1, -1, self.stride, -1) 61 | zq = rearrange(zq, "B T S D -> B (T S) D")[:, :mel_len, :] 62 | return zq, commit_loss, vq_loss, codes 63 | 64 | def test(): 65 | model = VQProsodyEncoder() 66 | mel = torch.randn(2, 303, 80) 67 | zq, commit_loss, vq_loss = model(mel) 68 | 69 | print(zq.shape, commit_loss.shape, vq_loss.shape) 70 | 71 | -------------------------------------------------------------------------------- /prepare_ds.py: -------------------------------------------------------------------------------- 1 | ''' 2 | wavs dir 3 | ├── speaker1 4 | │ ├── s1wav1.wav 5 | │ ├── s1wav1.txt 6 | │ ├── s1wav2.wav 7 | │ ├── s1wav2.txt 8 | │ ├── ... 9 | ├── speaker2 10 | │ ├── s2wav1.wav 11 | │ ├── s2wav1.txt 12 | │ ├── ... 13 | 14 | cautions: stage 0 will delete all txt files in wavs dir 15 | ''' 16 | 17 | import os 18 | 19 | import glob 20 | from modules.tokenizer import TextTokenizer 21 | from multiprocessing import Pool 22 | from tqdm.auto import tqdm 23 | from utils.textgrid import read_textgrid 24 | 25 | import argparse 26 | 27 | from lhotse import validate_recordings_and_supervisions, CutSet, NumpyHdf5Writer, load_manifest_lazy, load_manifest 28 | from lhotse.audio import Recording, RecordingSet 29 | from lhotse.supervision import SupervisionSegment, SupervisionSet 30 | from lhotse.recipes.utils import read_manifests_if_cached 31 | from lhotse.utils import Seconds, compute_num_frames 32 | 33 | from functools import partial 34 | 35 | from modules.tokenizer import ( 36 | HIFIGAN_SR, 37 | HIFIGAN_HOP_LENGTH, 38 | MelSpecExtractor, 39 | AudioFeatExtraConfig 40 | ) 41 | from models.megatts2 import MegaG 42 | from modules.datamodule import TTSDataset, make_spk_cutset 43 | 44 | from utils.symbol_table import SymbolTable 45 | 46 | import soundfile as sf 47 | import librosa 48 | 49 | import torch 50 | import numpy as np 51 | 52 | def make_lab(tt, wav): 53 | id = wav.split('/')[-1].split('.')[0] 54 | folder = '/'.join(wav.split('/')[:-1]) 55 | # Create lab files 56 | with open(f'{folder}/{id}.txt', 'r') as f: 57 | txt = f.read() 58 | 59 | with open(f'{folder}/{id}.lab', 'w') as f: 60 | f.write(' '.join(tt.tokenize(txt))) 61 | 62 | 63 | class DatasetMaker: 64 | def __init__(self): 65 | parser = argparse.ArgumentParser() 66 | 67 | parser.add_argument('--stage', type=int, default=0, 68 | help='Stage to start from') 69 | parser.add_argument('--wavtxt_path', type=str, 70 | default='data/wavs/', help='Path to wav and txt files') 71 | parser.add_argument('--text_grid_path', type=str, 72 | default='data/textgrids/', help='Path to textgrid files') 73 | parser.add_argument('--ds_path', type=str, 74 | default='data/ds/', help='Path to save dataset') 75 | parser.add_argument('--num_workers', type=int, 76 | default=4, help='Number of workers') 77 | parser.add_argument('--test_set_ratio', type=float, 78 | default=0.03, help='Test set ratio') 79 | parser.add_argument('--trim_wav', type=bool, 80 | default=False, help='Trim wav by textgrid') 81 | parser.add_argument('--generator_ckpt', type=str, 82 | default='generator.ckpt', help='Load generator checkpoint') 83 | parser.add_argument('--generator_config', type=str, 84 | default='configs/config_gan.yaml', help='Load generator config') 85 | 86 | self.args = parser.parse_args() 87 | 88 | self.test_set_interval = int(1 / self.args.test_set_ratio) 89 | 90 | def make_labs(self): 91 | wavs = glob.glob(f'{self.args.wavtxt_path}/**/*.wav', recursive=True) 92 | tt = TextTokenizer() 93 | 94 | with Pool(self.args.num_workers) as p: 95 | for _ in tqdm(p.imap(partial(make_lab, tt), wavs), total=len(wavs)): 96 | pass 97 | 98 | def make_ds(self): 99 | tgs = glob.glob( 100 | f'{self.args.text_grid_path}/**/*.TextGrid', recursive=True) 101 | 102 | recordings = [[], []] # train, test 103 | supervisions = [[], []] 104 | set_name = ['train', 'valid'] 105 | max_duration_token = 0 106 | 107 | for i, tg in tqdm(enumerate(tgs)): 108 | id = tg.split('/')[-1].split('.')[0] 109 | speaker = tg.split('/')[-2] 110 | 111 | intervals = [i for i in read_textgrid(tg) if (i[3] == 'phones')] 112 | 113 | y, sr = librosa.load( 114 | f'{self.args.wavtxt_path}/{speaker}/{id}.wav', sr=HIFIGAN_SR) 115 | 116 | if intervals[0][2] == '': 117 | intervals = intervals[1:] 118 | if intervals[-1][2] == '': 119 | intervals = intervals[:-1] 120 | if self.args.trim_wav: 121 | start = intervals[0][0]*sr 122 | stop = intervals[-1][1]*sr 123 | y = y[int(start):int(stop)] 124 | y = librosa.util.normalize(y) 125 | 126 | sf.write( 127 | f'{self.args.wavtxt_path}/{speaker}/{id}.wav', y, HIFIGAN_SR) 128 | 129 | start = intervals[0][0] 130 | stop = intervals[-1][1] 131 | 132 | frame_shift=HIFIGAN_HOP_LENGTH / HIFIGAN_SR 133 | duration = round(y.shape[-1] / HIFIGAN_SR, ndigits=12) 134 | n_frames = compute_num_frames( 135 | duration=duration, 136 | frame_shift=frame_shift, 137 | sampling_rate=HIFIGAN_SR, 138 | ) 139 | 140 | duration_tokens = [] 141 | phone_tokens = [] 142 | 143 | for i, interval in enumerate(intervals): 144 | 145 | phone_stop = (interval[1] - start) 146 | n_frame_interval = int(phone_stop / frame_shift) 147 | duration_tokens.append(n_frame_interval - sum(duration_tokens)) 148 | phone_tokens.append(interval[2] if interval[2] != '' else '') 149 | 150 | if sum(duration_tokens) > n_frames: 151 | print(f'{id} duration_tokens: {sum(duration_tokens)} must <= n_frames: {n_frames}') 152 | assert False 153 | 154 | recording = Recording.from_file( 155 | f'{self.args.wavtxt_path}/{speaker}/{id}.wav') 156 | text = open( 157 | f'{self.args.wavtxt_path}/{speaker}/{id}.txt', 'r').read() 158 | segment = SupervisionSegment( 159 | id=id, 160 | recording_id=id, 161 | start=0, 162 | duration=recording.duration, 163 | channel=0, 164 | language="CN", 165 | speaker=speaker, 166 | text=text, 167 | ) 168 | 169 | if abs(recording.duration - (stop - start)) > 0.01: 170 | print(f'{id} recording duration: {recording.duration} != {stop - start}') 171 | assert False 172 | 173 | set_id = 0 if i % self.test_set_interval else 1 174 | recordings[set_id].append(recording) 175 | supervisions[set_id].append(segment) 176 | 177 | segment.custom = {} 178 | segment.custom['duration_tokens'] = duration_tokens 179 | segment.custom['phone_tokens'] = phone_tokens 180 | 181 | max_duration_token = max(max_duration_token, len(duration_tokens)) 182 | 183 | assert len(duration_tokens) == len(phone_tokens) 184 | 185 | for i in range(2): 186 | recording_set = RecordingSet.from_recordings(recordings[i]) 187 | supervision_set = SupervisionSet.from_segments(supervisions[i]) 188 | validate_recordings_and_supervisions( 189 | recording_set, supervision_set) 190 | 191 | supervision_set.to_file( 192 | f"{self.args.ds_path}/supervisions_{set_name[i]}.jsonl.gz") 193 | recording_set.to_file( 194 | f"{self.args.ds_path}/recordings_{set_name[i]}.jsonl.gz") 195 | 196 | # Extract features 197 | manifests = read_manifests_if_cached( 198 | dataset_parts=['train', 'valid'], 199 | output_dir=self.args.ds_path, 200 | prefix="", 201 | suffix='jsonl.gz', 202 | types=["recordings", "supervisions"], 203 | ) 204 | 205 | for partition, m in manifests.items(): 206 | cut_set = CutSet.from_manifests( 207 | recordings=m["recordings"], 208 | supervisions=m["supervisions"], 209 | ) 210 | 211 | # extract 212 | cut_set = cut_set.compute_and_store_features( 213 | extractor=MelSpecExtractor(AudioFeatExtraConfig()), 214 | storage_path=f"{self.args.ds_path}/cuts_{partition}", 215 | storage_type=NumpyHdf5Writer, 216 | num_jobs=self.args.num_workers, 217 | ) 218 | 219 | cut_set.to_file( 220 | f"{self.args.ds_path}/cuts_{partition}.jsonl.gz") 221 | 222 | print(f'max_duration_token: {max_duration_token}') 223 | 224 | def extract_latent(self): 225 | 226 | os.system(f'mkdir -p {self.args.ds_path}/latents') 227 | 228 | G = MegaG.from_pretrained(dm.args.generator_ckpt, dm.args.generator_config) 229 | G = G.cuda() 230 | G.eval() 231 | 232 | cs_all = load_manifest(f'{dm.args.ds_path}/cuts_train.jsonl.gz') + load_manifest(f'{dm.args.ds_path}/cuts_valid.jsonl.gz') 233 | spk_cs = make_spk_cutset(cs_all) 234 | 235 | for spk in spk_cs.keys(): 236 | os.system(f'mkdir -p {self.args.ds_path}/latents/{spk}') 237 | 238 | ttsds = TTSDataset(spk_cs, f'{dm.args.ds_path}', 10) 239 | 240 | for c in tqdm(cs_all): 241 | id = c.recording_id 242 | spk = c.supervisions[0].speaker 243 | batch = ttsds.__getitem__(CutSet.from_cuts([c])) 244 | 245 | s2_latent = {} 246 | with torch.no_grad(): 247 | 248 | tc_latent, p_code = G.s2_latent( 249 | batch['phone_tokens'].cuda(), 250 | batch['tokens_lens'].cuda(), 251 | batch['mel_timbres'].cuda(), 252 | batch['mel_targets'].cuda() 253 | ) 254 | 255 | s2_latent['tc_latent'] = tc_latent.cpu().numpy() 256 | s2_latent['p_code'] = p_code.cpu().numpy() 257 | 258 | np.save(f'{self.args.ds_path}/latents/{spk}/{id}.npy', s2_latent) 259 | 260 | if __name__ == '__main__': 261 | dm = DatasetMaker() 262 | 263 | # Create lab files 264 | if dm.args.stage == 0: 265 | dm.make_labs() 266 | elif dm.args.stage == 1: 267 | dm.make_ds() 268 | 269 | # Test 270 | cs_train = load_manifest_lazy( 271 | f'{dm.args.ds_path}/cuts_train.jsonl.gz') 272 | cs_valid = load_manifest_lazy( 273 | f'{dm.args.ds_path}/cuts_valid.jsonl.gz') 274 | cs = cs_train + cs_valid 275 | 276 | unique_symbols = set() 277 | 278 | for c in tqdm(cs): 279 | unique_symbols.update(c.supervisions[0].custom["phone_tokens"]) 280 | 281 | unique_phonemes = SymbolTable() 282 | for s in sorted(list(unique_symbols)): 283 | unique_phonemes.add(s) 284 | 285 | unique_phonemes_file = f"unique_text_tokens.k2symbols" 286 | unique_phonemes.to_file(f'{dm.args.ds_path}/{unique_phonemes_file}') 287 | 288 | print(cs.describe()) 289 | elif dm.args.stage == 2: 290 | dm.extract_latent() 291 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0+cu118 2 | torchaudio==2.1.0+cu118 3 | torchvision==0.16.0+cu118 4 | lightning==2.1.2 5 | lhotse==1.17.0 6 | h5py -------------------------------------------------------------------------------- /utils/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Torch distributed utilities.""" 8 | 9 | import typing as tp 10 | 11 | import torch 12 | 13 | 14 | def rank(): 15 | if torch.distributed.is_initialized(): 16 | return torch.distributed.get_rank() 17 | else: 18 | return 0 19 | 20 | 21 | def world_size(): 22 | if torch.distributed.is_initialized(): 23 | return torch.distributed.get_world_size() 24 | else: 25 | return 1 26 | 27 | 28 | def is_distributed(): 29 | return world_size() > 1 30 | 31 | 32 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): 33 | if is_distributed(): 34 | return torch.distributed.all_reduce(tensor, op) 35 | 36 | 37 | def _is_complex_or_float(tensor): 38 | return torch.is_floating_point(tensor) or torch.is_complex(tensor) 39 | 40 | 41 | def _check_number_of_params(params: tp.List[torch.Tensor]): 42 | # utility function to check that the number of params in all workers is the same, 43 | # and thus avoid a deadlock with distributed all reduce. 44 | if not is_distributed() or not params: 45 | return 46 | tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) 47 | all_reduce(tensor) 48 | if tensor.item() != len(params) * world_size(): 49 | # If not all the workers have the same number, for at least one of them, 50 | # this inequality will be verified. 51 | raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " 52 | "at least one worker has a different one.") 53 | 54 | 55 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): 56 | """Broadcast the tensors from the given parameters to all workers. 57 | This can be used to ensure that all workers have the same model to start with. 58 | """ 59 | if not is_distributed(): 60 | return 61 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] 62 | _check_number_of_params(tensors) 63 | handles = [] 64 | for tensor in tensors: 65 | handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) 66 | handles.append(handle) 67 | for handle in handles: 68 | handle.wait() 69 | 70 | 71 | def sync_buffer(buffers, average=True): 72 | """ 73 | Sync grad for buffers. If average is False, broadcast instead of averaging. 74 | """ 75 | if not is_distributed(): 76 | return 77 | handles = [] 78 | for buffer in buffers: 79 | if torch.is_floating_point(buffer.data): 80 | if average: 81 | handle = torch.distributed.all_reduce( 82 | buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 83 | else: 84 | handle = torch.distributed.broadcast( 85 | buffer.data, src=0, async_op=True) 86 | handles.append((buffer, handle)) 87 | for buffer, handle in handles: 88 | handle.wait() 89 | if average: 90 | buffer.data /= world_size 91 | 92 | 93 | def sync_grad(params): 94 | """ 95 | Simpler alternative to DistributedDataParallel, that doesn't rely 96 | on any black magic. For simple models it can also be as fast. 97 | Just call this on your model parameters after the call to backward! 98 | """ 99 | if not is_distributed(): 100 | return 101 | handles = [] 102 | for p in params: 103 | if p.grad is not None: 104 | handle = torch.distributed.all_reduce( 105 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 106 | handles.append((p, handle)) 107 | for p, handle in handles: 108 | handle.wait() 109 | p.grad.data /= world_size() 110 | 111 | 112 | def average_metrics(metrics: tp.Dict[str, float], count=1.): 113 | """Average a dictionary of metrics across all workers, using the optional 114 | `count` as unnormalized weight. 115 | """ 116 | if not is_distributed(): 117 | return metrics 118 | keys, values = zip(*metrics.items()) 119 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 120 | tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) 121 | tensor *= count 122 | all_reduce(tensor) 123 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() 124 | return dict(zip(keys, averaged)) 125 | -------------------------------------------------------------------------------- /utils/symbol_table.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 2 | # 3 | # See ../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass 18 | from dataclasses import field 19 | from typing import Dict 20 | from typing import Generic 21 | from typing import List 22 | from typing import Optional 23 | from typing import TypeVar 24 | from typing import Union 25 | 26 | Symbol = TypeVar('Symbol') 27 | 28 | 29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter. 30 | @dataclass(repr=False) 31 | class SymbolTable(Generic[Symbol]): 32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to 33 | actual objects. These objects can be arbitrary Python objects 34 | that can serve as keys in a dictionary (i.e. they need to be 35 | hashable and immutable). 36 | 37 | The SymbolTable can only be read to/written from disk if the 38 | symbols are strings. 39 | ''' 40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict) 41 | '''Map an integer to a symbol. 42 | ''' 43 | 44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict) 45 | '''Map a symbol to an integer. 46 | ''' 47 | 48 | _next_available_id: int = 1 49 | '''A helper internal field that helps adding new symbols 50 | to the table efficiently. 51 | ''' 52 | 53 | eps: Symbol = '' 54 | '''Null symbol, always mapped to index 0. 55 | ''' 56 | 57 | def __post_init__(self): 58 | for idx, sym in self._id2sym.items(): 59 | assert self._sym2id[sym] == idx 60 | assert idx >= 0 61 | 62 | for sym, idx in self._sym2id.items(): 63 | assert idx >= 0 64 | assert self._id2sym[idx] == sym 65 | 66 | if 0 not in self._id2sym: 67 | self._id2sym[0] = self.eps 68 | self._sym2id[self.eps] = 0 69 | else: 70 | assert self._id2sym[0] == self.eps 71 | assert self._sym2id[self.eps] == 0 72 | 73 | self._next_available_id = max(self._id2sym) + 1 74 | 75 | @staticmethod 76 | def from_str(s: str) -> 'SymbolTable': 77 | '''Build a symbol table from a string. 78 | 79 | The string consists of lines. Every line has two fields separated 80 | by space(s), tab(s) or both. The first field is the symbol and the 81 | second the integer id of the symbol. 82 | 83 | Args: 84 | s: 85 | The input string with the format described above. 86 | Returns: 87 | An instance of :class:`SymbolTable`. 88 | ''' 89 | id2sym: Dict[int, str] = dict() 90 | sym2id: Dict[str, int] = dict() 91 | 92 | for line in s.split('\n'): 93 | fields = line.split() 94 | if len(fields) == 0: 95 | continue # skip empty lines 96 | assert len(fields) == 2, \ 97 | f'Expect a line with 2 fields. Given: {len(fields)}' 98 | sym, idx = fields[0], int(fields[1]) 99 | assert sym not in sym2id, f'Duplicated symbol {sym}' 100 | assert idx not in id2sym, f'Duplicated id {idx}' 101 | id2sym[idx] = sym 102 | sym2id[sym] = idx 103 | 104 | eps = id2sym.get(0, '') 105 | 106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) 107 | 108 | @staticmethod 109 | def from_file(filename: str) -> 'SymbolTable': 110 | '''Build a symbol table from file. 111 | 112 | Every line in the symbol table file has two fields separated by 113 | space(s), tab(s) or both. The following is an example file: 114 | 115 | .. code-block:: 116 | 117 | 0 118 | a 1 119 | b 2 120 | c 3 121 | 122 | Args: 123 | filename: 124 | Name of the symbol table file. Its format is documented above. 125 | 126 | Returns: 127 | An instance of :class:`SymbolTable`. 128 | 129 | ''' 130 | with open(filename, 'r', encoding='utf-8') as f: 131 | return SymbolTable.from_str(f.read().strip()) 132 | 133 | def to_str(self) -> str: 134 | ''' 135 | Returns: 136 | Return a string representation of this object. You can pass 137 | it to the method ``from_str`` to recreate an identical object. 138 | ''' 139 | s = '' 140 | for idx, symbol in sorted(self._id2sym.items()): 141 | s += f'{symbol} {idx}\n' 142 | return s 143 | 144 | def to_file(self, filename: str): 145 | '''Serialize the SymbolTable to a file. 146 | 147 | Every line in the symbol table file has two fields separated by 148 | space(s), tab(s) or both. The following is an example file: 149 | 150 | .. code-block:: 151 | 152 | 0 153 | a 1 154 | b 2 155 | c 3 156 | 157 | Args: 158 | filename: 159 | Name of the symbol table file. Its format is documented above. 160 | ''' 161 | with open(filename, 'w') as f: 162 | for idx, symbol in sorted(self._id2sym.items()): 163 | print(symbol, idx, file=f) 164 | 165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int: 166 | '''Add a new symbol to the SymbolTable. 167 | 168 | Args: 169 | symbol: 170 | The symbol to be added. 171 | index: 172 | Optional int id to which the symbol should be assigned. 173 | If it is not available, a ValueError will be raised. 174 | 175 | Returns: 176 | The int id to which the symbol has been assigned. 177 | ''' 178 | # Already in the table? Return its ID. 179 | if symbol in self._sym2id: 180 | return self._sym2id[symbol] 181 | # Specific ID not provided - use next available. 182 | if index is None: 183 | index = self._next_available_id 184 | # Specific ID provided but not available. 185 | if index in self._id2sym: 186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " 187 | f"already occupied by {self._id2sym[index]}") 188 | self._sym2id[symbol] = index 189 | self._id2sym[index] = symbol 190 | 191 | # Update next available ID if needed 192 | if self._next_available_id <= index: 193 | self._next_available_id = index + 1 194 | 195 | return index 196 | 197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: 198 | '''Get a symbol for an id or get an id for a symbol 199 | 200 | Args: 201 | k: 202 | If it is an id, it tries to find the symbol corresponding 203 | to the id; if it is a symbol, it tries to find the id 204 | corresponding to the symbol. 205 | 206 | Returns: 207 | An id or a symbol depending on the given `k`. 208 | ''' 209 | if isinstance(k, int): 210 | return self._id2sym[k] 211 | else: 212 | return self._sym2id[k] 213 | 214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable': 215 | '''Create a union of two SymbolTables. 216 | Raises an AssertionError if the same IDs are occupied by 217 | different symbols. 218 | 219 | Args: 220 | other: 221 | A symbol table to merge with ``self``. 222 | 223 | Returns: 224 | A new symbol table. 225 | ''' 226 | self._check_compatible(other) 227 | 228 | id2sym = {**self._id2sym, **other._id2sym} 229 | sym2id = {**self._sym2id, **other._sym2id} 230 | 231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps) 232 | 233 | def _check_compatible(self, other: 'SymbolTable') -> None: 234 | # Epsilon compatibility 235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ 236 | f'{self.eps} != {other.eps}' 237 | # IDs compatibility 238 | common_ids = set(self._id2sym).intersection(other._id2sym) 239 | for idx in common_ids: 240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ 241 | f'self[idx] = "{self[idx]}", ' \ 242 | f'other[idx] = "{other[idx]}"' 243 | # Symbols compatibility 244 | common_symbols = set(self._sym2id).intersection(other._sym2id) 245 | for sym in common_symbols: 246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ 247 | f'self[sym] = "{self[sym]}", ' \ 248 | f'other[sym] = "{other[sym]}"' 249 | 250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: 251 | return self.get(item) 252 | 253 | def __contains__(self, item: Union[int, Symbol]) -> bool: 254 | if isinstance(item, int): 255 | return item in self._id2sym 256 | else: 257 | return item in self._sym2id 258 | 259 | def __len__(self) -> int: 260 | return len(self._id2sym) 261 | 262 | def __eq__(self, other: 'SymbolTable') -> bool: 263 | if len(self) != len(other): 264 | return False 265 | 266 | for s in self.symbols: 267 | if self[s] != other[s]: 268 | return False 269 | 270 | return True 271 | 272 | @property 273 | def ids(self) -> List[int]: 274 | '''Returns a list of integer IDs corresponding to the symbols. 275 | ''' 276 | ans = list(self._id2sym.keys()) 277 | ans.sort() 278 | return ans 279 | 280 | @property 281 | def symbols(self) -> List[Symbol]: 282 | '''Returns a list of symbols (e.g., strings) corresponding to 283 | the integer IDs. 284 | ''' 285 | ans = list(self._sym2id.keys()) 286 | ans.sort() 287 | return ans 288 | -------------------------------------------------------------------------------- /utils/textgrid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | Copy from https://github.com/kylerbrown/textgrid/blob/master/textgrid.py 4 | """ 5 | 6 | from collections import namedtuple 7 | 8 | Entry = namedtuple("Entry", ["start", 9 | "stop", 10 | "name", 11 | "tier"]) 12 | 13 | def read_textgrid(filename, fileEncoding="utf-8"): 14 | """ 15 | Reads a TextGrid file into a dictionary object 16 | each dictionary has the following keys: 17 | "start" 18 | "stop" 19 | "name" 20 | "tier" 21 | 22 | Points and intervals use the same format, 23 | but the value for "start" and "stop" are the same 24 | 25 | Optionally, supply fileEncoding as argument. This defaults to "utf-8", tested with 'utf-16-be'. 26 | """ 27 | if isinstance(filename, str): 28 | with open(filename, "r", encoding=fileEncoding) as f: 29 | content = _read(f) 30 | elif hasattr(filename, "readlines"): 31 | content = _read(filename) 32 | else: 33 | raise TypeError("filename must be a string or a readable buffer") 34 | 35 | interval_lines = [i for i, line in enumerate(content) 36 | if line.startswith("intervals [") 37 | or line.startswith("points [")] 38 | # tier_lines, tiers = [(i, line.split('"')[-2]) 39 | # for i, line in enumerate(content) 40 | # if line.startswith("name =")] 41 | tier_lines = [] 42 | tiers = [] 43 | for i, line in enumerate(content): 44 | if line.startswith("name ="): 45 | tier_lines.append(i) 46 | tiers.append(line.split('"')[-2]) 47 | 48 | interval_tiers = _find_tiers(interval_lines, tier_lines, tiers) 49 | assert len(interval_lines) == len(interval_tiers) 50 | return [_build_entry(i, content, t) for i, t in zip(interval_lines, interval_tiers)] 51 | 52 | 53 | def _find_tiers(interval_lines, tier_lines, tiers): 54 | tier_pairs = zip(tier_lines, tiers) 55 | cur_tline, cur_tier = next(tier_pairs) 56 | next_tline, next_tier = next(tier_pairs, (None, None)) 57 | tiers = [] 58 | for il in interval_lines: 59 | if next_tline is not None and il > next_tline: 60 | cur_tline, cur_tier = next_tline, next_tier 61 | next_tline, next_tier = next(tier_pairs, (None, None)) 62 | tiers.append(cur_tier) 63 | return tiers 64 | 65 | 66 | def _read(f): 67 | return [x.strip() for x in f.readlines()] 68 | 69 | def write_csv(textgrid_list, filename=None, sep=",", header=True, save_gaps=False, meta=True): 70 | """ 71 | Writes a list of textgrid dictionaries to a csv file. 72 | If no filename is specified, csv is printed to standard out. 73 | """ 74 | columns = list(Entry._fields) 75 | if filename: 76 | f = open(filename, "w") 77 | if header: 78 | hline = sep.join(columns) 79 | if filename: 80 | f.write(hline + "\n") 81 | else: 82 | print(hline) 83 | for entry in textgrid_list: 84 | if entry.name or save_gaps: # skip unlabeled intervals 85 | row = sep.join(str(x) for x in list(entry)) 86 | if filename: 87 | f.write(row + "\n") 88 | else: 89 | print(row) 90 | if filename: 91 | f.flush() 92 | f.close() 93 | if meta: 94 | with open(filename + ".meta", "w") as metaf: 95 | metaf.write("""---\nunits: s\ndatatype: 1002\n""") 96 | 97 | def _build_entry(i, content, tier): 98 | """ 99 | takes the ith line that begin an interval and returns 100 | a dictionary of values 101 | """ 102 | start = _get_float_val(content[i + 1]) # addition is cheap typechecking 103 | if content[i].startswith("intervals ["): 104 | offset = 1 105 | else: 106 | offset = 0 # for "point" objects 107 | stop = _get_float_val(content[i + 1 + offset]) 108 | label = _get_str_val(content[i + 2 + offset]) 109 | return Entry(start=start, stop=stop, name=label, tier=tier) 110 | 111 | 112 | def _get_float_val(string): 113 | """ 114 | returns the last word in a string as a float 115 | """ 116 | return float(string.split()[-1]) 117 | 118 | 119 | def _get_str_val(string): 120 | """ 121 | returns the last item in quotes from a string 122 | """ 123 | return string.split('"')[-2] 124 | 125 | 126 | def textgrid2csv(): 127 | import argparse 128 | parser = argparse.ArgumentParser(description="convert a TextGrid file to a CSV.") 129 | parser.add_argument("TextGrid", 130 | help="a TextGrid file to process") 131 | parser.add_argument("-o", "--output", help="(optional) outputfile") 132 | parser.add_argument("--sep", help="separator to use in CSV output", 133 | default=",") 134 | parser.add_argument("--noheader", help="no header for the CSV", 135 | action="store_false") 136 | parser.add_argument("--savegaps", help="preserves intervals with no label", 137 | action="store_true") 138 | args = parser.parse_args() 139 | tgrid = read_textgrid(args.TextGrid) 140 | write_csv(tgrid, args.output, args.sep, args.noheader, args.savegaps) 141 | 142 | 143 | if __name__ == "__main__": 144 | textgrid2csv() 145 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from matplotlib import pyplot as plt 6 | 7 | import numpy as np 8 | 9 | from typing import Any, Dict, Tuple, Union 10 | 11 | 12 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 13 | assert lengths.ndim == 1, lengths.ndim 14 | max_len = max(max_len, lengths.max()) 15 | n = lengths.size(0) 16 | seq_range = torch.arange(0, max_len, device=lengths.device) 17 | expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) 18 | return expaned_lengths >= lengths.unsqueeze(-1) 19 | 20 | 21 | def make_attn_mask(lengths: torch.Tensor, num_heads: int, causal: False) -> torch.Tensor: 22 | 23 | key_padding_mask = make_pad_mask(lengths) 24 | 25 | bsz = key_padding_mask.size(0) 26 | seq_len = key_padding_mask.size(1) 27 | 28 | key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, num_heads, -1, -1) 29 | 30 | if causal: 31 | assert seq_len == lengths.max(), "Causal mask requires all lengths to be equal to max_len" 32 | causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=key_padding_mask.device), diagonal=1) 33 | causal_mask = causal_mask.view(1, 1, seq_len, seq_len).expand(bsz, num_heads, -1, -1) 34 | causal_mask = causal_mask.logical_or(key_padding_mask) 35 | return causal_mask.float().masked_fill(causal_mask, float("-inf")) 36 | else: 37 | key_padding_mask_float = key_padding_mask.float() 38 | key_padding_mask_float = key_padding_mask_float.masked_fill(key_padding_mask, float("-inf")) 39 | return key_padding_mask_float 40 | 41 | def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: 42 | """ 43 | Save a matplotlib figure to a numpy array. 44 | 45 | Args: 46 | fig (Figure): Matplotlib figure object. 47 | 48 | Returns: 49 | ndarray: Numpy array representing the figure. 50 | """ 51 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 52 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 53 | return data 54 | 55 | def plot_spectrogram_to_numpy(spec_target: np.ndarray, spec_output: np.ndarray) -> np.ndarray: 56 | """ 57 | Plot a spectrogram and convert it to a numpy array. 58 | 59 | Args: 60 | spectrogram (ndarray): Spectrogram data. 61 | 62 | Returns: 63 | ndarray: Numpy array representing the plotted spectrogram. 64 | """ 65 | 66 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12)) 67 | ax1.set_title("Target") 68 | im = ax1.imshow(spec_target.astype(np.float32), aspect="auto", origin="lower", interpolation="none") 69 | plt.colorbar(im, ax=ax1) 70 | plt.xlabel("Frames") 71 | plt.ylabel("Channels") 72 | 73 | ax2.set_title("Output") 74 | im = ax2.imshow(spec_output.astype(np.float32), aspect="auto", origin="lower", interpolation="none") 75 | plt.colorbar(im, ax=ax2) 76 | plt.xlabel("Frames") 77 | plt.ylabel("Channels") 78 | 79 | plt.tight_layout() 80 | 81 | fig.canvas.draw() 82 | data = save_figure_to_numpy(fig) 83 | plt.close() 84 | return data 85 | 86 | def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: 87 | """Instantiates a class with the given args and init. 88 | 89 | Args: 90 | args: Positional arguments required for instantiation. 91 | init: Dict of the form {"class_path":...,"init_args":...}. 92 | 93 | Returns: 94 | The instantiated class object. 95 | """ 96 | kwargs = init.get("init_args", {}) 97 | if not isinstance(args, tuple): 98 | args = (args,) 99 | class_module, class_name = init["class_path"].rsplit(".", 1) 100 | module = __import__(class_module, fromlist=[class_name]) 101 | args_class = getattr(module, class_name) 102 | return args_class(*args, **kwargs) 103 | 104 | 105 | if __name__ == "__main__": 106 | lengths = torch.tensor([3, 5]) 107 | 108 | b = torch.ones(3, 5, 5) 109 | m = make_attn_mask(lengths, 1, True) 110 | print(m) 111 | print(b + m) 112 | --------------------------------------------------------------------------------